Skip to content

Commit 1c013e2

Browse files
[Feature] Use Liger's Relu_Squared kernel for Nemotron models (#1176)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Use relu_squared in nemotron. This PR is generated using the liger-autopatch skill and tests the changes in #1177 . <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> Class patching and instance patching of relu_squared function. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: H100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f16c9f7 commit 1c013e2

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ loss.backward()
253253
| Ministral | `liger_kernel.transformers.apply_liger_kernel_to_ministral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
254254
| Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
255255
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
256-
| Nemotron | `liger_kernel.transformers.apply_liger_kernel_to_nemotron` | CrossEntropyLoss, FusedLinearCrossEntropy |
256+
| Nemotron | `liger_kernel.transformers.apply_liger_kernel_to_nemotron` | ReLUSquared, CrossEntropyLoss, FusedLinearCrossEntropy |
257257
| Pixtral | `liger_kernel.transformers.apply_liger_kernel_to_pixtral` | RoPE, RMSNorm, SwiGLU|
258258
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
259259
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
3030
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
3131
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
32+
from liger_kernel.transformers.relu_squared import LigerReLUSquared
3233
from liger_kernel.transformers.rms_norm import LigerRMSNorm
3334
from liger_kernel.transformers.rope import liger_rotary_pos_emb
3435
from liger_kernel.transformers.rope import liger_rotary_pos_emb_vision
@@ -748,6 +749,7 @@ def apply_liger_kernel_to_mistral(
748749

749750

750751
def apply_liger_kernel_to_nemotron(
752+
relu_squared: bool = True,
751753
cross_entropy: bool = False,
752754
fused_linear_cross_entropy: bool = True,
753755
model: PreTrainedModel = None,
@@ -756,12 +758,12 @@ def apply_liger_kernel_to_nemotron(
756758
"""
757759
Apply Liger kernels to replace original implementation in HuggingFace Nemotron models.
758760
759-
Note: Nemotron uses a non-gated MLP (squared ReLU) and NemotronLayerNorm1P (LayerNorm with +1 offset),
760-
which are not currently supported by Liger kernels. RoPE is also not patched because Nemotron uses
761-
partial rotary embeddings (partial_rotary_factor=0.5) which the Liger RoPE kernel does not support.
762-
Only cross entropy optimizations are applied.
761+
Note: NemotronLayerNorm1P (LayerNorm with +1 offset) is not currently supported by Liger kernels.
762+
RoPE is also not patched because Nemotron uses partial rotary embeddings
763+
(partial_rotary_factor=0.5) which the Liger RoPE kernel does not support.
763764
764765
Args:
766+
relu_squared (bool): Whether to apply Liger's ReLU squared activation. Default is True.
765767
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
766768
fused_linear_cross_entropy (bool):
767769
Whether to apply Liger's fused linear cross entropy loss. Default is True.
@@ -776,6 +778,9 @@ def apply_liger_kernel_to_nemotron(
776778

777779
from transformers.models.nemotron import modeling_nemotron
778780

781+
if relu_squared:
782+
modeling_nemotron.ACT2FN["relu2"] = LigerReLUSquared
783+
779784
if cross_entropy:
780785
modeling_nemotron.CrossEntropyLoss = LigerCrossEntropyLoss
781786
if fused_linear_cross_entropy:
@@ -784,6 +789,11 @@ def apply_liger_kernel_to_nemotron(
784789
else:
785790
modeling_nemotron.NemotronForCausalLM.forward = nemotron_lce_forward
786791

792+
if model is not None:
793+
for decoder_layer in model.model.layers:
794+
if relu_squared:
795+
decoder_layer.mlp.act_fn = LigerReLUSquared()
796+
787797

788798
def apply_liger_kernel_to_mixtral(
789799
rope: bool = True,

test/transformers/test_monkey_patch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3275,6 +3275,7 @@ def test_apply_liger_kernel_to_instance_for_hunyuan_v1_dense():
32753275
@pytest.mark.skipif(not is_nemotron_available(), reason="nemotron not available")
32763276
def test_apply_liger_kernel_to_instance_for_nemotron():
32773277
from liger_kernel.transformers.model.nemotron import lce_forward as nemotron_lce_forward
3278+
from liger_kernel.transformers.relu_squared import LigerReLUSquared
32783279

32793280
# Ensure any monkey patching is cleaned up for subsequent tests
32803281
with patch("transformers.models.nemotron.modeling_nemotron"):
@@ -3292,14 +3293,19 @@ def test_apply_liger_kernel_to_instance_for_nemotron():
32923293

32933294
# Check that model instance variables are not yet patched with Liger modules
32943295
assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(nemotron_lce_forward)
3296+
for decoder_layer in dummy_model_instance.model.layers:
3297+
assert not isinstance(decoder_layer.mlp.act_fn, LigerReLUSquared)
32953298

32963299
# Test applying kernels to the model instance
3297-
# Nemotron only supports rope and fused_linear_cross_entropy patching
32983300
_apply_liger_kernel_to_instance(model=dummy_model_instance)
32993301

33003302
# Check that the model's forward was correctly patched
33013303
assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(nemotron_lce_forward)
33023304

3305+
# Check that the activation function was correctly patched
3306+
for decoder_layer in dummy_model_instance.model.layers:
3307+
assert isinstance(decoder_layer.mlp.act_fn, LigerReLUSquared)
3308+
33033309
try:
33043310
print(dummy_model_instance)
33053311
except Exception as e:

0 commit comments

Comments
 (0)