🚀 The feature, motivation and pitch
Problem Statement
When using LoRA, PEFT, or other parameter-efficient fine-tuning methods, the base model's normalization layer weights are typically frozen (requires_grad=False). However, Liger's norm ops currently compute gradients for these frozen parameters unconditionally, resulting in:
- Wasted computation: The backward pass computes
dW/dB even when they will be discarded
- Unnecessary memory allocation: Temporary buffers for gradient accumulation are allocated but never used
- Suboptimal training throughput: Especially noticeable at large hidden sizes (e.g., 8K-32K in modern LLMs)
This is particularly relevant as LoRA/PEFT adoption has become the de facto standard for fine-tuning large language models.
Affected Operations
RMSNorm
FusedAddRMSNorm
LayerNorm
GroupNorm
PolyNorm
Proposed Solution
Leverage PyTorch's ctx.needs_input_grad in the backward pass to conditionally skip:
- Weight/bias gradient computation in the Triton kernel (
compute_dW, compute_dB flags)
- Temporary buffer allocation for gradient accumulation
This approach:
- Requires no public API changes
- Is fully backward compatible (unfrozen weights work exactly as before)
- Automatically benefits all existing LoRA/PEFT users without code changes
Benchmark Results
Environment: RTX 3090, bf16, M=2048 (batch × seq_len)
RMSNorm Only (freeze_weight=True)
| Hidden Size |
Backward Speedup |
Full (fwd+bwd) Speedup |
| H=1024 |
1.25× (−20.1%) |
1.12× (−10.3%) |
| H=2048 |
1.15× (−12.8%) |
1.09× (−8.3%) |
| H=4096 |
1.11× (−10.1%) |
1.05× (−4.7%) |
| H=8192 |
1.07× (−6.2%) |
1.04× (−4.2%) |
| H=16384 |
1.37× (−27.1%) |
1.22× (−18.1%) |
| H=32768 |
3.12× (−67.9%) |
2.41× (−58.5%) |
The speedup increases significantly at larger hidden sizes because the dW reduction (summing partial gradients across SMs) becomes the dominant cost.
Mixed Workload: RMSNorm + LoRA Linear (freeze_norm_weight=True)
| Hidden Size |
Backward |
Full |
| H=1024–32768 |
1.00×–1.05× |
1.00×–1.04× |
In realistic LoRA scenarios, the linear layers dominate runtime, so the norm optimization provides modest but consistent gains.
Implementation Details
Internal API changes (not public-facing):
rms_norm_backward(..., compute_dW: bool)
fused_add_rms_norm_backward(..., compute_dW: bool)
layer_norm_backward(..., compute_dW: bool, compute_dB: bool)
group_norm_backward(..., compute_dW: bool, compute_dB: bool)
poly_norm_backward(..., compute_dW: bool, compute_dB: bool)
Kernel changes:
- Added
compute_dW/compute_dB as tl.constexpr parameters, enabling Triton to eliminate dead code at compile time
- Skip buffer allocation when gradients are not needed
Why This Matters
- Growing LoRA/PEFT adoption: Most LLM fine-tuning now uses parameter-efficient methods
- Larger models = bigger impact: Modern LLMs use hidden sizes of 4K–16K+, where this optimization shines
- Zero user effort: Existing code automatically benefits
- Memory savings: Reduced temporary buffer allocation helps with tight GPU memory budgets
Reproduction
# Run benchmarks
PYTHONPATH=$(pwd)/src python benchmark/scripts/benchmark_rms_norm.py --overwrite
PYTHONPATH=$(pwd)/src python benchmark/scripts/benchmark_rms_norm_mixed.py --overwrite
Alternatives
No response
Additional context
No response
🚀 The feature, motivation and pitch
Problem Statement
When using LoRA, PEFT, or other parameter-efficient fine-tuning methods, the base model's normalization layer weights are typically frozen (
requires_grad=False). However, Liger's norm ops currently compute gradients for these frozen parameters unconditionally, resulting in:dW/dBeven when they will be discardedThis is particularly relevant as LoRA/PEFT adoption has become the de facto standard for fine-tuning large language models.
Affected Operations
RMSNormFusedAddRMSNormLayerNormGroupNormPolyNormProposed Solution
Leverage PyTorch's
ctx.needs_input_gradin the backward pass to conditionally skip:compute_dW,compute_dBflags)This approach:
Benchmark Results
Environment: RTX 3090, bf16, M=2048 (batch × seq_len)
RMSNorm Only (freeze_weight=True)
The speedup increases significantly at larger hidden sizes because the
dWreduction (summing partial gradients across SMs) becomes the dominant cost.Mixed Workload: RMSNorm + LoRA Linear (freeze_norm_weight=True)
In realistic LoRA scenarios, the linear layers dominate runtime, so the norm optimization provides modest but consistent gains.
Implementation Details
Internal API changes (not public-facing):
Kernel changes:
compute_dW/compute_dBastl.constexprparameters, enabling Triton to eliminate dead code at compile timeWhy This Matters
Reproduction
Alternatives
No response
Additional context
No response