Skip to content

Commit eac9ce6

Browse files
Add Liger Kernel support for nemotron models (#1165)
This PR is generated by Claude code using the skill in #1167. ## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> - Adds monkey patching support for NVIDIA Nemotron models with apply_liger_kernel_to_nemotron - Supports FusedLinearCrossEntropy (default) and CrossEntropyLoss optimizations - RoPE, MLP, and LayerNorm patching are intentionally excluded due to Nemotron's unique architecture <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> - Instance patching test passes - All 4 convergence tests pass on H100 GPU (bf16/fp32 x FLCE/with_logits) - Lint passes <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f1b7e47 commit eac9ce6

File tree

10 files changed

+431
-0
lines changed

10 files changed

+431
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ loss.backward()
252252
| LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
253253
| Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
254254
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
255+
| Nemotron | `liger_kernel.transformers.apply_liger_kernel_to_nemotron` | CrossEntropyLoss, FusedLinearCrossEntropy |
255256
| Pixtral | `liger_kernel.transformers.apply_liger_kernel_to_pixtral` | RoPE, RMSNorm, SwiGLU|
256257
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
257258
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |

src/liger_kernel/transformers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
5757
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
5858
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
59+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_nemotron # noqa: F401
5960
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
6061
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401
6162
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
@@ -126,6 +127,7 @@ def __getattr__(name: str):
126127
"apply_liger_kernel_to_llama4",
127128
"apply_liger_kernel_to_mistral",
128129
"apply_liger_kernel_to_mixtral",
130+
"apply_liger_kernel_to_nemotron",
129131
"apply_liger_kernel_to_mllama",
130132
"apply_liger_kernel_to_olmo2",
131133
"apply_liger_kernel_to_olmo3",
@@ -210,6 +212,7 @@ def __getattr__(name: str):
210212
"apply_liger_kernel_to_llama4",
211213
"apply_liger_kernel_to_mistral",
212214
"apply_liger_kernel_to_mixtral",
215+
"apply_liger_kernel_to_nemotron",
213216
"apply_liger_kernel_to_mllama",
214217
"apply_liger_kernel_to_olmo2",
215218
"apply_liger_kernel_to_olmo3",
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from typing import TYPE_CHECKING
2+
from typing import Optional
3+
from typing import Tuple
4+
from typing import Union
5+
6+
import torch
7+
8+
from liger_kernel.transformers.model.llama import lce_maybe_trainable_lm_head
9+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
10+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
11+
12+
if TYPE_CHECKING:
13+
from transformers.cache_utils import Cache
14+
15+
16+
def lce_forward(
17+
self,
18+
input_ids: Optional[torch.LongTensor] = None,
19+
attention_mask: Optional[torch.Tensor] = None,
20+
position_ids: Optional[torch.LongTensor] = None,
21+
past_key_values: Optional["Cache"] = None,
22+
inputs_embeds: Optional[torch.FloatTensor] = None,
23+
labels: Optional[torch.LongTensor] = None,
24+
use_cache: Optional[bool] = None,
25+
output_attentions: Optional[bool] = None,
26+
output_hidden_states: Optional[bool] = None,
27+
return_dict: Optional[bool] = None,
28+
cache_position: Optional[torch.LongTensor] = None,
29+
logits_to_keep: Union[int, torch.Tensor] = 0,
30+
skip_logits: Optional[bool] = None,
31+
**kwargs,
32+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
33+
r"""
34+
Args:
35+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
36+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
37+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
38+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
39+
40+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
41+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
42+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
43+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
44+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
45+
This is useful when using packed tensor format (single dimension for batch and sequence length).
46+
47+
Returns:
48+
49+
Example:
50+
51+
```python
52+
>>> from transformers import AutoTokenizer, NemotronForCausalLM
53+
54+
>>> model = NemotronForCausalLM.from_pretrained("nvidia/nemotron-3-8b-base-4k-hf")
55+
>>> tokenizer = AutoTokenizer.from_pretrained("nvidia/nemotron-3-8b-base-4k-hf")
56+
57+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
58+
>>> inputs = tokenizer(prompt, return_tensors="pt")
59+
60+
>>> # Generate
61+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
62+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
63+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
64+
```"""
65+
66+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
67+
output_hidden_states = (
68+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
69+
)
70+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
71+
72+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
73+
outputs = self.model(
74+
input_ids=input_ids,
75+
attention_mask=attention_mask,
76+
position_ids=position_ids,
77+
past_key_values=past_key_values,
78+
inputs_embeds=inputs_embeds,
79+
use_cache=use_cache,
80+
output_attentions=output_attentions,
81+
output_hidden_states=output_hidden_states,
82+
cache_position=cache_position,
83+
**kwargs,
84+
)
85+
86+
hidden_states = outputs.last_hidden_state
87+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
88+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
89+
kept_hidden_states = hidden_states[:, slice_indices, :]
90+
91+
shift_labels = kwargs.pop("shift_labels", None)
92+
logits = None
93+
loss = None
94+
token_accuracy = None
95+
predicted_tokens = None
96+
97+
# if in training mode, don't materialize logits
98+
if skip_logits and labels is None and shift_labels is None:
99+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
100+
101+
if skip_logits is None:
102+
# By default, if in training mode, don't materialize logits
103+
skip_logits = self.training and (labels is not None or shift_labels is not None)
104+
105+
# Compute loss
106+
if skip_logits:
107+
result = lce_maybe_trainable_lm_head(
108+
self,
109+
hidden_states=kept_hidden_states,
110+
hidden_size=self.config.hidden_size,
111+
labels=labels,
112+
shift_labels=shift_labels,
113+
**kwargs,
114+
)
115+
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
116+
else:
117+
logits = self.lm_head(kept_hidden_states)
118+
if labels is not None or shift_labels is not None:
119+
loss = self.loss_function(
120+
logits=logits,
121+
labels=labels,
122+
shift_labels=shift_labels,
123+
vocab_size=self.config.vocab_size,
124+
**kwargs,
125+
)
126+
127+
if not return_dict:
128+
output = (logits,) + outputs[1:]
129+
output = ((loss,) + output) if loss is not None else output
130+
output = output + (token_accuracy,) if token_accuracy is not None else output
131+
output = output + (predicted_tokens,) if predicted_tokens is not None else output
132+
return output
133+
134+
# Return custom output class with token_accuracy field
135+
return LigerCausalLMOutputWithPast(
136+
loss=loss,
137+
logits=logits,
138+
past_key_values=outputs.past_key_values,
139+
hidden_states=outputs.hidden_states,
140+
attentions=outputs.attentions,
141+
token_accuracy=token_accuracy,
142+
predicted_tokens=predicted_tokens,
143+
)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
2424
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
2525
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
26+
from liger_kernel.transformers.model.nemotron import lce_forward as nemotron_lce_forward
2627
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
2728
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
2829
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
@@ -682,6 +683,44 @@ def apply_liger_kernel_to_mistral(
682683
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
683684

684685

686+
def apply_liger_kernel_to_nemotron(
687+
cross_entropy: bool = False,
688+
fused_linear_cross_entropy: bool = True,
689+
model: PreTrainedModel = None,
690+
**kwargs,
691+
) -> None:
692+
"""
693+
Apply Liger kernels to replace original implementation in HuggingFace Nemotron models.
694+
695+
Note: Nemotron uses a non-gated MLP (squared ReLU) and NemotronLayerNorm1P (LayerNorm with +1 offset),
696+
which are not currently supported by Liger kernels. RoPE is also not patched because Nemotron uses
697+
partial rotary embeddings (partial_rotary_factor=0.5) which the Liger RoPE kernel does not support.
698+
Only cross entropy optimizations are applied.
699+
700+
Args:
701+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
702+
fused_linear_cross_entropy (bool):
703+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
704+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
705+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
706+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
707+
loaded. Default is None.
708+
"""
709+
assert not (cross_entropy and fused_linear_cross_entropy), (
710+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
711+
)
712+
713+
from transformers.models.nemotron import modeling_nemotron
714+
715+
if cross_entropy:
716+
modeling_nemotron.CrossEntropyLoss = LigerCrossEntropyLoss
717+
if fused_linear_cross_entropy:
718+
if model is not None:
719+
model.forward = MethodType(nemotron_lce_forward, model)
720+
else:
721+
modeling_nemotron.NemotronForCausalLM.forward = nemotron_lce_forward
722+
723+
685724
def apply_liger_kernel_to_mixtral(
686725
rope: bool = True,
687726
cross_entropy: bool = False,
@@ -3083,6 +3122,7 @@ def __init__(self, hidden_size, eps=1e-6, **kwargs):
30833122
"mllama_text_model": apply_liger_kernel_to_mllama,
30843123
"mistral": apply_liger_kernel_to_mistral,
30853124
"mixtral": apply_liger_kernel_to_mixtral,
3125+
"nemotron": apply_liger_kernel_to_nemotron,
30863126
"olmo2": apply_liger_kernel_to_olmo2,
30873127
"pixtral": apply_liger_kernel_to_pixtral,
30883128
"olmo3": apply_liger_kernel_to_olmo3,

test/convergence/bf16/test_mini_models.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from liger_kernel.transformers import apply_liger_kernel_to_mistral
4444
from liger_kernel.transformers import apply_liger_kernel_to_mixtral
4545
from liger_kernel.transformers import apply_liger_kernel_to_mllama
46+
from liger_kernel.transformers import apply_liger_kernel_to_nemotron
4647
from liger_kernel.transformers import apply_liger_kernel_to_olmo2
4748
from liger_kernel.transformers import apply_liger_kernel_to_olmo3
4849
from liger_kernel.transformers import apply_liger_kernel_to_phi3
@@ -83,6 +84,7 @@
8384
from test.utils import revert_liger_kernel_to_mistral
8485
from test.utils import revert_liger_kernel_to_mixtral
8586
from test.utils import revert_liger_kernel_to_mllama
87+
from test.utils import revert_liger_kernel_to_nemotron
8688
from test.utils import revert_liger_kernel_to_olmo2
8789
from test.utils import revert_liger_kernel_to_olmo3
8890
from test.utils import revert_liger_kernel_to_phi3
@@ -332,6 +334,14 @@
332334
except ImportError:
333335
EXAONE4_AVAILABLE = False
334336

337+
try:
338+
from transformers.models.nemotron.configuration_nemotron import NemotronConfig
339+
from transformers.models.nemotron.modeling_nemotron import NemotronForCausalLM
340+
341+
NEMOTRON_AVAILABLE = True
342+
except ImportError:
343+
NEMOTRON_AVAILABLE = False
344+
335345

336346
device = infer_device()
337347

@@ -1559,6 +1569,29 @@
15591569
),
15601570
)
15611571

1572+
if NEMOTRON_AVAILABLE:
1573+
MINI_MODEL_SETUPS["mini_nemotron"] = MiniModelConfig(
1574+
liger_kernel_patch_func=apply_liger_kernel_to_nemotron,
1575+
liger_kernel_patch_revert_func=revert_liger_kernel_to_nemotron,
1576+
model_class=NemotronForCausalLM,
1577+
mini_model_config=NemotronConfig(
1578+
attention_bias=False,
1579+
attention_dropout=0.0,
1580+
bos_token_id=1,
1581+
eos_token_id=2,
1582+
hidden_act="relu2",
1583+
hidden_size=1024,
1584+
initializer_range=0.02,
1585+
intermediate_size=2048,
1586+
max_position_embeddings=8192,
1587+
num_attention_heads=8,
1588+
num_hidden_layers=4,
1589+
num_key_value_heads=2,
1590+
norm_eps=1e-5,
1591+
vocab_size=32000,
1592+
),
1593+
)
1594+
15621595

15631596
def create_model(model_name="mini_llama4"):
15641597
"""
@@ -2274,6 +2307,22 @@ def run_mini_model(
22742307
),
22752308
],
22762309
),
2310+
pytest.param(
2311+
"mini_nemotron",
2312+
32,
2313+
1e-5,
2314+
torch.bfloat16,
2315+
1e-2,
2316+
5e-2,
2317+
1e-1,
2318+
1e-2,
2319+
1e-2,
2320+
1e-2,
2321+
marks=[
2322+
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
2323+
pytest.mark.skipif(not NEMOTRON_AVAILABLE, reason="Nemotron not available"),
2324+
],
2325+
),
22772326
],
22782327
)
22792328
def test_mini_model(

0 commit comments

Comments
 (0)