Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .claude-plugin/marketplace.json
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@
"./10-optimization/flash-attention",
"./10-optimization/gguf",
"./10-optimization/gptq",
"./10-optimization/hqq"
"./10-optimization/hqq",
"./10-optimization/ml-training-recipes"
]
},
{
Expand Down
319 changes: 319 additions & 0 deletions 10-optimization/ml-training-recipes/SKILL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
---
name: ml-training-recipes
description: Battle-tested PyTorch training recipes for all domains — LLMs, vision, diffusion, medical imaging, protein/drug discovery, spatial omics, genomics. Covers training loops, optimizer selection (AdamW, Muon), LR scheduling, mixed precision, debugging, and systematic experimentation. Use when training or fine-tuning neural networks, debugging loss spikes or OOM, choosing architectures, or optimizing GPU throughput.
version: 1.0.0
author: dailycafi
license: MIT
tags: [PyTorch, Training, Optimization, LLM, Vision, Diffusion, Biomedical, Muon, AdamW, Debugging]
dependencies: [torch>=2.0.0]
---

# ML Training Recipes

Battle-tested patterns for PyTorch training across domains. Drawn from production codebases
(Karpathy's autoresearch/nanochat, torchvision, HuggingFace) and modern training practice.

## Reference files (read when needed)

- `references/architecture.md` — Transformer/LLM architecture code patterns, weight init
- `references/optimizers.md` — Muon, AdamW hybrid, per-group LR, compiled optimizer steps
- `references/domain-specific.md` — Vision, diffusion, contrastive, distributed, checkpointing, data loading
- `references/scaling-and-selection.md` — Scaling laws, compute budget tables, decision trees, DGX Spark
- `references/biomedical.md` — Drug discovery, protein models, medical imaging, genomics, clinical NLP
- `references/experiment-loop.md` — Autonomous experiment loop (autoresearch keep/discard/revert)

---

## Architecture Selection

Pick the right model by **data type** and **data scale**:

| Data Type | < 10K samples | 10K-100K | > 100K |
|-----------|--------------|----------|--------|
| **Images** | Pretrained CNN + fine-tune | Fine-tune ViT or CNN | ViT from scratch |
| **Text (gen)** | Few-shot prompting | Fine-tune GPT/LLaMA (LoRA) | Pretrain from scratch |
| **Tabular** | XGBoost/LightGBM | Still XGBoost | Neural viable |
| **Audio** | Pretrained Whisper | Fine-tune AST | Train from scratch |
| **Molecules** | Pretrained GNN | Fine-tune molecular LM | Train GNN from scratch |
| **Proteins** | ESM-2 embeddings + head | Fine-tune ESM-2 | Train protein LM |
| **Medical img** | Pretrained CNN | nnU-Net (auto-config) | Swin-UNETR / MedSAM |

**Key principle**: architecture matters less than training recipe at equal compute. A well-tuned
ResNet beats a poorly-tuned ViT (ref: "ResNet Strikes Back", Wightman 2021).

For biomedical domains, see `references/biomedical.md`.
For sequence model selection and compute planning, see `references/scaling-and-selection.md`.

---

## Scaling Laws

### Chinchilla rule (Hoffmann et al., 2022)

Compute-optimal training: **~20 tokens per parameter**.

| Model Size | Compute-Optimal | Inference-Optimal (100×) |
|-----------|----------------|--------------------------|
| 125M | 2.5B tokens | 12.5B tokens |
| 1B | 20B tokens | 100B tokens |
| 7B | 140B tokens | 700B tokens |

**FLOPs ≈ 6 × N × D** (N=params, D=tokens). Data repetition limit: ~4 epochs before diminishing returns.

---

## Training Loop

```python
import gc, time, torch

torch.manual_seed(42)
torch.set_float32_matmul_precision("high") # TF32 on Ampere+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)

grad_accum_steps = total_batch_size // (batch_size * seq_len)
step = 0

while not done:
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
loss = model(x, y)
(loss / grad_accum_steps).backward()
x, y = next(train_loader)

update_lr(optimizer, progress)
optimizer.step()
model.zero_grad(set_to_none=True) # frees memory vs zeroing

if loss.item() > 100: # fast-fail on divergence
print("FAIL: loss exploded"); exit(1)

torch.cuda.synchronize()
if step == 0:
gc.collect(); gc.freeze(); gc.disable() # avoid ~500ms GC stalls
step += 1
```

### Key principles

- **Gradient clipping**: `clip_grad_norm_(params, 1.0)` — near-universal for Transformers.
Exception: Muon optimizer normalizes updates via orthogonalization, so clipping is optional.
- **Tensor Core alignment**: batch size, hidden dims should be multiples of 8 (bf16) or 64 (A100).
- **Time-based budgets** make experiments comparable across hardware.
- **`cudnn.benchmark = True`** for fixed-size vision inputs.

---

## Optimizer Configuration

Modern LLM training uses different optimizers per parameter group:

| Parameter Type | Optimizer | LR (base) | Weight Decay |
|---------------|-----------|-----------|--------------|
| 2D weight matrices | Muon | 0.04 | 0.2 |
| Token embeddings | AdamW | 0.6 × scale | 0.0 |
| Unembedding (lm_head) | AdamW | 0.004 × scale | 0.0 |
| Per-layer scalars | AdamW | 0.005 × scale | 0.0 |

**LR scaling by dimension**: `lr * (d_model / 768)^(-0.5)` — keeps dynamics stable across sizes.

### Rules of thumb

- Embeddings need higher LR (sparse updates). Never weight-decay embeddings.
- Weight decay scheduling: linearly decay WD to 0 over training.
- AdamW defaults: β1=0.9, β2=0.95, eps=1e-10 (not default 1e-8 — prevents stale updates in bf16).

For Muon details (polar express orthogonalization, NorMuon), see `references/optimizers.md`.

---

## Learning Rate Scheduling

### Time-based (autoresearch style)

```python
def get_lr_multiplier(progress): # progress = elapsed_time / time_budget
if progress < warmup_ratio:
return progress / warmup_ratio
elif progress < 1.0 - warmdown_ratio:
return 1.0
else:
cooldown = (1.0 - progress) / warmdown_ratio
return cooldown + (1 - cooldown) * final_lr_frac
```

### Cosine decay

```python
def get_lr(step, total_steps, max_lr, min_lr, warmup_steps):
if step < warmup_steps:
return max_lr * step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
```

**WSD (Warmup-Stable-Decay)**: gaining traction — easier to resume training mid-run.

### Guidance

- **Warmup**: 1-5% of training. Zero warmup valid with Muon (autoresearch uses `WARMUP_RATIO=0.0`).
- **Warmdown**: 30-50% of training in LR decay. Matters more than warmup for final quality.
- **Final LR**: 0 or ~10% of peak. Zero is simpler.

---

## Mixed Precision & Compilation

```python
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" # before torch import

import torch
torch.set_float32_matmul_precision("high")
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
model = torch.compile(model, dynamic=False)
```

- **bf16** (Ampere+): same exponent as fp32, no loss scaling needed. Preferred over fp16.
- **fp16**: needs GradScaler. Use only on V100 or older.
- `dynamic=False` enables max optimization. Add `fullgraph=True` if no graph breaks.
- First steps are slow (JIT) — exclude from timing.

---

## Memory & Performance

### Meta device init (large models)

```python
with torch.device("meta"):
model = GPT(config) # zero memory
model.to_empty(device="cuda")
model.init_weights()
```

### MFU (Model FLOPs Utilization)

```python
achieved_flops = model_flops_per_token * batch_tokens / step_time
mfu = achieved_flops / gpu_peak_flops
# H100 SXM: 989.5 TFLOPS | A100: 312 | RTX 4090: 165
```

Good targets: >30% decent, >40% good, >50% excellent (single-GPU).

### OOM solutions (in order)

1. Reduce `DEVICE_BATCH_SIZE`, increase `grad_accum_steps`
2. `PYTORCH_ALLOC_CONF=expandable_segments:True`
3. `model.zero_grad(set_to_none=True)`
4. Meta device init → `to_empty`
5. Activation checkpointing: `torch.utils.checkpoint.checkpoint()`
6. 8-bit optimizer (bitsandbytes): ~30% savings on optimizer states

---

## Hyperparameter Search

### Priority order (tune first → last)

1. **Learning rate** — most impactful. Always tune first.
2. **Batch size** — largest that fits. Speed knob, not quality knob.
3. **Weight decay** — 0.01-0.1 for AdamW.
4. **Warmup steps** — 1-5% of training.

### The 2025 default recipe

| Setting | Value |
|---------|-------|
| Optimizer | AdamW (β1=0.9, β2=0.95, eps=1e-10) |
| Weight decay | 0.1 |
| LR schedule | Cosine decay or WSD |
| Peak LR | 3e-4 (scale down for larger models) |
| Precision | bf16 |
| Grad clipping | max_norm=1.0 |
| Normalization | RMSNorm (pre-norm) |
| Activation | SwiGLU |
| Position encoding | RoPE |
| Attention | Flash Attention, optionally GQA |

---

## Debugging Checklist

### Karpathy's recipe (still canonical)

1. **Become one with the data** — visualize, check distributions, verify labels
2. **Get end-to-end running first** — verify on a trivial case
3. **Overfit one batch** — if you can't, you have a bug
4. **Then regularize** — add regularization only after overfitting works
5. **Tune hyperparameters** — start with known defaults

### Loss exploding / NaN

1. Reduce LR (3-10× smaller)
2. Add gradient clipping: `clip_grad_norm_(params, 1.0)`
3. Check for inf/nan in inputs
4. Add logit soft capping: `softcap * tanh(logits / softcap)`
5. Add QK-norm in attention
6. Verify weight init (zero-init output projections?)
7. Check loss reduction with gradient accumulation (`loss / grad_accum_steps`)

### Slow training / Low MFU

1. Verify `torch.compile` is active
2. Check `torch.set_float32_matmul_precision("high")`
3. Pin memory + non_blocking transfers
4. Profile with `torch.profiler`
5. GC stalls? `gc.freeze(); gc.disable()`
6. Tensor Core alignment: dims multiples of 8/64

### Loss plateau / Slow convergence

1. LR too low — try 2-5× larger
2. Warmup too long
3. Weight decay too high
4. Verify LR schedule is actually applied (print each step)
5. Model too small for task

### Silent failures

1. **Data leakage** between train/val
2. **Wrong preprocessing at inference** — augmentation mismatch
3. **Label errors** — use cleanlab to detect
4. **Shuffling bugs** — correlated batches
5. **Tokenizer mismatch** with pretrained model

### What to monitor

- **Gradient norms** — spike precedes loss spike
- **Per-layer activation stats** — reveals exploding/vanishing
- **Dead neurons** — >50% zero ReLU = dying ReLU problem
- **Learning rate** — verify schedule applied (common silent bug)

---

## Experiment Management

Track experiments in TSV for easy comparison:

```
commit val_bpb memory_gb status description
a1b2c3d 0.9979 44.0 keep baseline
b2c3d4e 0.9932 44.2 keep increase matrix LR to 0.04
c3d4e5f 1.0050 44.0 discard switch to GeLU (worse)
```

**Simplicity criterion**: all else equal, simpler is better. Removing something and getting equal
results is a great outcome. For systematic agent-driven experimentation, see `references/experiment-loop.md`.

### Evaluation metrics by domain

| Domain | Primary Metric | Notes |
|--------|---------------|-------|
| LLM | BPB (bits per byte) | Vocab-size-independent |
| Classification | Accuracy / F1 | Macro-F1 for imbalanced |
| Segmentation | mIoU / Dice | Per-class IoU reveals weak spots |
| Generation | FID | Needs >10k samples |
| Regression | RMSE / MAE | Log-transform skewed targets |
Loading