This chapter implements a complete RNA-seq differential expression analysis pipeline in Python. We cover the full journey from a raw count matrix to a biologically interpreted results table: quality control, normalization, differential expression testing, visualization, and pathway enrichment.
The pipeline uses:
- pydeseq2: Python reimplementation of DESeq2
- GSEApy: Python interface to gene set enrichment analysis
- pandas / numpy: data manipulation
- matplotlib / seaborn: visualization
We work with a simulated dataset that mirrors a real RNA-seq experiment — a cancer cell line treated with a drug versus DMSO control, 3 replicates each.
Setup
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
from scipy import stats
from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats
import gseapy as gp
import warnings
warnings.filterwarnings('ignore')
np.random.seed(42)
Step 1: Simulate a Realistic Count Matrix
In practice you'd load a counts matrix from featureCounts, STAR/RSEM, or Salmon. Here we simulate one with realistic properties: overdispersion, differential expression in ~10% of genes, and realistic count magnitudes.
def simulate_rnaseq_counts(n_genes=10000, n_control=3, n_treated=3,
n_de_genes=1000, seed=42):
"""Simulate RNA-seq count data with realistic properties."""
rng = np.random.default_rng(seed)
n_samples = n_control + n_treated
# Gene base expression levels: log-normal distribution
# Most genes are lowly expressed; a few are highly expressed
base_means = rng.lognormal(mean=3.0, sigma=2.0, size=n_genes)
# Dispersion: inversely related to mean (as in real RNA-seq)
# High dispersion for low-count genes, low dispersion for high-count
dispersions = 0.5 / (base_means ** 0.5) + 0.05
# Sample size factors (simulate modest library size differences)
size_factors = rng.uniform(0.8, 1.3, size=n_samples)
# Fold changes for DE genes
de_gene_indices = rng.choice(n_genes, size=n_de_genes, replace=False)
# Mix of up and down regulated genes
log2fc = np.zeros(n_genes)
n_up = n_de_genes // 2
log2fc[de_gene_indices[:n_up]] = rng.uniform(0.8, 3.0, size=n_up)
log2fc[de_gene_indices[n_up:]] = -rng.uniform(0.8, 3.0, size=n_de_genes - n_up)
# Generate counts: negative binomial
counts = np.zeros((n_genes, n_samples), dtype=int)
for j in range(n_samples):
sf = size_factors[j]
is_treated = j >= n_control
for i in range(n_genes):
mean = base_means[i] * sf
if is_treated:
mean *= 2 ** log2fc[i]
# NB: parameterized by mean and dispersion
# p = mean / (mean + 1/disp); r = 1/disp
disp = dispersions[i]
r = 1.0 / disp
p = mean / (mean + r)
counts[i, j] = rng.negative_binomial(r, 1 - p)
# Build metadata
sample_names = [f"control_{i+1}" for i in range(n_control)] + \
[f"treated_{i+1}" for i in range(n_treated)]
gene_names = [f"GENE_{i:05d}" for i in range(n_genes)]
counts_df = pd.DataFrame(counts, index=gene_names, columns=sample_names)
metadata_df = pd.DataFrame({
'condition': ['control'] * n_control + ['treated'] * n_treated,
'batch': ['batch1', 'batch1', 'batch2', 'batch1', 'batch2', 'batch2']
}, index=sample_names)
true_de = set(gene_names[i] for i in de_gene_indices)
return counts_df, metadata_df, true_de
counts_df, metadata_df, true_de_genes = simulate_rnaseq_counts()
print(f"Count matrix shape: {counts_df.shape}")
print(f"\nSample metadata:")
print(metadata_df)
print(f"\nCount matrix (first 5 genes, all samples):")
print(counts_df.head())
print(f"\nTrue DE genes: {len(true_de_genes)}")
Count matrix shape: (10000, 6)
Sample metadata:
condition batch
control_1 control batch1
control_2 control batch1
control_3 control batch2
treated_1 treated batch1
treated_2 treated batch2
treated_3 treated batch2
True DE genes: 1000
Step 2: Quality Control
Always inspect the data before analysis. QC should catch: low-quality samples, contamination, batch effects visible before correction, and outlier samples.
def plot_library_sizes(counts_df, metadata_df):
"""Bar chart of total counts per sample (library size)."""
lib_sizes = counts_df.sum(axis=0) / 1e6 # in millions
colors = ['#4C9BE8' if c == 'control' else '#E84C4C'
for c in metadata_df['condition']]
fig, ax = plt.subplots(figsize=(8, 4))
bars = ax.bar(lib_sizes.index, lib_sizes.values, color=colors)
ax.set_ylabel('Library size (million reads)')
ax.set_title('Library Sizes by Sample')
ax.tick_params(axis='x', rotation=45)
# Add legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor='#4C9BE8', label='Control'),
Patch(facecolor='#E84C4C', label='Treated')]
ax.legend(handles=legend_elements)
plt.tight_layout()
plt.savefig('qc_library_sizes.png', dpi=150, bbox_inches='tight')
plt.show()
print("\nLibrary sizes (millions reads):")
for sample, size in lib_sizes.items():
print(f" {sample}: {size:.1f}M")
def plot_count_distributions(counts_df, metadata_df):
"""Log-count distributions per sample — should be similar between samples."""
# Log-transform (add 1 for zeros)
log_counts = np.log2(counts_df + 1)
fig, ax = plt.subplots(figsize=(10, 5))
for sample in counts_df.columns:
condition = metadata_df.loc[sample, 'condition']
color = '#4C9BE8' if condition == 'control' else '#E84C4C'
alpha = 0.6
log_counts[sample].plot.density(ax=ax, color=color, alpha=alpha, label=sample)
ax.set_xlabel('log2(count + 1)')
ax.set_ylabel('Density')
ax.set_title('Count Distributions by Sample')
ax.legend(loc='upper right', fontsize=8)
plt.tight_layout()
plt.savefig('qc_count_distributions.png', dpi=150, bbox_inches='tight')
plt.show()
def detect_zero_fraction(counts_df):
"""Fraction of genes with zero counts per sample."""
zero_frac = (counts_df == 0).mean(axis=0)
print("\nZero count fraction per sample:")
for sample, frac in zero_frac.items():
print(f" {sample}: {frac:.1%}")
return zero_frac
plot_library_sizes(counts_df, metadata_df)
plot_count_distributions(counts_df, metadata_df)
zero_fracs = detect_zero_fraction(counts_df)
PCA for Sample-Level QC
def pca_plot(counts_df, metadata_df, title="PCA of samples"):
"""
PCA on log-normalized counts.
Critical: PCA must be run on log-transformed, normalized counts.
Raw counts give misleading PCA due to mean-variance dependence.
"""
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
# Normalize: log2(CPM + 1)
cpm = counts_df.divide(counts_df.sum(axis=0), axis=1) * 1e6
log_cpm = np.log2(cpm + 1)
# Filter to top 2000 most variable genes
gene_variances = log_cpm.var(axis=1)
top_genes = gene_variances.nlargest(2000).index
log_cpm_filtered = log_cpm.loc[top_genes]
# PCA (samples are rows)
X = log_cpm_filtered.T.values
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
pca = PCA(n_components=4)
pcs = pca.fit_transform(X_scaled)
# Plot PC1 vs PC2
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
color_map = {'control': '#4C9BE8', 'treated': '#E84C4C'}
marker_map = {'batch1': 'o', 'batch2': 's'}
sample_names = counts_df.columns.tolist()
for ax_idx, (pc_x, pc_y, xlabel, ylabel) in enumerate([
(0, 1, f'PC1 ({pca.explained_variance_ratio_[0]:.1%})',
f'PC2 ({pca.explained_variance_ratio_[1]:.1%})'),
(0, 2, f'PC1 ({pca.explained_variance_ratio_[0]:.1%})',
f'PC3 ({pca.explained_variance_ratio_[2]:.1%})')
]):
ax = axes[ax_idx]
for i, sample in enumerate(sample_names):
condition = metadata_df.loc[sample, 'condition']
batch = metadata_df.loc[sample, 'batch']
ax.scatter(pcs[i, pc_x], pcs[i, pc_y],
c=color_map[condition],
marker=marker_map[batch],
s=120, zorder=5)
ax.annotate(sample, (pcs[i, pc_x], pcs[i, pc_y]),
textcoords="offset points", xytext=(5, 5), fontsize=7)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.axhline(0, color='gray', lw=0.5)
ax.axvline(0, color='gray', lw=0.5)
axes[0].set_title('PC1 vs PC2')
axes[1].set_title('PC1 vs PC3')
# Shared legend
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], marker='o', color='w', markerfacecolor='#4C9BE8',
markersize=10, label='Control'),
Line2D([0], [0], marker='o', color='w', markerfacecolor='#E84C4C',
markersize=10, label='Treated'),
Line2D([0], [0], marker='o', color='gray', markersize=10, label='Batch 1'),
Line2D([0], [0], marker='s', color='gray', markersize=10, label='Batch 2'),
]
axes[1].legend(handles=legend_elements, loc='upper right', fontsize=8)
plt.suptitle(title)
plt.tight_layout()
plt.savefig('qc_pca.png', dpi=150, bbox_inches='tight')
plt.show()
print("\nVariance explained by each PC:")
for i, var in enumerate(pca.explained_variance_ratio_[:4]):
print(f" PC{i+1}: {var:.1%}")
return pca
pca_result = pca_plot(counts_df, metadata_df)
Step 3: Pre-filtering
Remove genes with very low counts. These genes have no power to detect differential expression and inflate multiple testing correction burden.
def filter_low_counts(counts_df, min_count=10, min_samples=2):
"""
Keep genes with at least min_count in at least min_samples.
This is applied before DESeq2 — DESeq2 also does independent
filtering internally, but manual pre-filtering reduces memory
and computation for very large count matrices.
"""
# At least min_count counts in at least min_samples samples
keep = (counts_df >= min_count).sum(axis=1) >= min_samples
counts_filtered = counts_df.loc[keep]
print(f"\nPre-filtering:")
print(f" Before: {counts_df.shape[0]:,} genes")
print(f" After: {counts_filtered.shape[0]:,} genes")
print(f" Removed: {counts_df.shape[0] - counts_filtered.shape[0]:,} low-count genes")
return counts_filtered
counts_filtered = filter_low_counts(counts_df, min_count=10, min_samples=2)
Step 4: DESeq2 Analysis with pydeseq2
pydeseq2 is a Python reimplementation of the R DESeq2 package. It implements the same statistical model: negative binomial GLM with empirical Bayes dispersion shrinkage and Wald test.
def run_deseq2(counts_df, metadata_df, design_column='condition',
reference_level='control'):
"""
Run DESeq2 differential expression analysis.
pydeseq2 expects:
- counts: samples × genes (transposed from our genes × samples)
- metadata: samples × covariates
"""
# pydeseq2 expects samples as rows, genes as columns
counts_t = counts_df.T
# Ensure metadata and counts have matching indices
assert all(counts_t.index == metadata_df.index), \
"Sample order mismatch between counts and metadata"
print(f"\nRunning DESeq2:")
print(f" Samples: {counts_t.shape[0]}")
print(f" Genes: {counts_t.shape[1]}")
print(f" Design: ~{design_column}")
print(f" Reference level: {reference_level}")
# Create DESeqDataSet
inference = DefaultInference(n_cpus=4)
dds = DeseqDataSet(
counts=counts_t,
metadata=metadata_df,
design_factors=design_column,
ref_level=[design_column, reference_level],
refit_cooks=True,
inference=inference
)
# Run DESeq2: size factors → dispersions → GLM → Wald test
dds.deseq2()
print("\nSize factors:")
for sample, sf in dds.obsm['size_factors'].items():
print(f" {sample}: {sf:.3f}")
return dds
dds = run_deseq2(counts_filtered, metadata_df)
Extract and Format Results
def extract_results(dds, contrast=('condition', 'treated', 'control'),
alpha=0.05, lfc_threshold=0.0):
"""Extract DESeq2 results with LFC shrinkage."""
stat_res = DeseqStats(
dds,
contrast=list(contrast),
alpha=alpha,
cooks_filter=True,
independent_filter=True
)
# Compute Wald test statistics
stat_res.run_wald_test()
# Apply LFC shrinkage (apeglm-equivalent in pydeseq2)
stat_res.lfc_shrink(coeff=f"{contrast[0]}_{contrast[1]}_vs_{contrast[2]}")
# Extract results table
results = stat_res.results_df.copy()
# Summary statistics
n_sig = (results['padj'] < alpha).sum()
n_sig_lfc = ((results['padj'] < alpha) &
(results['log2FoldChange'].abs() >= 1)).sum()
print(f"\nDESeq2 Results Summary:")
print(f" Total genes tested: {len(results):,}")
print(f" Genes with padj < {alpha}: {n_sig:,}")
print(f" Genes with padj < {alpha} AND |log2FC| >= 1: {n_sig_lfc:,}")
n_up = ((results['padj'] < alpha) & (results['log2FoldChange'] >= 1)).sum()
n_down = ((results['padj'] < alpha) & (results['log2FoldChange'] <= -1)).sum()
print(f" Upregulated: {n_up:,}")
print(f" Downregulated: {n_down:,}")
return results, stat_res
results, stat_res = extract_results(dds)
print("\nResults table (top 10 by adjusted p-value):")
top_hits = results.dropna(subset=['padj']).nsmallest(10, 'padj')
print(top_hits[['baseMean', 'log2FoldChange', 'lfcSE', 'stat', 'pvalue', 'padj']]
.round(4).to_string())
Step 5: Assessing Performance Against Ground Truth
Since we simulated the data with known DE genes, we can evaluate how well DESeq2 recovered them.
def evaluate_detection(results, true_de_genes, padj_threshold=0.05, lfc_threshold=1.0):
"""
Compute precision, recall, and F1 for DE gene detection.
In real experiments you don't have ground truth — but this validates
that the statistical model is working correctly.
"""
results_clean = results.dropna(subset=['padj'])
# Called significant by DESeq2
called_sig = set(results_clean[
(results_clean['padj'] < padj_threshold) &
(results_clean['log2FoldChange'].abs() >= lfc_threshold)
].index)
# True DE genes that were in our filtered count matrix
true_de_in_matrix = true_de_genes & set(results_clean.index)
# Confusion matrix components
tp = len(called_sig & true_de_in_matrix)
fp = len(called_sig - true_de_in_matrix)
fn = len(true_de_in_matrix - called_sig)
tn = len(set(results_clean.index) - called_sig - true_de_in_matrix)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
print(f"\nPerformance evaluation (padj < {padj_threshold}, |log2FC| >= {lfc_threshold}):")
print(f" True DE genes in matrix: {len(true_de_in_matrix):,}")
print(f" Called significant: {len(called_sig):,}")
print(f" True positives: {tp:,}")
print(f" False positives: {fp:,}")
print(f" False negatives: {fn:,}")
print(f" Precision: {precision:.3f}")
print(f" Recall: {recall:.3f}")
print(f" F1: {f1:.3f}")
return {'precision': precision, 'recall': recall, 'f1': f1,
'tp': tp, 'fp': fp, 'fn': fn}
perf = evaluate_detection(results, true_de_genes)
Step 6: Visualization
MA Plot
def plot_ma(results, padj_threshold=0.05, title="MA Plot"):
"""
MA plot: log2 fold change vs. mean expression.
Red points = significant. After LFC shrinkage, low-count genes
have LFCs shrunk toward zero — they don't dominate the plot.
"""
results_clean = results.dropna(subset=['padj', 'log2FoldChange'])
sig = results_clean['padj'] < padj_threshold
fig, ax = plt.subplots(figsize=(8, 6))
# Non-significant
ax.scatter(
np.log2(results_clean.loc[~sig, 'baseMean'] + 1),
results_clean.loc[~sig, 'log2FoldChange'],
alpha=0.3, s=3, color='#AAAAAA', rasterized=True
)
# Significant
ax.scatter(
np.log2(results_clean.loc[sig, 'baseMean'] + 1),
results_clean.loc[sig, 'log2FoldChange'],
alpha=0.7, s=8, color='#E84C4C', rasterized=True
)
ax.axhline(0, color='black', lw=0.8)
ax.axhline(1, color='blue', lw=0.5, linestyle='--', alpha=0.5)
ax.axhline(-1, color='blue', lw=0.5, linestyle='--', alpha=0.5)
ax.set_xlabel('log2(Mean normalized count + 1)')
ax.set_ylabel('log2 Fold Change (treated / control)')
ax.set_title(title)
n_sig = sig.sum()
ax.text(0.02, 0.98, f'Significant: {n_sig:,}',
transform=ax.transAxes, va='top', fontsize=9,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
plt.tight_layout()
plt.savefig('ma_plot.png', dpi=150, bbox_inches='tight')
plt.show()
plot_ma(results)
Volcano Plot
def plot_volcano(results, padj_threshold=0.05, lfc_threshold=1.0,
n_label=10, title="Volcano Plot"):
"""
Volcano plot: -log10(p-value) vs. log2FC.
Color coding:
- Gray: not significant
- Blue: |log2FC| > threshold but padj >= threshold
- Red: significant AND |log2FC| > threshold
"""
results_clean = results.dropna(subset=['padj', 'pvalue', 'log2FoldChange'])
# Categorize genes
sig_lfc = ((results_clean['padj'] < padj_threshold) &
(results_clean['log2FoldChange'].abs() >= lfc_threshold))
sig_only = ((results_clean['padj'] < padj_threshold) &
(results_clean['log2FoldChange'].abs() < lfc_threshold))
lfc_only = ((results_clean['padj'] >= padj_threshold) &
(results_clean['log2FoldChange'].abs() >= lfc_threshold))
# Cap -log10(pvalue) for visualization
neg_log_p = -np.log10(results_clean['pvalue'].clip(lower=1e-300))
fig, ax = plt.subplots(figsize=(8, 7))
neither = ~sig_lfc & ~sig_only & ~lfc_only
ax.scatter(results_clean.loc[neither, 'log2FoldChange'],
neg_log_p[neither], alpha=0.3, s=4, color='#AAAAAA', rasterized=True)
ax.scatter(results_clean.loc[lfc_only, 'log2FoldChange'],
neg_log_p[lfc_only], alpha=0.5, s=6, color='#4C9BE8', rasterized=True)
ax.scatter(results_clean.loc[sig_only, 'log2FoldChange'],
neg_log_p[sig_only], alpha=0.5, s=6, color='#FFA500', rasterized=True)
ax.scatter(results_clean.loc[sig_lfc, 'log2FoldChange'],
neg_log_p[sig_lfc], alpha=0.7, s=8, color='#E84C4C', rasterized=True)
# Threshold lines
ax.axhline(-np.log10(padj_threshold), color='gray', lw=0.8, linestyle='--')
ax.axvline(lfc_threshold, color='gray', lw=0.8, linestyle='--')
ax.axvline(-lfc_threshold, color='gray', lw=0.8, linestyle='--')
# Label top hits
top = results_clean[sig_lfc].nsmallest(n_label, 'padj')
for gene, row in top.iterrows():
ax.annotate(gene, (row['log2FoldChange'], -np.log10(row['pvalue'])),
fontsize=6, xytext=(3, 3), textcoords='offset points')
ax.set_xlabel('log2 Fold Change (treated / control)')
ax.set_ylabel('-log10(p-value)')
ax.set_title(title)
# Counts in each quadrant
n_up = (sig_lfc & (results_clean['log2FoldChange'] > 0)).sum()
n_down = (sig_lfc & (results_clean['log2FoldChange'] < 0)).sum()
ax.text(0.98, 0.98, f'Up: {n_up}', transform=ax.transAxes,
ha='right', va='top', color='#E84C4C', fontsize=10, fontweight='bold')
ax.text(0.02, 0.98, f'Down: {n_down}', transform=ax.transAxes,
ha='left', va='top', color='#E84C4C', fontsize=10, fontweight='bold')
plt.tight_layout()
plt.savefig('volcano_plot.png', dpi=150, bbox_inches='tight')
plt.show()
plot_volcano(results)
Heatmap of Top Differentially Expressed Genes
def plot_de_heatmap(counts_df, results, metadata_df,
n_genes=50, padj_threshold=0.05):
"""
Heatmap of top DE genes with hierarchical clustering.
Uses log2(CPM + 1) normalized expression for visualization.
"""
# Normalize counts
cpm = counts_df.divide(counts_df.sum(axis=0), axis=1) * 1e6
log_cpm = np.log2(cpm + 1)
# Select top DE genes by adjusted p-value
sig = (results['padj'] < padj_threshold) & (results['log2FoldChange'].abs() >= 1)
sig_genes = results[sig].nsmallest(n_genes, 'padj').index
sig_genes = [g for g in sig_genes if g in log_cpm.index]
heatmap_data = log_cpm.loc[sig_genes]
# Z-score normalize each gene (row) for visualization
heatmap_data_z = heatmap_data.subtract(heatmap_data.mean(axis=1), axis=0)
heatmap_data_z = heatmap_data_z.divide(heatmap_data.std(axis=1), axis=0)
# Column annotation
condition_colors = {'control': '#4C9BE8', 'treated': '#E84C4C'}
col_colors = pd.Series(
[condition_colors[metadata_df.loc[s, 'condition']] for s in heatmap_data_z.columns],
index=heatmap_data_z.columns
)
g = sns.clustermap(
heatmap_data_z,
col_colors=col_colors,
col_cluster=True,
row_cluster=True,
cmap='RdBu_r',
vmin=-2, vmax=2,
figsize=(10, 12),
xticklabels=True,
yticklabels=False,
cbar_kws={'label': 'Z-score'}
)
g.fig.suptitle(f'Top {len(sig_genes)} DE Genes (Z-score)', y=1.02)
plt.savefig('de_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()
plot_de_heatmap(counts_filtered, results, metadata_df)
Step 7: Pathway Enrichment Analysis
Differential expression gives a gene list. Pathway enrichment reveals the biological processes those genes are involved in.
We'll implement two approaches:
- Over-Representation Analysis (ORA): hypergeometric test on gene sets
- GSEA pre-ranked: Gene Set Enrichment Analysis using the ranked log2FC
def run_ora(results, padj_threshold=0.05, lfc_threshold=1.0,
organism='human'):
"""
Over-Representation Analysis using GSEApy's enrichr.
Tests whether the DE gene set is enriched for specific pathways
relative to the background of all tested genes.
"""
# Significant upregulated genes
sig_up = results[
(results['padj'] < padj_threshold) &
(results['log2FoldChange'] >= lfc_threshold)
].index.tolist()
# Significant downregulated genes
sig_down = results[
(results['padj'] < padj_threshold) &
(results['log2FoldChange'] <= -lfc_threshold)
].index.tolist()
print(f"\nORA input:")
print(f" Upregulated genes: {len(sig_up)}")
print(f" Downregulated genes: {len(sig_down)}")
# In a real analysis, sig_up/sig_down would be real gene symbols
# (HGNC names like "TP53", "BRCA1") not simulated gene IDs.
# Here we demonstrate the API; gene sets would be from MSigDB/KEGG.
# Example with upregulated genes
# enr_up = gp.enrichr(
# gene_list=sig_up,
# gene_sets=['KEGG_2021_Human', 'MSigDB_Hallmark_2020'],
# organism=organism,
# outdir='enrichr_results',
# no_plot=True
# )
# return enr_up.results
# Since we have simulated gene names, demonstrate the structure:
print("\n (In real analysis: pass HGNC gene symbols to gp.enrichr)")
print(" Gene sets: KEGG_2021_Human, MSigDB_Hallmark_2020, GO_Biological_Process_2023")
return sig_up, sig_down
sig_up_genes, sig_down_genes = run_ora(results)
GSEA Pre-ranked
def run_gsea_preranked(results, gene_set_name='MSigDB_Hallmark_2020'):
"""
GSEA pre-ranked analysis.
Ranks all genes by log2FC (or stat) and tests whether gene sets
are enriched at the top or bottom of the ranked list.
Advantages over ORA:
- Uses all genes, not just a binary significant/not cutoff
- Detects subtle coordinated shifts in gene sets
- Finds enrichment even when no single gene passes significance threshold
"""
# Create ranked gene list: use stat (Wald statistic) for ranking
# stat = log2FC / SE — combines effect size and significance
ranked = results.dropna(subset=['stat']).sort_values('stat', ascending=False)
# In real analysis: gene names would be HGNC symbols
# ranked_genes = ranked['stat'].to_dict() # {gene_name: rank_metric}
print("\nGSEA Pre-ranked setup:")
print(f" Ranking metric: Wald statistic (log2FC / SE)")
print(f" Genes ranked: {len(ranked):,}")
print(f" Top 5 (most upregulated):")
print(ranked[['log2FoldChange', 'stat', 'padj']].head().round(4).to_string())
print(f"\n Bottom 5 (most downregulated):")
print(ranked[['log2FoldChange', 'stat', 'padj']].tail().round(4).to_string())
# Real GSEA call would be:
# gsea_res = gp.prerank(
# rnk=ranked['stat'],
# gene_sets=gene_set_name,
# seed=42,
# permutation_num=1000,
# outdir='gsea_results'
# )
# return gsea_res.res2d
print("\n (In real analysis: pass ranked stats with HGNC gene symbols to gp.prerank)")
return ranked
gsea_ranked = run_gsea_preranked(results)
Manual Hypergeometric Test
For understanding what ORA does internally:
from scipy.stats import hypergeom
def hypergeometric_enrichment(gene_set, de_genes, background_genes):
"""
Test whether a gene set is enriched among DE genes.
Parameters:
- gene_set: set of genes in the pathway
- de_genes: set of differentially expressed genes
- background_genes: all tested genes (universe)
Returns p-value, odds ratio, overlap
"""
M = len(background_genes) # total genes in universe
n = len(gene_set & background_genes) # genes in pathway (in universe)
N = len(de_genes & background_genes) # DE genes (in universe)
k = len(gene_set & de_genes) # overlap
if n == 0 or N == 0:
return 1.0, 1.0, 0
# P(X >= k) where X ~ Hypergeometric(M, n, N)
pval = hypergeom.sf(k - 1, M, n, N)
# Odds ratio
expected = n * N / M
odds_ratio = (k / (N - k)) / (n / (M - n)) if (N - k) > 0 and (M - n) > 0 else np.inf
return pval, odds_ratio, k, expected
# Example: test a few mock gene sets
background = set(results.dropna(subset=['padj']).index)
de_sig = set(results[
(results['padj'] < 0.05) & (results['log2FoldChange'].abs() >= 1)
].dropna(subset=['padj']).index)
# Create mock gene sets (random subsets of our genes)
rng = np.random.default_rng(42)
all_genes = list(background)
mock_pathways = {
'Pathway_A': set(rng.choice(all_genes, size=100, replace=False)),
'Pathway_B': set(rng.choice(all_genes, size=200, replace=False)),
# Simulate an "enriched" pathway: mostly from DE genes
'Pathway_C': set(rng.choice(list(de_sig)[:200], size=80, replace=False)) |
set(rng.choice(list(background - de_sig), size=20, replace=False)),
}
print("\nManual hypergeometric enrichment test:")
print(f" Background size: {len(background):,}")
print(f" DE gene set size: {len(de_sig):,}")
print(f"\n {'Pathway':<15} {'Size':>6} {'Overlap':>8} {'Expected':>9} {'P-value':>12} {'Odds Ratio':>12}")
print(" " + "-" * 70)
pvals = []
for pathway_name, gene_set in mock_pathways.items():
pval, odds_ratio, k, expected = hypergeometric_enrichment(
gene_set, de_sig, background
)
pvals.append(pval)
print(f" {pathway_name:<15} {len(gene_set):>6} {k:>8} {expected:>9.1f} "
f"{pval:>12.2e} {odds_ratio:>12.2f}")
# Correct for multiple testing
from statsmodels.stats.multitest import multipletests
reject, pvals_adj, _, _ = multipletests(pvals, method='fdr_bh')
print(f"\n After BH correction:")
for i, (pathway_name, padj) in enumerate(zip(mock_pathways.keys(), pvals_adj)):
print(f" {pathway_name}: padj = {padj:.4f} {'*' if padj < 0.05 else ''}")
Step 8: Normalized Expression Visualization
Expression of Top Hits
def plot_top_gene_expression(counts_df, results, metadata_df, n_genes=9):
"""
Box plots of normalized expression for top DE genes.
Shows the actual expression values that drive the LFC estimate.
"""
# Normalize
cpm = counts_df.divide(counts_df.sum(axis=0), axis=1) * 1e6
log_cpm = np.log2(cpm + 1)
# Top DE genes (mix of up and down)
sig = (results['padj'] < 0.05) & (results['log2FoldChange'].abs() >= 1)
top_up = results[sig & (results['log2FoldChange'] > 0)].nsmallest(
n_genes // 2, 'padj').index
top_down = results[sig & (results['log2FoldChange'] < 0)].nsmallest(
n_genes // 2, 'padj').index
top_genes = list(top_up) + list(top_down)
top_genes = [g for g in top_genes if g in log_cpm.index][:n_genes]
n_cols = 3
n_rows = (len(top_genes) + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 4 * n_rows))
axes = axes.flatten()
for idx, gene in enumerate(top_genes):
ax = axes[idx]
ctrl_expr = log_cpm.loc[gene, metadata_df['condition'] == 'control'].values
trt_expr = log_cpm.loc[gene, metadata_df['condition'] == 'treated'].values
ax.boxplot([ctrl_expr, trt_expr], labels=['Control', 'Treated'],
notch=False, patch_artist=True,
boxprops=dict(facecolor='lightblue', alpha=0.7))
# Overlay individual points
for j, (vals, color) in enumerate([(ctrl_expr, '#4C9BE8'), (trt_expr, '#E84C4C')]):
ax.scatter([j + 1] * len(vals), vals, color=color, zorder=5, s=40)
lfc = results.loc[gene, 'log2FoldChange']
padj = results.loc[gene, 'padj']
ax.set_title(f'{gene}\nlog2FC={lfc:.2f}, padj={padj:.2e}', fontsize=9)
ax.set_ylabel('log2(CPM + 1)')
# Hide unused subplots
for idx in range(len(top_genes), len(axes)):
axes[idx].set_visible(False)
plt.suptitle('Expression of Top Differentially Expressed Genes', y=1.02)
plt.tight_layout()
plt.savefig('top_gene_expression.png', dpi=150, bbox_inches='tight')
plt.show()
plot_top_gene_expression(counts_filtered, results, metadata_df)
Step 9: Export Results
def export_results(results, metadata_df, output_prefix='deseq2'):
"""Export analysis results to standard formats."""
# Sort by adjusted p-value
results_sorted = results.sort_values('padj', na_position='last')
# Full results table
results_sorted.to_csv(f'{output_prefix}_all_results.tsv', sep='\t')
print(f"Saved: {output_prefix}_all_results.tsv ({len(results_sorted):,} genes)")
# Significant genes only
sig = (results_sorted['padj'] < 0.05) & (results_sorted['log2FoldChange'].abs() >= 1)
sig_results = results_sorted[sig].dropna(subset=['padj'])
sig_results.to_csv(f'{output_prefix}_significant.tsv', sep='\t')
print(f"Saved: {output_prefix}_significant.tsv ({len(sig_results):,} genes)")
# Gene lists for enrichment tools (just gene names)
up = sig_results[sig_results['log2FoldChange'] > 0].index.tolist()
down = sig_results[sig_results['log2FoldChange'] < 0].index.tolist()
with open(f'{output_prefix}_upregulated_genes.txt', 'w') as f:
f.write('\n'.join(up))
with open(f'{output_prefix}_downregulated_genes.txt', 'w') as f:
f.write('\n'.join(down))
print(f"Saved: gene lists ({len(up)} up, {len(down)} down)")
# Summary statistics
summary = {
'total_genes_tested': len(results_sorted),
'significant_padj05': int((results_sorted['padj'] < 0.05).sum()),
'significant_padj05_lfc1': int(sig.sum()),
'upregulated': len(up),
'downregulated': len(down),
'median_baseMean_sig': float(sig_results['baseMean'].median()),
'median_absLFC_sig': float(sig_results['log2FoldChange'].abs().median())
}
import json
with open(f'{output_prefix}_summary.json', 'w') as f:
json.dump(summary, f, indent=2)
print(f"\nAnalysis summary:")
for key, val in summary.items():
print(f" {key}: {val}")
export_results(results, metadata_df)
Step 10: Complete Pipeline Function
Putting it all together:
def full_deseq2_pipeline(counts_path_or_df, metadata_path_or_df,
design_column='condition',
reference_level='control',
output_dir='deseq2_results'):
"""
Complete RNA-seq differential expression pipeline.
Parameters:
- counts_path_or_df: path to TSV or DataFrame (genes × samples)
- metadata_path_or_df: path to TSV or DataFrame (samples × covariates)
- design_column: metadata column to test
- reference_level: reference condition in design_column
- output_dir: directory for output files
Returns:
- results DataFrame with all DESeq2 statistics
"""
import os
os.makedirs(output_dir, exist_ok=True)
# Load data
if isinstance(counts_path_or_df, str):
counts = pd.read_csv(counts_path_or_df, index_col=0, sep='\t')
else:
counts = counts_path_or_df.copy()
if isinstance(metadata_path_or_df, str):
metadata = pd.read_csv(metadata_path_or_df, index_col=0, sep='\t')
else:
metadata = metadata_path_or_df.copy()
print("=" * 60)
print("RNA-seq Differential Expression Pipeline")
print("=" * 60)
# QC
print("\n[1/6] Quality Control")
zero_fracs = detect_zero_fraction(counts)
# Pre-filter
print("\n[2/6] Pre-filtering low-count genes")
counts_filt = filter_low_counts(counts, min_count=10, min_samples=2)
# Run DESeq2
print("\n[3/6] Running DESeq2")
dds = run_deseq2(counts_filt, metadata, design_column, reference_level)
# Extract results
print("\n[4/6] Extracting results")
results, stat_res = extract_results(dds)
# Visualize
print("\n[5/6] Generating visualizations")
plot_ma(results, title="MA Plot — treated vs. control")
plot_volcano(results, title="Volcano Plot — treated vs. control")
try:
plot_de_heatmap(counts_filt, results, metadata)
plot_top_gene_expression(counts_filt, results, metadata)
except Exception as e:
print(f" Warning: visualization failed: {e}")
# Export
print("\n[6/6] Exporting results")
import os
prefix = os.path.join(output_dir, f'deseq2_{design_column}')
export_results(results, metadata, prefix)
print("\n" + "=" * 60)
print("Pipeline complete.")
print("=" * 60)
return results
# Run the complete pipeline
final_results = full_deseq2_pipeline(
counts_df,
metadata_df,
design_column='condition',
reference_level='control',
output_dir='pipeline_output'
)
Common Pitfalls and How to Avoid Them
def check_common_mistakes(counts_df, metadata_df, results):
"""
Diagnostic checks for common RNA-seq analysis errors.
Checks:
1. Sample order consistency
2. Reference level in metadata
3. p-value histogram shape (should be right-skewed with anti-conservative pile-up)
4. Cook's distance outliers
5. Independent filtering effectiveness
"""
print("\nDiagnostic checks:")
# 1. Sample order
counts_samples = set(counts_df.columns)
metadata_samples = set(metadata_df.index)
if counts_samples != metadata_samples:
print(" [FAIL] Sample mismatch between counts and metadata!")
print(f" In counts not metadata: {counts_samples - metadata_samples}")
print(f" In metadata not counts: {metadata_samples - counts_samples}")
else:
print(" [OK] Sample names consistent between counts and metadata")
# 2. p-value histogram
pvals_clean = results['pvalue'].dropna()
fig, ax = plt.subplots(figsize=(6, 4))
ax.hist(pvals_clean, bins=50, edgecolor='black', linewidth=0.5, color='steelblue')
ax.set_xlabel('p-value')
ax.set_ylabel('Number of genes')
ax.set_title('P-value Histogram\n(should have spike at 0, flat elsewhere)')
# Check for anti-conservative peak at 0
frac_small = (pvals_clean < 0.05).mean()
ax.text(0.5, 0.95, f'Fraction p<0.05: {frac_small:.1%}',
transform=ax.transAxes, ha='center', va='top')
plt.tight_layout()
plt.savefig('pvalue_histogram.png', dpi=150, bbox_inches='tight')
plt.show()
if frac_small < 0.05:
print(f" [WARNING] Only {frac_small:.1%} of genes have p < 0.05.")
print(" This may indicate insufficient power or very few DE genes.")
elif frac_small > 0.5:
print(f" [WARNING] {frac_small:.1%} of genes have p < 0.05.")
print(" This may indicate batch effects or sample swap confounding biology.")
else:
print(f" [OK] {frac_small:.1%} of genes have p < 0.05 (reasonable signal)")
# 3. NA in results (from independent filtering)
n_na = results['padj'].isna().sum()
n_total = len(results)
print(f" [INFO] Genes filtered by independent filtering: {n_na:,} / {n_total:,} "
f"({n_na/n_total:.1%})")
if n_na / n_total > 0.5:
print(" [WARNING] >50% of genes filtered — consider less aggressive pre-filtering")
check_common_mistakes(counts_filtered, metadata_df, results)
Summary: The Complete Workflow
Raw count matrix (genes × samples)
↓ filter_low_counts()
Pre-filtered counts
↓ pca_plot() — check for outliers, batch effects
↓ run_deseq2()
DESeqDataSet with size factors + dispersion estimates
↓ extract_results()
Results table: log2FC, SE, stat, pvalue, padj
↓ plot_ma(), plot_volcano(), plot_de_heatmap()
Visualizations
↓ run_ora() / run_gsea_preranked()
Enriched pathways
↓ export_results()
TSV files + gene lists for downstream analysis
The pydeseq2 implementation faithfully replicates R DESeq2 results, enabling Python-native bioinformatics workflows without R dependencies. For production pipelines, wrap these steps in a Snakemake rule to enable parallelization across multiple contrasts and automatic re-running when inputs change.