A minimal, hackable pre-training stack for GPT-style language models. This project provides a clean, production-ready foundation for training large-scale transformer models from scratch with distributed training support.
-
Modular GPT Architecture: Flexible transformer implementation with support for:
- Grouped Query Attention (GQA)
- Mixture of Experts (MoE)
- Customizable attention, MLP, and normalization layers
- Flash Attention optimization support
-
Distributed Training:
- ZeRO-1 optimizer state partitioning for memory efficiency
- DistributedDataParallel (DDP) for multi-GPU training
- Gradient accumulation for large effective batch sizes
-
Training Optimizations:
- Mixed precision training (BFloat16)
- Gradient clipping
- Cosine learning rate schedule with warmup
- Automatic checkpoint resumption with full state recovery
-
Developer-Friendly:
- Comprehensive profiling utilities
- Model FLOPs Utilization (MFU) tracking
- Mock data mode for rapid debugging
- Minimal dependencies
.
├── model/ # Model architecture
│ ├── config.py # GPTConfig dataclass
│ ├── gpt.py # GPT model implementation
│ ├── modules/ # Modular components
│ │ ├── attn.py # Attention mechanisms
│ │ ├── mlp.py # MLP and MoE layers
│ │ ├── norm.py # Normalization layers
│ │ └── emb.py # Embedding layers
│ └── ops/ # Custom operations
│ ├── flashattn.py # Flash Attention integration
│ └── grouped_gemm.py # Grouped GEMM for MoE
│
├── distributed/ # Distributed training components
│ └── zero1/
│ └── distributed_optimizer.py # ZeRO-1 implementation
│
├── utils/ # Utility functions
│ ├── model.py # Model utilities (param counting, etc.)
│ ├── training.py # Training schedule helpers
│ └── profile.py # Profiling and MFU computation
│
├── scripts/ # Training scripts
│ ├── debug_gpt_0.25b/ # 0.25B model config
│ └── debug_gpt_0.3b_a0.17b/ # 0.3B model config
│
└── train.py # Main training script
- Python 3.10+
- PyTorch 2.0+ with CUDA/NCCL support
- transformers (for tokenizer only)
- tqdm
- numpy
## Quick Start
### 1. Single Node Training
**Using training scripts (recommended):**
```bash
# Train 0.25B dense model (8 GPUs)
bash scripts/debug_gpt_0.25b/pretrain.sh
# Train 0.3B MoE model (8 GPUs)
bash scripts/debug_gpt_0.3b_a0.17b/pretrain.sh
Direct command for quick testing:
torchrun --nproc_per_node=8 train.py \
--exp_name debug_test \
--use_mock_data \
--mock_data_num_samples 1280 \
--total_batch_size 524288 \
--B 8 \
--T 4096 \
--max_epochs 1 \
--debugAll training scripts support multi-node training via environment variables:
# Node 0 (master, IP: 192.168.1.100)
NUM_NODES=2 NODE_RANK=0 MASTER_ADDR=192.168.1.100 \
bash scripts/debug_gpt_0.25b/pretrain.sh
# Node 1 (worker)
NUM_NODES=2 NODE_RANK=1 MASTER_ADDR=192.168.1.100 \
bash scripts/debug_gpt_0.25b/pretrain.shFor SLURM clusters:
sbatch scripts/slurm_multinode.shFor SSH-based automated launch:
# Edit scripts/launch_multinode_ssh.sh to configure node IPs
bash scripts/launch_multinode_ssh.sh📖 See scripts/README.md and scripts/MULTI_NODE_GUIDE.md for comprehensive multi-node training guides
Prepare your dataset and modify the _init_dataset method in train.py:
def _init_dataset(self, config: TrainerConfig):
# Replace CustomDataset with your dataset implementation
from your_data_module import YourDataset
self.train_dataset = YourDataset(
dataset_path=config.dataset_path,
split="train"
)@dataclass
class GPTConfig:
block_size: int = 4096 # Maximum sequence length
vocab_size: int = 50304 # Vocabulary size
num_layer: int = 32 # Number of transformer layers
num_attention_heads: int = 128 # Number of attention heads
num_key_value_heads: int = 8 # Number of KV heads (GQA)
hidden_size: int = 1024 # Hidden dimension
intermediate_size: int = 4096 # FFN intermediate size
dropout: float = 0.0 # Dropout rate
tied_lm_head: bool = True # Tie input/output embeddings
# Mixture of Experts (optional)
use_moe: bool = False # Enable MoE
num_experts: int = 128 # Total number of experts
num_experts_per_tok: int = 8 # Active experts per token
moe_intermediate_size: int = 256 # Expert FFN sizeKey parameters in train.py:
@dataclass
class TrainerConfig:
exp_name: str = "gpt" # Experiment name
total_batch_size: int = 524288 # Total tokens per step
B: int = 8 # Micro batch size per device
T: int = 4096 # Sequence length
max_lr: float = 4e-3 # Maximum learning rate
min_lr: float = 3e-5 # Minimum learning rate
weight_decay: float = 0.1 # AdamW weight decay
grad_clip_value: float = 1.0 # Gradient clipping threshold
warmup_steps: int = 1000 # LR warmup steps
max_epochs: int = 1 # Training epochs
save_every_steps: int = 5000 # Checkpoint frequency
use_compile: bool = False # PyTorch 2.0 compilationThe trainer automatically saves and resumes from checkpoints, preserving:
- Model weights (
*_model.pt) - Optimizer states (
*_opt/directory) - Training metadata (
*_meta.pt): step counter, RNG state, dataloader position
Simply restart the training command to resume from the latest checkpoint.
Memory-efficient optimizer state partitioning:
- Optimizer states are sharded across GPUs
- Model parameters remain replicated
- Automatic gradient synchronization and parameter broadcasting
Automatically computed based on:
grad_accum_steps = total_batch_size / (B × T × num_gpus)
Implements cosine annealing with linear warmup:
- Linear warmup: 0 → max_lr over
warmup_steps - Cosine decay: max_lr → min_lr over remaining steps
Real-time tracking of hardware efficiency:
MFU = (Actual FLOPs) / (Peak Hardware FLOPs)
Enable PyTorch profiler for performance analysis:
python train.py \
--use_profiler \
--steps_to_profile 15 20This generates a Chrome trace file at <log_dir>/rank0_trace.json that can be viewed in chrome://tracing.
--num_layer 12 \
--num_attention_heads 32 \
--num_key_value_heads 4 \
--hidden_size 1024 \
--intermediate_size 4096--num_layer 24 \
--num_attention_heads 64 \
--num_key_value_heads 8 \
--hidden_size 2048 \
--intermediate_size 8192--num_layer 32 \
--num_attention_heads 128 \
--num_key_value_heads 16 \
--hidden_size 4096 \
--intermediate_size 16384Implement your dataset class and modify _init_dataset in train.py:
class YourDataset(Dataset):
def __getitem__(self, idx):
return {
"input_ids": torch.tensor(...), # shape: (seq_len,)
"labels": torch.tensor(...) # shape: (seq_len,)
}Modify components in model/modules/:
attn.py: Implement custom attention mechanismsmlp.py: Add new feedforward architecturesnorm.py: Experiment with normalization strategies
Replace AdamW in _init_optimizer:
def _init_optimizer(self, config: TrainerConfig):
self.optimizer = YourOptimizer(
self.raw_model.parameters(),
lr=config.max_lr
)
self.optimizer = DistributedOptimizer(
optimizer=self.optimizer,
process_group=self.dp_group,
)Training logs are saved to:
<log_dir>/<exp_name>_<config_hash>/log.txt
Log format:
<step> train <loss>
<step> val <val_loss>
Example:
0 train 10.8234
100 train 8.4521
250 val 8.3012
- Enable compilation: Add
--use_compilefor PyTorch 2.0+ (20-30% speedup) - Tune batch size: Maximize
Bper GPU to improve throughput - Use Flash Attention: Ensure Flash Attention is available for faster attention
- Gradient checkpointing: Implement in
model/gpt.pyfor larger models - Mixed precision: BFloat16 is enabled by default (better than FP16 for training)
- Reduce
B(micro batch size) - Enable gradient checkpointing
- Use larger
grad_accum_stepsby reducingB
- Ensure Flash Attention is installed
- Enable
--use_compile - Check MFU percentage (should be >30% for efficient training)
- Increase
Bto better utilize GPU
- Ensure all processes have write access to
log_dir - Check disk space for optimizer state storage
If you use this code in your research, please cite:
@software{train_large_model_from_scratch,
title = {Train Large Model from Scratch},
author = {Liangyu Wang},
year = {2025},
url = {https://github.com/liangyuwang/train-large-model-from-scratch}
}This project is licensed under the terms specified in the LICENSE file.
This implementation draws inspiration from:
- nanoGPT by Andrej Karpathy
- Megatron-LM by NVIDIA
- DeepSpeed ZeRO optimization
Contributions are welcome! Please feel free to submit issues or pull requests.
Note: This is a minimal training stack designed for educational purposes and rapid prototyping. For production-scale training, consider using frameworks like DeepSpeed, Megatron-LM, or Composer.