Skip to content

brooks-code/ObsCure_MNIST

ObsCure MNIST

Missing those classic handwritten digits.

Banner Image
Post office scene ("The tin drum" by Volker Schlöndorff, 1979).

The origins: in the late 1980s, the US Census Bureau was interested in automatic digitization of handwritten census forms. Over time MNIST became the de facto starting point for evaluating new machine learning architectures and training techniques. More on Wikipedia

This repo contains the source of a submission to the Kaggle MNIST competition. This project grew out of a few productive rainy afternoons spent refining a compact, ResNet-inspired PyTorch model with MixUp/CutMix, GhostBatchNorm, Lookahead, AMP, and SWA.

Quick start

  1. Install requirements:
pip install -r requirements.txt
  1. Train:
python3 ObsCure_MNIST.py

Components

  • Model: WideSmallResNet (BasicBlock) with GhostBatchNorm2d.
  • Augmentation: strong/light transforms, MixUp and CutMix utilities.
  • Optimizers: SGD base optimizer wrapped by a Lookahead implementation.
  • Mixed precision: torch.amp GradScaler + autocast used in training.
  • SWA: torch.optim.swa_utils AveragedModel and SWALR used in later epochs.
  • Utilities: ensemble evaluation, test-time augmentation (TTA), BN update for SWA.

Architecture

Contents - click to expand

Banner Image

Configuration

Contents - click to expand
Parameter Value Purpose / Explanation
DEVICE torch.device("cuda" if torch.cuda.is_available() else "cpu") Specifies where tensors and model run: GPU if available, otherwise CPU.
SEED 42 Random seed for reproducibility (controls RNG for torch, numpy, etc.).
BATCH_SIZE 256 Number of samples per training batch. Larger batches speed throughput but use more memory.
GHOST_BATCH 32 Mini-batch size used inside Ghost Batch Normalization to simulate smaller-batch statistics within a large BATCH_SIZE.
GHOST_BN_UPDATE_BATCH 512 Batch size used when updating Ghost BatchNorm running statistics (e.g., a larger aggregate used for more stable updates).
NUM_CLASSES 10 Number of target classes (MNIST digits 0–9).
INITIAL_EPOCHS 2 Initial phase epochs (e.g., warmup or base training).
EXTRA_EPOCHS 2 Additional training epochs (e.g., fine-tuning or further training).
TOTAL_EPOCHS INITIAL_EPOCHS + EXTRA_EPOCHS (4) Total number of training epochs.
RESUME False Whether to resume training from a checkpoint.
CHECKPOINT_PATH "checkpoint_epoch100.pth" Path to checkpoint file to load when RESUME is True.
MIXPROB 0.102 Overall probability of applying a mix augmentation (mixup or cutmix) to a batch.
MIXUP_ALPHA 0.091 Alpha parameter for Beta distribution when sampling mixup interpolation coefficient.
CUTMIX_BETA 0.35 Beta parameter for Beta distribution when sampling cutmix area ratios.
USE_CUTMIX_PROB 0.8 Given that a mix augmentation is applied, probability of choosing CutMix vs MixUp (0.8 → 80% CutMix, 20% MixUp).
FINAL_FRAC 0.25 Fraction of training near the end for special schedules/behavior (commonly final LR fraction or final epochs fraction for SWA).
SWA_START int(TOTAL_EPOCHS * 0.80) → 3 Epoch to start Stochastic Weight Averaging (SWA). With TOTAL_EPOCHS=4, SWA starts at epoch 3.
BASE_LR 0.01 Base learning rate for optimizer/scheduler.
RESUME_LR 5e-4 Learning rate to use when resuming training from checkpoint.
ETA_MIN 1e-6 Minimum learning rate for cosine/annealing schedulers.
MOMENTUM 0.9 Momentum term for SGD optimizer.
WEIGHT_DECAY 1.8e-5 L2 regularization coefficient applied to weights.
TTA_RUNS 5 Number of Test Time Augmentation runs to average predictions for evaluation.
SAVE_PREFIX f"mnist_seed{SEED}" → "mnist_seed42" Prefix used when saving models/checkpoints/outputs.
NUM_WORKERS 4 Number of subprocesses for data loading (DataLoader num_workers).

Hyperparameter optimization achieved with Optuna.

Output

  • Better checkpoints: PREFIX_ckpt_epoch{N}.pth and PREFIX_swa_epoch{N}.pth
  • History JSON: results/PREFIX_history.json
  • Plots: results/PREFIX_train_loss.png and results/PREFIX_acc.png

Training results

Banner Image

Best overall checkpoint:

train loss 0.0675 train acc 97.188
val acc 99.7800

Note

Ranked 7th (top 4 best scores)/1000+ submissions on the Kaggle MNIST leaderbord with a 99.94% accuracy on the platform's validation dataset.

Details - click to expand

Kaggle expert Image

License

The source code is provided under the CC0 license. See the LICENSE file for details.

Languages