-
Notifications
You must be signed in to change notification settings - Fork 488
Description
🚀 The feature, motivation and pitch
Problem Statement
When using LoRA, PEFT, or other parameter-efficient fine-tuning methods, the base model's normalization layer weights are typically frozen (requires_grad=False). However, Liger's norm ops currently compute gradients for these frozen parameters unconditionally, resulting in:
- Wasted computation: The backward pass computes
dW/dBeven when they will be discarded - Unnecessary memory allocation: Temporary buffers for gradient accumulation are allocated but never used
- Suboptimal training throughput: Especially noticeable at large hidden sizes (e.g., 8K-32K in modern LLMs)
This is particularly relevant as LoRA/PEFT adoption has become the de facto standard for fine-tuning large language models.
Affected Operations
RMSNormFusedAddRMSNormLayerNormGroupNormPolyNorm
Proposed Solution
Leverage PyTorch's ctx.needs_input_grad in the backward pass to conditionally skip:
- Weight/bias gradient computation in the Triton kernel (
compute_dW,compute_dBflags) - Temporary buffer allocation for gradient accumulation
This approach:
- Requires no public API changes
- Is fully backward compatible (unfrozen weights work exactly as before)
- Automatically benefits all existing LoRA/PEFT users without code changes
Benchmark Results
Environment: RTX 3090, bf16, M=2048 (batch × seq_len)
RMSNorm Only (freeze_weight=True)
| Hidden Size | Backward Speedup | Full (fwd+bwd) Speedup |
|---|---|---|
| H=1024 | 1.25× (−20.1%) | 1.12× (−10.3%) |
| H=2048 | 1.15× (−12.8%) | 1.09× (−8.3%) |
| H=4096 | 1.11× (−10.1%) | 1.05× (−4.7%) |
| H=8192 | 1.07× (−6.2%) | 1.04× (−4.2%) |
| H=16384 | 1.37× (−27.1%) | 1.22× (−18.1%) |
| H=32768 | 3.12× (−67.9%) | 2.41× (−58.5%) |
The speedup increases significantly at larger hidden sizes because the dW reduction (summing partial gradients across SMs) becomes the dominant cost.
Mixed Workload: RMSNorm + LoRA Linear (freeze_norm_weight=True)
| Hidden Size | Backward | Full |
|---|---|---|
| H=1024–32768 | 1.00×–1.05× | 1.00×–1.04× |
In realistic LoRA scenarios, the linear layers dominate runtime, so the norm optimization provides modest but consistent gains.
Implementation Details
Internal API changes (not public-facing):
rms_norm_backward(..., compute_dW: bool)
fused_add_rms_norm_backward(..., compute_dW: bool)
layer_norm_backward(..., compute_dW: bool, compute_dB: bool)
group_norm_backward(..., compute_dW: bool, compute_dB: bool)
poly_norm_backward(..., compute_dW: bool, compute_dB: bool)Kernel changes:
- Added
compute_dW/compute_dBastl.constexprparameters, enabling Triton to eliminate dead code at compile time - Skip buffer allocation when gradients are not needed
Why This Matters
- Growing LoRA/PEFT adoption: Most LLM fine-tuning now uses parameter-efficient methods
- Larger models = bigger impact: Modern LLMs use hidden sizes of 4K–16K+, where this optimization shines
- Zero user effort: Existing code automatically benefits
- Memory savings: Reduced temporary buffer allocation helps with tight GPU memory budgets
Reproduction
# Run benchmarks
PYTHONPATH=$(pwd)/src python benchmark/scripts/benchmark_rms_norm.py --overwrite
PYTHONPATH=$(pwd)/src python benchmark/scripts/benchmark_rms_norm_mixed.py --overwriteAlternatives
No response
Additional context
No response