Part 8·8.8·45 min read

In Practice: Full RNA-seq Pipeline in Python

A complete RNA-seq analysis pipeline from raw counts to biological interpretation — normalization, QC, differential expression, visualization, and pathway enrichment, implemented in Python with pydeseq2 and GSEApy.

RNA-seqpydeseq2GSEApyPythonbioinformatics pipelinedifferential expression

This chapter implements a complete RNA-seq differential expression analysis pipeline in Python. We cover the full journey from a raw count matrix to a biologically interpreted results table: quality control, normalization, differential expression testing, visualization, and pathway enrichment.

The pipeline uses:

  • pydeseq2: Python reimplementation of DESeq2
  • GSEApy: Python interface to gene set enrichment analysis
  • pandas / numpy: data manipulation
  • matplotlib / seaborn: visualization

We work with a simulated dataset that mirrors a real RNA-seq experiment — a cancer cell line treated with a drug versus DMSO control, 3 replicates each.

Setup

python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
from scipy import stats
from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats
import gseapy as gp
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)

Step 1: Simulate a Realistic Count Matrix

In practice you'd load a counts matrix from featureCounts, STAR/RSEM, or Salmon. Here we simulate one with realistic properties: overdispersion, differential expression in ~10% of genes, and realistic count magnitudes.

python
def simulate_rnaseq_counts(n_genes=10000, n_control=3, n_treated=3,
                            n_de_genes=1000, seed=42):
    """Simulate RNA-seq count data with realistic properties."""
    rng = np.random.default_rng(seed)
    
    n_samples = n_control + n_treated
    
    # Gene base expression levels: log-normal distribution
    # Most genes are lowly expressed; a few are highly expressed
    base_means = rng.lognormal(mean=3.0, sigma=2.0, size=n_genes)
    
    # Dispersion: inversely related to mean (as in real RNA-seq)
    # High dispersion for low-count genes, low dispersion for high-count
    dispersions = 0.5 / (base_means ** 0.5) + 0.05
    
    # Sample size factors (simulate modest library size differences)
    size_factors = rng.uniform(0.8, 1.3, size=n_samples)
    
    # Fold changes for DE genes
    de_gene_indices = rng.choice(n_genes, size=n_de_genes, replace=False)
    
    # Mix of up and down regulated genes
    log2fc = np.zeros(n_genes)
    n_up = n_de_genes // 2
    log2fc[de_gene_indices[:n_up]] = rng.uniform(0.8, 3.0, size=n_up)
    log2fc[de_gene_indices[n_up:]] = -rng.uniform(0.8, 3.0, size=n_de_genes - n_up)
    
    # Generate counts: negative binomial
    counts = np.zeros((n_genes, n_samples), dtype=int)
    
    for j in range(n_samples):
        sf = size_factors[j]
        is_treated = j >= n_control
        
        for i in range(n_genes):
            mean = base_means[i] * sf
            if is_treated:
                mean *= 2 ** log2fc[i]
            
            # NB: parameterized by mean and dispersion
            # p = mean / (mean + 1/disp); r = 1/disp
            disp = dispersions[i]
            r = 1.0 / disp
            p = mean / (mean + r)
            counts[i, j] = rng.negative_binomial(r, 1 - p)
    
    # Build metadata
    sample_names = [f"control_{i+1}" for i in range(n_control)] + \
                   [f"treated_{i+1}" for i in range(n_treated)]
    gene_names = [f"GENE_{i:05d}" for i in range(n_genes)]
    
    counts_df = pd.DataFrame(counts, index=gene_names, columns=sample_names)
    metadata_df = pd.DataFrame({
        'condition': ['control'] * n_control + ['treated'] * n_treated,
        'batch': ['batch1', 'batch1', 'batch2', 'batch1', 'batch2', 'batch2']
    }, index=sample_names)
    
    true_de = set(gene_names[i] for i in de_gene_indices)
    
    return counts_df, metadata_df, true_de

counts_df, metadata_df, true_de_genes = simulate_rnaseq_counts()

print(f"Count matrix shape: {counts_df.shape}")
print(f"\nSample metadata:")
print(metadata_df)
print(f"\nCount matrix (first 5 genes, all samples):")
print(counts_df.head())
print(f"\nTrue DE genes: {len(true_de_genes)}")
Count matrix shape: (10000, 6)

Sample metadata:
           condition   batch
control_1    control  batch1
control_2    control  batch1
control_3    control  batch2
treated_1    treated  batch1
treated_2    treated  batch2
treated_3    treated  batch2

True DE genes: 1000

Step 2: Quality Control

Always inspect the data before analysis. QC should catch: low-quality samples, contamination, batch effects visible before correction, and outlier samples.

python
def plot_library_sizes(counts_df, metadata_df):
    """Bar chart of total counts per sample (library size)."""
    lib_sizes = counts_df.sum(axis=0) / 1e6  # in millions
    
    colors = ['#4C9BE8' if c == 'control' else '#E84C4C'
              for c in metadata_df['condition']]
    
    fig, ax = plt.subplots(figsize=(8, 4))
    bars = ax.bar(lib_sizes.index, lib_sizes.values, color=colors)
    ax.set_ylabel('Library size (million reads)')
    ax.set_title('Library Sizes by Sample')
    ax.tick_params(axis='x', rotation=45)
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor='#4C9BE8', label='Control'),
                       Patch(facecolor='#E84C4C', label='Treated')]
    ax.legend(handles=legend_elements)
    
    plt.tight_layout()
    plt.savefig('qc_library_sizes.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nLibrary sizes (millions reads):")
    for sample, size in lib_sizes.items():
        print(f"  {sample}: {size:.1f}M")

def plot_count_distributions(counts_df, metadata_df):
    """Log-count distributions per sample — should be similar between samples."""
    # Log-transform (add 1 for zeros)
    log_counts = np.log2(counts_df + 1)
    
    fig, ax = plt.subplots(figsize=(10, 5))
    
    for sample in counts_df.columns:
        condition = metadata_df.loc[sample, 'condition']
        color = '#4C9BE8' if condition == 'control' else '#E84C4C'
        alpha = 0.6
        
        log_counts[sample].plot.density(ax=ax, color=color, alpha=alpha, label=sample)
    
    ax.set_xlabel('log2(count + 1)')
    ax.set_ylabel('Density')
    ax.set_title('Count Distributions by Sample')
    ax.legend(loc='upper right', fontsize=8)
    
    plt.tight_layout()
    plt.savefig('qc_count_distributions.png', dpi=150, bbox_inches='tight')
    plt.show()

def detect_zero_fraction(counts_df):
    """Fraction of genes with zero counts per sample."""
    zero_frac = (counts_df == 0).mean(axis=0)
    print("\nZero count fraction per sample:")
    for sample, frac in zero_frac.items():
        print(f"  {sample}: {frac:.1%}")
    return zero_frac

plot_library_sizes(counts_df, metadata_df)
plot_count_distributions(counts_df, metadata_df)
zero_fracs = detect_zero_fraction(counts_df)

PCA for Sample-Level QC

python
def pca_plot(counts_df, metadata_df, title="PCA of samples"):
    """
    PCA on log-normalized counts.
    
    Critical: PCA must be run on log-transformed, normalized counts.
    Raw counts give misleading PCA due to mean-variance dependence.
    """
    from sklearn.preprocessing import StandardScaler
    from sklearn.decomposition import PCA
    
    # Normalize: log2(CPM + 1)
    cpm = counts_df.divide(counts_df.sum(axis=0), axis=1) * 1e6
    log_cpm = np.log2(cpm + 1)
    
    # Filter to top 2000 most variable genes
    gene_variances = log_cpm.var(axis=1)
    top_genes = gene_variances.nlargest(2000).index
    log_cpm_filtered = log_cpm.loc[top_genes]
    
    # PCA (samples are rows)
    X = log_cpm_filtered.T.values
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    pca = PCA(n_components=4)
    pcs = pca.fit_transform(X_scaled)
    
    # Plot PC1 vs PC2
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    color_map = {'control': '#4C9BE8', 'treated': '#E84C4C'}
    marker_map = {'batch1': 'o', 'batch2': 's'}
    
    sample_names = counts_df.columns.tolist()
    
    for ax_idx, (pc_x, pc_y, xlabel, ylabel) in enumerate([
        (0, 1, f'PC1 ({pca.explained_variance_ratio_[0]:.1%})',
                f'PC2 ({pca.explained_variance_ratio_[1]:.1%})'),
        (0, 2, f'PC1 ({pca.explained_variance_ratio_[0]:.1%})',
                f'PC3 ({pca.explained_variance_ratio_[2]:.1%})')
    ]):
        ax = axes[ax_idx]
        for i, sample in enumerate(sample_names):
            condition = metadata_df.loc[sample, 'condition']
            batch = metadata_df.loc[sample, 'batch']
            ax.scatter(pcs[i, pc_x], pcs[i, pc_y],
                      c=color_map[condition],
                      marker=marker_map[batch],
                      s=120, zorder=5)
            ax.annotate(sample, (pcs[i, pc_x], pcs[i, pc_y]),
                       textcoords="offset points", xytext=(5, 5), fontsize=7)
        
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.axhline(0, color='gray', lw=0.5)
        ax.axvline(0, color='gray', lw=0.5)
    
    axes[0].set_title('PC1 vs PC2')
    axes[1].set_title('PC1 vs PC3')
    
    # Shared legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#4C9BE8',
               markersize=10, label='Control'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#E84C4C',
               markersize=10, label='Treated'),
        Line2D([0], [0], marker='o', color='gray', markersize=10, label='Batch 1'),
        Line2D([0], [0], marker='s', color='gray', markersize=10, label='Batch 2'),
    ]
    axes[1].legend(handles=legend_elements, loc='upper right', fontsize=8)
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig('qc_pca.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nVariance explained by each PC:")
    for i, var in enumerate(pca.explained_variance_ratio_[:4]):
        print(f"  PC{i+1}: {var:.1%}")
    
    return pca

pca_result = pca_plot(counts_df, metadata_df)

Step 3: Pre-filtering

Remove genes with very low counts. These genes have no power to detect differential expression and inflate multiple testing correction burden.

python
def filter_low_counts(counts_df, min_count=10, min_samples=2):
    """
    Keep genes with at least min_count in at least min_samples.
    
    This is applied before DESeq2 — DESeq2 also does independent
    filtering internally, but manual pre-filtering reduces memory
    and computation for very large count matrices.
    """
    # At least min_count counts in at least min_samples samples
    keep = (counts_df >= min_count).sum(axis=1) >= min_samples
    
    counts_filtered = counts_df.loc[keep]
    
    print(f"\nPre-filtering:")
    print(f"  Before: {counts_df.shape[0]:,} genes")
    print(f"  After:  {counts_filtered.shape[0]:,} genes")
    print(f"  Removed: {counts_df.shape[0] - counts_filtered.shape[0]:,} low-count genes")
    
    return counts_filtered

counts_filtered = filter_low_counts(counts_df, min_count=10, min_samples=2)

Step 4: DESeq2 Analysis with pydeseq2

pydeseq2 is a Python reimplementation of the R DESeq2 package. It implements the same statistical model: negative binomial GLM with empirical Bayes dispersion shrinkage and Wald test.

python
def run_deseq2(counts_df, metadata_df, design_column='condition',
               reference_level='control'):
    """
    Run DESeq2 differential expression analysis.
    
    pydeseq2 expects:
    - counts: samples × genes (transposed from our genes × samples)
    - metadata: samples × covariates
    """
    # pydeseq2 expects samples as rows, genes as columns
    counts_t = counts_df.T
    
    # Ensure metadata and counts have matching indices
    assert all(counts_t.index == metadata_df.index), \
        "Sample order mismatch between counts and metadata"
    
    print(f"\nRunning DESeq2:")
    print(f"  Samples: {counts_t.shape[0]}")
    print(f"  Genes: {counts_t.shape[1]}")
    print(f"  Design: ~{design_column}")
    print(f"  Reference level: {reference_level}")
    
    # Create DESeqDataSet
    inference = DefaultInference(n_cpus=4)
    dds = DeseqDataSet(
        counts=counts_t,
        metadata=metadata_df,
        design_factors=design_column,
        ref_level=[design_column, reference_level],
        refit_cooks=True,
        inference=inference
    )
    
    # Run DESeq2: size factors → dispersions → GLM → Wald test
    dds.deseq2()
    
    print("\nSize factors:")
    for sample, sf in dds.obsm['size_factors'].items():
        print(f"  {sample}: {sf:.3f}")
    
    return dds

dds = run_deseq2(counts_filtered, metadata_df)

Extract and Format Results

python
def extract_results(dds, contrast=('condition', 'treated', 'control'),
                    alpha=0.05, lfc_threshold=0.0):
    """Extract DESeq2 results with LFC shrinkage."""
    
    stat_res = DeseqStats(
        dds,
        contrast=list(contrast),
        alpha=alpha,
        cooks_filter=True,
        independent_filter=True
    )
    
    # Compute Wald test statistics
    stat_res.run_wald_test()
    
    # Apply LFC shrinkage (apeglm-equivalent in pydeseq2)
    stat_res.lfc_shrink(coeff=f"{contrast[0]}_{contrast[1]}_vs_{contrast[2]}")
    
    # Extract results table
    results = stat_res.results_df.copy()
    
    # Summary statistics
    n_sig = (results['padj'] < alpha).sum()
    n_sig_lfc = ((results['padj'] < alpha) & 
                 (results['log2FoldChange'].abs() >= 1)).sum()
    
    print(f"\nDESeq2 Results Summary:")
    print(f"  Total genes tested: {len(results):,}")
    print(f"  Genes with padj < {alpha}: {n_sig:,}")
    print(f"  Genes with padj < {alpha} AND |log2FC| >= 1: {n_sig_lfc:,}")
    
    n_up = ((results['padj'] < alpha) & (results['log2FoldChange'] >= 1)).sum()
    n_down = ((results['padj'] < alpha) & (results['log2FoldChange'] <= -1)).sum()
    print(f"    Upregulated: {n_up:,}")
    print(f"    Downregulated: {n_down:,}")
    
    return results, stat_res

results, stat_res = extract_results(dds)

print("\nResults table (top 10 by adjusted p-value):")
top_hits = results.dropna(subset=['padj']).nsmallest(10, 'padj')
print(top_hits[['baseMean', 'log2FoldChange', 'lfcSE', 'stat', 'pvalue', 'padj']]
      .round(4).to_string())

Step 5: Assessing Performance Against Ground Truth

Since we simulated the data with known DE genes, we can evaluate how well DESeq2 recovered them.

python
def evaluate_detection(results, true_de_genes, padj_threshold=0.05, lfc_threshold=1.0):
    """
    Compute precision, recall, and F1 for DE gene detection.
    
    In real experiments you don't have ground truth — but this validates
    that the statistical model is working correctly.
    """
    results_clean = results.dropna(subset=['padj'])
    
    # Called significant by DESeq2
    called_sig = set(results_clean[
        (results_clean['padj'] < padj_threshold) & 
        (results_clean['log2FoldChange'].abs() >= lfc_threshold)
    ].index)
    
    # True DE genes that were in our filtered count matrix
    true_de_in_matrix = true_de_genes & set(results_clean.index)
    
    # Confusion matrix components
    tp = len(called_sig & true_de_in_matrix)
    fp = len(called_sig - true_de_in_matrix)
    fn = len(true_de_in_matrix - called_sig)
    tn = len(set(results_clean.index) - called_sig - true_de_in_matrix)
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"\nPerformance evaluation (padj < {padj_threshold}, |log2FC| >= {lfc_threshold}):")
    print(f"  True DE genes in matrix: {len(true_de_in_matrix):,}")
    print(f"  Called significant:      {len(called_sig):,}")
    print(f"  True positives:          {tp:,}")
    print(f"  False positives:         {fp:,}")
    print(f"  False negatives:         {fn:,}")
    print(f"  Precision: {precision:.3f}")
    print(f"  Recall:    {recall:.3f}")
    print(f"  F1:        {f1:.3f}")
    
    return {'precision': precision, 'recall': recall, 'f1': f1,
            'tp': tp, 'fp': fp, 'fn': fn}

perf = evaluate_detection(results, true_de_genes)

Step 6: Visualization

MA Plot

python
def plot_ma(results, padj_threshold=0.05, title="MA Plot"):
    """
    MA plot: log2 fold change vs. mean expression.
    
    Red points = significant. After LFC shrinkage, low-count genes
    have LFCs shrunk toward zero — they don't dominate the plot.
    """
    results_clean = results.dropna(subset=['padj', 'log2FoldChange'])
    
    sig = results_clean['padj'] < padj_threshold
    
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Non-significant
    ax.scatter(
        np.log2(results_clean.loc[~sig, 'baseMean'] + 1),
        results_clean.loc[~sig, 'log2FoldChange'],
        alpha=0.3, s=3, color='#AAAAAA', rasterized=True
    )
    
    # Significant
    ax.scatter(
        np.log2(results_clean.loc[sig, 'baseMean'] + 1),
        results_clean.loc[sig, 'log2FoldChange'],
        alpha=0.7, s=8, color='#E84C4C', rasterized=True
    )
    
    ax.axhline(0, color='black', lw=0.8)
    ax.axhline(1, color='blue', lw=0.5, linestyle='--', alpha=0.5)
    ax.axhline(-1, color='blue', lw=0.5, linestyle='--', alpha=0.5)
    
    ax.set_xlabel('log2(Mean normalized count + 1)')
    ax.set_ylabel('log2 Fold Change (treated / control)')
    ax.set_title(title)
    
    n_sig = sig.sum()
    ax.text(0.02, 0.98, f'Significant: {n_sig:,}',
            transform=ax.transAxes, va='top', fontsize=9,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
    
    plt.tight_layout()
    plt.savefig('ma_plot.png', dpi=150, bbox_inches='tight')
    plt.show()

plot_ma(results)

Volcano Plot

python
def plot_volcano(results, padj_threshold=0.05, lfc_threshold=1.0,
                 n_label=10, title="Volcano Plot"):
    """
    Volcano plot: -log10(p-value) vs. log2FC.
    
    Color coding:
    - Gray: not significant
    - Blue: |log2FC| > threshold but padj >= threshold
    - Red: significant AND |log2FC| > threshold
    """
    results_clean = results.dropna(subset=['padj', 'pvalue', 'log2FoldChange'])
    
    # Categorize genes
    sig_lfc = ((results_clean['padj'] < padj_threshold) & 
               (results_clean['log2FoldChange'].abs() >= lfc_threshold))
    sig_only = ((results_clean['padj'] < padj_threshold) & 
                (results_clean['log2FoldChange'].abs() < lfc_threshold))
    lfc_only = ((results_clean['padj'] >= padj_threshold) & 
                (results_clean['log2FoldChange'].abs() >= lfc_threshold))
    
    # Cap -log10(pvalue) for visualization
    neg_log_p = -np.log10(results_clean['pvalue'].clip(lower=1e-300))
    
    fig, ax = plt.subplots(figsize=(8, 7))
    
    neither = ~sig_lfc & ~sig_only & ~lfc_only
    ax.scatter(results_clean.loc[neither, 'log2FoldChange'],
               neg_log_p[neither], alpha=0.3, s=4, color='#AAAAAA', rasterized=True)
    ax.scatter(results_clean.loc[lfc_only, 'log2FoldChange'],
               neg_log_p[lfc_only], alpha=0.5, s=6, color='#4C9BE8', rasterized=True)
    ax.scatter(results_clean.loc[sig_only, 'log2FoldChange'],
               neg_log_p[sig_only], alpha=0.5, s=6, color='#FFA500', rasterized=True)
    ax.scatter(results_clean.loc[sig_lfc, 'log2FoldChange'],
               neg_log_p[sig_lfc], alpha=0.7, s=8, color='#E84C4C', rasterized=True)
    
    # Threshold lines
    ax.axhline(-np.log10(padj_threshold), color='gray', lw=0.8, linestyle='--')
    ax.axvline(lfc_threshold, color='gray', lw=0.8, linestyle='--')
    ax.axvline(-lfc_threshold, color='gray', lw=0.8, linestyle='--')
    
    # Label top hits
    top = results_clean[sig_lfc].nsmallest(n_label, 'padj')
    for gene, row in top.iterrows():
        ax.annotate(gene, (row['log2FoldChange'], -np.log10(row['pvalue'])),
                   fontsize=6, xytext=(3, 3), textcoords='offset points')
    
    ax.set_xlabel('log2 Fold Change (treated / control)')
    ax.set_ylabel('-log10(p-value)')
    ax.set_title(title)
    
    # Counts in each quadrant
    n_up = (sig_lfc & (results_clean['log2FoldChange'] > 0)).sum()
    n_down = (sig_lfc & (results_clean['log2FoldChange'] < 0)).sum()
    ax.text(0.98, 0.98, f'Up: {n_up}', transform=ax.transAxes,
            ha='right', va='top', color='#E84C4C', fontsize=10, fontweight='bold')
    ax.text(0.02, 0.98, f'Down: {n_down}', transform=ax.transAxes,
            ha='left', va='top', color='#E84C4C', fontsize=10, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('volcano_plot.png', dpi=150, bbox_inches='tight')
    plt.show()

plot_volcano(results)

Heatmap of Top Differentially Expressed Genes

python
def plot_de_heatmap(counts_df, results, metadata_df,
                    n_genes=50, padj_threshold=0.05):
    """
    Heatmap of top DE genes with hierarchical clustering.
    
    Uses log2(CPM + 1) normalized expression for visualization.
    """
    # Normalize counts
    cpm = counts_df.divide(counts_df.sum(axis=0), axis=1) * 1e6
    log_cpm = np.log2(cpm + 1)
    
    # Select top DE genes by adjusted p-value
    sig = (results['padj'] < padj_threshold) & (results['log2FoldChange'].abs() >= 1)
    sig_genes = results[sig].nsmallest(n_genes, 'padj').index
    sig_genes = [g for g in sig_genes if g in log_cpm.index]
    
    heatmap_data = log_cpm.loc[sig_genes]
    
    # Z-score normalize each gene (row) for visualization
    heatmap_data_z = heatmap_data.subtract(heatmap_data.mean(axis=1), axis=0)
    heatmap_data_z = heatmap_data_z.divide(heatmap_data.std(axis=1), axis=0)
    
    # Column annotation
    condition_colors = {'control': '#4C9BE8', 'treated': '#E84C4C'}
    col_colors = pd.Series(
        [condition_colors[metadata_df.loc[s, 'condition']] for s in heatmap_data_z.columns],
        index=heatmap_data_z.columns
    )
    
    g = sns.clustermap(
        heatmap_data_z,
        col_colors=col_colors,
        col_cluster=True,
        row_cluster=True,
        cmap='RdBu_r',
        vmin=-2, vmax=2,
        figsize=(10, 12),
        xticklabels=True,
        yticklabels=False,
        cbar_kws={'label': 'Z-score'}
    )
    
    g.fig.suptitle(f'Top {len(sig_genes)} DE Genes (Z-score)', y=1.02)
    plt.savefig('de_heatmap.png', dpi=150, bbox_inches='tight')
    plt.show()

plot_de_heatmap(counts_filtered, results, metadata_df)

Step 7: Pathway Enrichment Analysis

Differential expression gives a gene list. Pathway enrichment reveals the biological processes those genes are involved in.

We'll implement two approaches:

  1. Over-Representation Analysis (ORA): hypergeometric test on gene sets
  2. GSEA pre-ranked: Gene Set Enrichment Analysis using the ranked log2FC
python
def run_ora(results, padj_threshold=0.05, lfc_threshold=1.0,
            organism='human'):
    """
    Over-Representation Analysis using GSEApy's enrichr.
    
    Tests whether the DE gene set is enriched for specific pathways
    relative to the background of all tested genes.
    """
    # Significant upregulated genes
    sig_up = results[
        (results['padj'] < padj_threshold) & 
        (results['log2FoldChange'] >= lfc_threshold)
    ].index.tolist()
    
    # Significant downregulated genes
    sig_down = results[
        (results['padj'] < padj_threshold) & 
        (results['log2FoldChange'] <= -lfc_threshold)
    ].index.tolist()
    
    print(f"\nORA input:")
    print(f"  Upregulated genes: {len(sig_up)}")
    print(f"  Downregulated genes: {len(sig_down)}")
    
    # In a real analysis, sig_up/sig_down would be real gene symbols
    # (HGNC names like "TP53", "BRCA1") not simulated gene IDs.
    # Here we demonstrate the API; gene sets would be from MSigDB/KEGG.
    
    # Example with upregulated genes
    # enr_up = gp.enrichr(
    #     gene_list=sig_up,
    #     gene_sets=['KEGG_2021_Human', 'MSigDB_Hallmark_2020'],
    #     organism=organism,
    #     outdir='enrichr_results',
    #     no_plot=True
    # )
    # return enr_up.results
    
    # Since we have simulated gene names, demonstrate the structure:
    print("\n  (In real analysis: pass HGNC gene symbols to gp.enrichr)")
    print("  Gene sets: KEGG_2021_Human, MSigDB_Hallmark_2020, GO_Biological_Process_2023")
    
    return sig_up, sig_down

sig_up_genes, sig_down_genes = run_ora(results)

GSEA Pre-ranked

python
def run_gsea_preranked(results, gene_set_name='MSigDB_Hallmark_2020'):
    """
    GSEA pre-ranked analysis.
    
    Ranks all genes by log2FC (or stat) and tests whether gene sets
    are enriched at the top or bottom of the ranked list.
    
    Advantages over ORA:
    - Uses all genes, not just a binary significant/not cutoff
    - Detects subtle coordinated shifts in gene sets
    - Finds enrichment even when no single gene passes significance threshold
    """
    # Create ranked gene list: use stat (Wald statistic) for ranking
    # stat = log2FC / SE — combines effect size and significance
    ranked = results.dropna(subset=['stat']).sort_values('stat', ascending=False)
    
    # In real analysis: gene names would be HGNC symbols
    # ranked_genes = ranked['stat'].to_dict()  # {gene_name: rank_metric}
    
    print("\nGSEA Pre-ranked setup:")
    print(f"  Ranking metric: Wald statistic (log2FC / SE)")
    print(f"  Genes ranked: {len(ranked):,}")
    print(f"  Top 5 (most upregulated):")
    print(ranked[['log2FoldChange', 'stat', 'padj']].head().round(4).to_string())
    print(f"\n  Bottom 5 (most downregulated):")
    print(ranked[['log2FoldChange', 'stat', 'padj']].tail().round(4).to_string())
    
    # Real GSEA call would be:
    # gsea_res = gp.prerank(
    #     rnk=ranked['stat'],
    #     gene_sets=gene_set_name,
    #     seed=42,
    #     permutation_num=1000,
    #     outdir='gsea_results'
    # )
    # return gsea_res.res2d
    
    print("\n  (In real analysis: pass ranked stats with HGNC gene symbols to gp.prerank)")
    
    return ranked

gsea_ranked = run_gsea_preranked(results)

Manual Hypergeometric Test

For understanding what ORA does internally:

python
from scipy.stats import hypergeom

def hypergeometric_enrichment(gene_set, de_genes, background_genes):
    """
    Test whether a gene set is enriched among DE genes.
    
    Parameters:
    - gene_set: set of genes in the pathway
    - de_genes: set of differentially expressed genes
    - background_genes: all tested genes (universe)
    
    Returns p-value, odds ratio, overlap
    """
    M = len(background_genes)         # total genes in universe
    n = len(gene_set & background_genes)  # genes in pathway (in universe)
    N = len(de_genes & background_genes)  # DE genes (in universe)
    k = len(gene_set & de_genes)          # overlap
    
    if n == 0 or N == 0:
        return 1.0, 1.0, 0
    
    # P(X >= k) where X ~ Hypergeometric(M, n, N)
    pval = hypergeom.sf(k - 1, M, n, N)
    
    # Odds ratio
    expected = n * N / M
    odds_ratio = (k / (N - k)) / (n / (M - n)) if (N - k) > 0 and (M - n) > 0 else np.inf
    
    return pval, odds_ratio, k, expected

# Example: test a few mock gene sets
background = set(results.dropna(subset=['padj']).index)
de_sig = set(results[
    (results['padj'] < 0.05) & (results['log2FoldChange'].abs() >= 1)
].dropna(subset=['padj']).index)

# Create mock gene sets (random subsets of our genes)
rng = np.random.default_rng(42)
all_genes = list(background)

mock_pathways = {
    'Pathway_A': set(rng.choice(all_genes, size=100, replace=False)),
    'Pathway_B': set(rng.choice(all_genes, size=200, replace=False)),
    # Simulate an "enriched" pathway: mostly from DE genes
    'Pathway_C': set(rng.choice(list(de_sig)[:200], size=80, replace=False)) | 
                 set(rng.choice(list(background - de_sig), size=20, replace=False)),
}

print("\nManual hypergeometric enrichment test:")
print(f"  Background size: {len(background):,}")
print(f"  DE gene set size: {len(de_sig):,}")
print(f"\n  {'Pathway':<15} {'Size':>6} {'Overlap':>8} {'Expected':>9} {'P-value':>12} {'Odds Ratio':>12}")
print("  " + "-" * 70)

pvals = []
for pathway_name, gene_set in mock_pathways.items():
    pval, odds_ratio, k, expected = hypergeometric_enrichment(
        gene_set, de_sig, background
    )
    pvals.append(pval)
    print(f"  {pathway_name:<15} {len(gene_set):>6} {k:>8} {expected:>9.1f} "
          f"{pval:>12.2e} {odds_ratio:>12.2f}")

# Correct for multiple testing
from statsmodels.stats.multitest import multipletests
reject, pvals_adj, _, _ = multipletests(pvals, method='fdr_bh')
print(f"\n  After BH correction:")
for i, (pathway_name, padj) in enumerate(zip(mock_pathways.keys(), pvals_adj)):
    print(f"  {pathway_name}: padj = {padj:.4f} {'*' if padj < 0.05 else ''}")

Step 8: Normalized Expression Visualization

Expression of Top Hits

python
def plot_top_gene_expression(counts_df, results, metadata_df, n_genes=9):
    """
    Box plots of normalized expression for top DE genes.
    
    Shows the actual expression values that drive the LFC estimate.
    """
    # Normalize
    cpm = counts_df.divide(counts_df.sum(axis=0), axis=1) * 1e6
    log_cpm = np.log2(cpm + 1)
    
    # Top DE genes (mix of up and down)
    sig = (results['padj'] < 0.05) & (results['log2FoldChange'].abs() >= 1)
    top_up = results[sig & (results['log2FoldChange'] > 0)].nsmallest(
        n_genes // 2, 'padj').index
    top_down = results[sig & (results['log2FoldChange'] < 0)].nsmallest(
        n_genes // 2, 'padj').index
    
    top_genes = list(top_up) + list(top_down)
    top_genes = [g for g in top_genes if g in log_cpm.index][:n_genes]
    
    n_cols = 3
    n_rows = (len(top_genes) + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 4 * n_rows))
    axes = axes.flatten()
    
    for idx, gene in enumerate(top_genes):
        ax = axes[idx]
        
        ctrl_expr = log_cpm.loc[gene, metadata_df['condition'] == 'control'].values
        trt_expr = log_cpm.loc[gene, metadata_df['condition'] == 'treated'].values
        
        ax.boxplot([ctrl_expr, trt_expr], labels=['Control', 'Treated'],
                  notch=False, patch_artist=True,
                  boxprops=dict(facecolor='lightblue', alpha=0.7))
        
        # Overlay individual points
        for j, (vals, color) in enumerate([(ctrl_expr, '#4C9BE8'), (trt_expr, '#E84C4C')]):
            ax.scatter([j + 1] * len(vals), vals, color=color, zorder=5, s=40)
        
        lfc = results.loc[gene, 'log2FoldChange']
        padj = results.loc[gene, 'padj']
        ax.set_title(f'{gene}\nlog2FC={lfc:.2f}, padj={padj:.2e}', fontsize=9)
        ax.set_ylabel('log2(CPM + 1)')
    
    # Hide unused subplots
    for idx in range(len(top_genes), len(axes)):
        axes[idx].set_visible(False)
    
    plt.suptitle('Expression of Top Differentially Expressed Genes', y=1.02)
    plt.tight_layout()
    plt.savefig('top_gene_expression.png', dpi=150, bbox_inches='tight')
    plt.show()

plot_top_gene_expression(counts_filtered, results, metadata_df)

Step 9: Export Results

python
def export_results(results, metadata_df, output_prefix='deseq2'):
    """Export analysis results to standard formats."""
    
    # Sort by adjusted p-value
    results_sorted = results.sort_values('padj', na_position='last')
    
    # Full results table
    results_sorted.to_csv(f'{output_prefix}_all_results.tsv', sep='\t')
    print(f"Saved: {output_prefix}_all_results.tsv ({len(results_sorted):,} genes)")
    
    # Significant genes only
    sig = (results_sorted['padj'] < 0.05) & (results_sorted['log2FoldChange'].abs() >= 1)
    sig_results = results_sorted[sig].dropna(subset=['padj'])
    sig_results.to_csv(f'{output_prefix}_significant.tsv', sep='\t')
    print(f"Saved: {output_prefix}_significant.tsv ({len(sig_results):,} genes)")
    
    # Gene lists for enrichment tools (just gene names)
    up = sig_results[sig_results['log2FoldChange'] > 0].index.tolist()
    down = sig_results[sig_results['log2FoldChange'] < 0].index.tolist()
    
    with open(f'{output_prefix}_upregulated_genes.txt', 'w') as f:
        f.write('\n'.join(up))
    with open(f'{output_prefix}_downregulated_genes.txt', 'w') as f:
        f.write('\n'.join(down))
    
    print(f"Saved: gene lists ({len(up)} up, {len(down)} down)")
    
    # Summary statistics
    summary = {
        'total_genes_tested': len(results_sorted),
        'significant_padj05': int((results_sorted['padj'] < 0.05).sum()),
        'significant_padj05_lfc1': int(sig.sum()),
        'upregulated': len(up),
        'downregulated': len(down),
        'median_baseMean_sig': float(sig_results['baseMean'].median()),
        'median_absLFC_sig': float(sig_results['log2FoldChange'].abs().median())
    }
    
    import json
    with open(f'{output_prefix}_summary.json', 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"\nAnalysis summary:")
    for key, val in summary.items():
        print(f"  {key}: {val}")

export_results(results, metadata_df)

Step 10: Complete Pipeline Function

Putting it all together:

python
def full_deseq2_pipeline(counts_path_or_df, metadata_path_or_df,
                          design_column='condition',
                          reference_level='control',
                          output_dir='deseq2_results'):
    """
    Complete RNA-seq differential expression pipeline.
    
    Parameters:
    - counts_path_or_df: path to TSV or DataFrame (genes × samples)
    - metadata_path_or_df: path to TSV or DataFrame (samples × covariates)
    - design_column: metadata column to test
    - reference_level: reference condition in design_column
    - output_dir: directory for output files
    
    Returns:
    - results DataFrame with all DESeq2 statistics
    """
    import os
    os.makedirs(output_dir, exist_ok=True)
    
    # Load data
    if isinstance(counts_path_or_df, str):
        counts = pd.read_csv(counts_path_or_df, index_col=0, sep='\t')
    else:
        counts = counts_path_or_df.copy()
    
    if isinstance(metadata_path_or_df, str):
        metadata = pd.read_csv(metadata_path_or_df, index_col=0, sep='\t')
    else:
        metadata = metadata_path_or_df.copy()
    
    print("=" * 60)
    print("RNA-seq Differential Expression Pipeline")
    print("=" * 60)
    
    # QC
    print("\n[1/6] Quality Control")
    zero_fracs = detect_zero_fraction(counts)
    
    # Pre-filter
    print("\n[2/6] Pre-filtering low-count genes")
    counts_filt = filter_low_counts(counts, min_count=10, min_samples=2)
    
    # Run DESeq2
    print("\n[3/6] Running DESeq2")
    dds = run_deseq2(counts_filt, metadata, design_column, reference_level)
    
    # Extract results
    print("\n[4/6] Extracting results")
    results, stat_res = extract_results(dds)
    
    # Visualize
    print("\n[5/6] Generating visualizations")
    plot_ma(results, title="MA Plot — treated vs. control")
    plot_volcano(results, title="Volcano Plot — treated vs. control")
    try:
        plot_de_heatmap(counts_filt, results, metadata)
        plot_top_gene_expression(counts_filt, results, metadata)
    except Exception as e:
        print(f"  Warning: visualization failed: {e}")
    
    # Export
    print("\n[6/6] Exporting results")
    import os
    prefix = os.path.join(output_dir, f'deseq2_{design_column}')
    export_results(results, metadata, prefix)
    
    print("\n" + "=" * 60)
    print("Pipeline complete.")
    print("=" * 60)
    
    return results

# Run the complete pipeline
final_results = full_deseq2_pipeline(
    counts_df,
    metadata_df,
    design_column='condition',
    reference_level='control',
    output_dir='pipeline_output'
)

Common Pitfalls and How to Avoid Them

python
def check_common_mistakes(counts_df, metadata_df, results):
    """
    Diagnostic checks for common RNA-seq analysis errors.
    
    Checks:
    1. Sample order consistency
    2. Reference level in metadata
    3. p-value histogram shape (should be right-skewed with anti-conservative pile-up)
    4. Cook's distance outliers
    5. Independent filtering effectiveness
    """
    print("\nDiagnostic checks:")
    
    # 1. Sample order
    counts_samples = set(counts_df.columns)
    metadata_samples = set(metadata_df.index)
    if counts_samples != metadata_samples:
        print("  [FAIL] Sample mismatch between counts and metadata!")
        print(f"    In counts not metadata: {counts_samples - metadata_samples}")
        print(f"    In metadata not counts: {metadata_samples - counts_samples}")
    else:
        print("  [OK] Sample names consistent between counts and metadata")
    
    # 2. p-value histogram
    pvals_clean = results['pvalue'].dropna()
    
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.hist(pvals_clean, bins=50, edgecolor='black', linewidth=0.5, color='steelblue')
    ax.set_xlabel('p-value')
    ax.set_ylabel('Number of genes')
    ax.set_title('P-value Histogram\n(should have spike at 0, flat elsewhere)')
    
    # Check for anti-conservative peak at 0
    frac_small = (pvals_clean < 0.05).mean()
    ax.text(0.5, 0.95, f'Fraction p<0.05: {frac_small:.1%}',
            transform=ax.transAxes, ha='center', va='top')
    
    plt.tight_layout()
    plt.savefig('pvalue_histogram.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    if frac_small < 0.05:
        print(f"  [WARNING] Only {frac_small:.1%} of genes have p < 0.05.")
        print("    This may indicate insufficient power or very few DE genes.")
    elif frac_small > 0.5:
        print(f"  [WARNING] {frac_small:.1%} of genes have p < 0.05.")
        print("    This may indicate batch effects or sample swap confounding biology.")
    else:
        print(f"  [OK] {frac_small:.1%} of genes have p < 0.05 (reasonable signal)")
    
    # 3. NA in results (from independent filtering)
    n_na = results['padj'].isna().sum()
    n_total = len(results)
    print(f"  [INFO] Genes filtered by independent filtering: {n_na:,} / {n_total:,} "
          f"({n_na/n_total:.1%})")
    if n_na / n_total > 0.5:
        print("    [WARNING] >50% of genes filtered — consider less aggressive pre-filtering")

check_common_mistakes(counts_filtered, metadata_df, results)

Summary: The Complete Workflow

Raw count matrix (genes × samples)
    ↓ filter_low_counts()
Pre-filtered counts
    ↓ pca_plot() — check for outliers, batch effects
    ↓ run_deseq2()
DESeqDataSet with size factors + dispersion estimates
    ↓ extract_results()
Results table: log2FC, SE, stat, pvalue, padj
    ↓ plot_ma(), plot_volcano(), plot_de_heatmap()
Visualizations
    ↓ run_ora() / run_gsea_preranked()
Enriched pathways
    ↓ export_results()
TSV files + gene lists for downstream analysis

The pydeseq2 implementation faithfully replicates R DESeq2 results, enabling Python-native bioinformatics workflows without R dependencies. For production pipelines, wrap these steps in a Snakemake rule to enable parallelization across multiple contrasts and automatic re-running when inputs change.