Skip to content

Skip dW/dB computation in norm ops when weight/bias is frozen (LoRA/PEFT optimization) #1067

@yukiu00

Description

@yukiu00

🚀 The feature, motivation and pitch

Problem Statement

When using LoRA, PEFT, or other parameter-efficient fine-tuning methods, the base model's normalization layer weights are typically frozen (requires_grad=False). However, Liger's norm ops currently compute gradients for these frozen parameters unconditionally, resulting in:

  1. Wasted computation: The backward pass computes dW/dB even when they will be discarded
  2. Unnecessary memory allocation: Temporary buffers for gradient accumulation are allocated but never used
  3. Suboptimal training throughput: Especially noticeable at large hidden sizes (e.g., 8K-32K in modern LLMs)

This is particularly relevant as LoRA/PEFT adoption has become the de facto standard for fine-tuning large language models.

Affected Operations

  • RMSNorm
  • FusedAddRMSNorm
  • LayerNorm
  • GroupNorm
  • PolyNorm

Proposed Solution

Leverage PyTorch's ctx.needs_input_grad in the backward pass to conditionally skip:

  1. Weight/bias gradient computation in the Triton kernel (compute_dW, compute_dB flags)
  2. Temporary buffer allocation for gradient accumulation

This approach:

  • Requires no public API changes
  • Is fully backward compatible (unfrozen weights work exactly as before)
  • Automatically benefits all existing LoRA/PEFT users without code changes

Benchmark Results

Environment: RTX 3090, bf16, M=2048 (batch × seq_len)

RMSNorm Only (freeze_weight=True)

Hidden Size Backward Speedup Full (fwd+bwd) Speedup
H=1024 1.25× (−20.1%) 1.12× (−10.3%)
H=2048 1.15× (−12.8%) 1.09× (−8.3%)
H=4096 1.11× (−10.1%) 1.05× (−4.7%)
H=8192 1.07× (−6.2%) 1.04× (−4.2%)
H=16384 1.37× (−27.1%) 1.22× (−18.1%)
H=32768 3.12× (−67.9%) 2.41× (−58.5%)

The speedup increases significantly at larger hidden sizes because the dW reduction (summing partial gradients across SMs) becomes the dominant cost.

Mixed Workload: RMSNorm + LoRA Linear (freeze_norm_weight=True)

Hidden Size Backward Full
H=1024–32768 1.00×–1.05× 1.00×–1.04×

In realistic LoRA scenarios, the linear layers dominate runtime, so the norm optimization provides modest but consistent gains.

Implementation Details

Internal API changes (not public-facing):

rms_norm_backward(..., compute_dW: bool)
fused_add_rms_norm_backward(..., compute_dW: bool)
layer_norm_backward(..., compute_dW: bool, compute_dB: bool)
group_norm_backward(..., compute_dW: bool, compute_dB: bool)
poly_norm_backward(..., compute_dW: bool, compute_dB: bool)

Kernel changes:

  • Added compute_dW/compute_dB as tl.constexpr parameters, enabling Triton to eliminate dead code at compile time
  • Skip buffer allocation when gradients are not needed

Why This Matters

  1. Growing LoRA/PEFT adoption: Most LLM fine-tuning now uses parameter-efficient methods
  2. Larger models = bigger impact: Modern LLMs use hidden sizes of 4K–16K+, where this optimization shines
  3. Zero user effort: Existing code automatically benefits
  4. Memory savings: Reduced temporary buffer allocation helps with tight GPU memory budgets

Reproduction

# Run benchmarks
PYTHONPATH=$(pwd)/src python benchmark/scripts/benchmark_rms_norm.py --overwrite
PYTHONPATH=$(pwd)/src python benchmark/scripts/benchmark_rms_norm_mixed.py --overwrite

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions