An adversarial autoencoder for bulk RNA-seq batch effect correction. It learns a latent representation that preserves biological signal (optional supervised head) while discouraging batch-specific variation via a gradient reversal adversary. Outputs a batch-corrected expression matrix (logCPM scale) and optional latent embedding plus visual diagnostics.
The encoder and decoder are multi-layer perceptrons with LayerNorm, SiLU activations, and dropout. A gradient reversal layer feeds the latent code to the batch adversary to discourage batch-specific signal, while an optional supervised head maintains biological labels when provided.
- Counts → library-size normalisation → CPM → log1p
- Optional highly-variable gene (HVG) selection
- Per‑gene z-score standardisation for stable optimisation (then inverse transformed to logCPM)
- Autoencoder + gradient reversal batch classifier (domain adversarial)
- Optional supervised label head to preserve biology (e.g. condition)
- Flexible adversarial lambda schedules: linear | constant | sigmoid | adaptive
- Multiple reconstruction losses: MSE | MAE | Huber
- Optional mixed precision (AMP) on CUDA
- Early stopping + LR schedulers (plateau / cosine)
- Gradient accumulation for large models / small GPUs
- Weights & Biases logging (offline or online)
- Post-training PCA + boxplot visualisations (before/after correction)
- Latent silhouette diagnostics (batch & label)
- Robust input orientation detection (genes×samples or samples×genes)
NN_batch_correct.pyMain training + correction scriptmodels/Model components (adversarial AE) and factorymodels/ae.pyAEBatchCorrector, GradReverseLayer, ResidualBlock, make_mlpmodels/factory.pyHelper to build AE or VAE+Attention models from CLI args
data/Example inputs (data/bulk_counts.csv,data/sample_meta.csv,data/test_counts.csv,data/test_meta.csv)artifacts/outputs/Generated demo outputs (corrected_logcpm.csv,latent*.csv,pca_*.png,logCPM_boxplots.png)artifacts/checkpoints/Example trained weights (trained_model.pt,model_best.pt,best_model.pt)assets/README assets (logo, PCA panel)visualise.pyStandalone PCA/boxplot + architecture diagram utilities
- Either genes in rows & samples in columns (common) OR samples in rows & genes in columns.
- Index (first column) must be feature IDs (genes) or sample IDs; orientation is auto‑detected / override with
--genes_in_rows.
Must include at minimum:
sample(configurable via--sample_col)batch(configurable via--batch_col) Optional:- Biological label column (e.g.
condition) for supervised preservation (--label_col).
Samples present in both files are intersected. If <2 overlaps are found an orientation / whitespace rescue pass is attempted before failing.
Create & activate a Python environment (>=3.10 recommended). Then:
pip install -r requirements.txtOptional extras:
- Set
WANDB_MODE=offlineto avoid network usage.
- Model code has been extracted from
NN_batch_correct.pyintomodels/ae.pyand is re-exported viamodels/__init__.py. The CLI behavior is unchanged. - A lightweight factory (
models/factory.py) centralizes model construction for AE vs. VAE+Attention to keep the training script focused on data and orchestration.
Minimal unsupervised run (only batches):
python NN_batch_correct.py `
--counts data/bulk_counts.csv `
--metadata data/sample_meta.csv `
--genes_in_rows `
--out_corrected artifacts/outputs/corrected_logcpm.csvWith supervised label preservation + latent export + model save + visualisations:
python NN_batch_correct.py `
--counts data/bulk_counts.csv `
--metadata data/sample_meta.csv `
--genes_in_rows `
--label_col condition `
--hvg 5000 `
--latent_dim 32 `
--enc_hidden 1024,256 `
--dec_hidden 256,1024 `
--adv_hidden 128 `
--sup_hidden 64 `
--adv_lambda_schedule adaptive `
--adv_weight 1.0 `
--sup_weight 1.0 `
--epochs 200 `
--batch_size 64 `
--dropout 0.1 `
--recon_loss mse `
--out_corrected artifacts/outputs/corrected_logcpm.csv `
--out_latent artifacts/outputs/latent.csv `
--save_model model_best.pt `
--generate_viz `
--viz_hvg_top 2000Enable AMP (CUDA) + cosine scheduler + W&B logging:
$env:WANDB_MODE="offline" # optional
python NN_batch_correct.py `
--counts data/bulk_counts.csv `
--metadata data/sample_meta.csv `
--genes_in_rows `
--adv_lambda_schedule sigmoid `
--scheduler cosine `
--amp `
--use_wandb `
--out_corrected artifacts/outputs/corrected_logcpm.csv--counts / --metadataInput files--genes_in_rowsSet if counts file is genes×samples--hvg NKeep top-N variable genes (0 = all)--label_colEnable supervised biology head--adv_lambda_schedulelinear | constant | sigmoid | adaptive- adaptive: dynamically scales GRL lambda toward near-random adversary accuracy
--adv_weight / --sup_weightRelative loss contribution scalars--recon_lossmse | mae | huber--grad_accumGradient accumulation steps--schedulernone | plateau | cosine--ampMixed precision (GPU)--expected_batchesAssert exact number of batches present--label_valuesComma list of required labels per batch (quality control)--out_correctedOutput corrected matrix (logCPM)--out_latent(Optional) latent embedding CSV--save_modelSave trained weights & metadata--generate_vizTrigger PCA + boxplots (usesvisualise.pyfunctions)
Run python NN_batch_correct.py -h for the full list.
- Corrected matrix: samples × genes (logCPM scale) ready for downstream analysis.
- Latent embedding:
z1..zKcolumns (optional). - Saved model
.pt: includes state_dict, class labels, gene list. - PCA plots: pre & post correction (
--generate_viz). - Boxplots: logCPM distributions before vs after, per batch.
- Silhouette scores (printed): lower batch, higher label = good disentanglement.
You can run it standalone (e.g. after modifying parameters):
python visualise.py `
--counts data/bulk_counts.csv `
--metadata data/sample_meta.csv `
--genes_in_rows `
--corrected artifacts/outputs/corrected_logcpm.csv `
--hvg_top 2000It will regenerate PCA plots and boxplots. Without --corrected, only "before" plots are produced.
An architecture diagram is generated automatically as nn_architecture.png the first time visualise.py is executed.
- linear: ramps 0 → adv_weight across first ~1/3 of epochs
- constant: fixed adv_weight
- sigmoid: slow start, sharp middle growth, saturates
- adaptive: adjusts each epoch to keep adversary accuracy near random; prevents over/under powering the encoder
Use --seed (default 42). Seed applies to Python, NumPy, and Torch (CPU & CUDA). For completely deterministic CUDA runs you may need extra backend flags (not set here to retain performance).
- Inspect and clean metadata (consistent sample IDs)
- Run a baseline correction (unsupervised) & inspect PCA / boxplots
- Add label supervision if biological grouping should be preserved
- Try alternative lambda schedules (adaptive vs linear)
- Tune latent dimension / hidden widths (watch overfitting via val loss)
- Export latent embedding for downstream clustering / differential analyses
Currently the script trains & corrects in one pass. To reuse a saved model on new samples:
- Load the
.ptfile - Apply the same preprocessing: library-size normalise → log1p → (restrict to training genes) → z-score using training stats (not yet stored separately; extending the checkpoint to save scaler mean/std would be a small enhancement)
- Run encoder→decoder forward pass to obtain corrected logCPM (after inverse standardisation)
(Enhancement idea: save scaler mean/std in checkpoint for direct reuse.)
See requirements.txt (PyTorch CPU version shown; install CUDA variant as needed).
- Shape / overlap errors: confirm sample IDs exactly match between counts & metadata (case, whitespace).
- Few overlapping samples: verify orientation; try adding
--genes_in_rows. - Adversary dominates (batch acc ~1.0, poor reconstruction): try
--adv_lambda_schedule adaptiveor lower--adv_weight. - Biological separation lost: add / tune
--label_coland increase--sup_weightmoderately. - Slow training: reduce hidden sizes or HVG count; enable
--ampon GPU.
MIT (see header in source). Cite this repository if you use it in a publication.
Inspired by domain adversarial training paradigms (Ganin & Lempitsky) adapted for bulk RNA-seq batch correction.
Happy correcting!

