Skip to content

Commit 0438fa8

Browse files
authored
Add patches for Qwen3_5ForConditionalGeneration to support multimodal. (#1150)
## Summary Add support for Qwen3.5 multimodal patches as follow-up for #1123. Implement extra `lce_forward` for `Qwen3_5ForConditionalGeneration` and add corresponding tests. ### Details Modified files: - `model/qwen3_5.py`: add `lce_forward_for_multimodal` for `Qwen3_5ForConditionalGeneration`. - `model/output_classes.py`: add `LigerQwen3_5CausalLMOutputWithPast` as the return type for qwen3_5. - `test/utils.py`: `revert_liger_kernel_to_qwen3_5` to support `conditional_generation` for qwen3_5. - `test_monkey_patch.py`: add `test_apply_liger_kernel_to_instance_for_qwen3_5_for_conditional_generation`. - `bf16/test_mini_models_multimodal.py`: add `MINI_MODEL_SETUPS["mini_qwen3_5"]` to convergence test. - `fp32/test_mini_models_multimodal.py`: skipped ## Testing Done - Hardware Type: NVIDIA A100 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 781083b commit 0438fa8

File tree

7 files changed

+486
-16
lines changed

7 files changed

+486
-16
lines changed

src/liger_kernel/transformers/model/output_classes.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@
7373
except Exception:
7474
_Qwen3VLMoeCausalLMOutputWithPast = None
7575

76+
try:
77+
from transformers.models.qwen3_5.modeling_qwen3_5 import (
78+
Qwen3_5CausalLMOutputWithPast as _Qwen3_5CausalLMOutputWithPast,
79+
)
80+
except Exception:
81+
_Qwen3_5CausalLMOutputWithPast = None
82+
7683

7784
@dataclass
7885
class LigerCausalLMOutputWithPast(CausalLMOutputWithPast):
@@ -156,3 +163,11 @@ class LigerQwen3VLCausalLMOutputWithPast(_Qwen3VLCausalLMOutputWithPast):
156163
class LigerQwen3VLMoeCausalLMOutputWithPast(_Qwen3VLMoeCausalLMOutputWithPast):
157164
token_accuracy: Optional[torch.FloatTensor] = None
158165
predicted_tokens: Optional[torch.LongTensor] = None
166+
167+
168+
if _Qwen3_5CausalLMOutputWithPast is not None:
169+
170+
@dataclass
171+
class LigerQwen3_5CausalLMOutputWithPast(_Qwen3_5CausalLMOutputWithPast):
172+
token_accuracy: Optional[torch.FloatTensor] = None
173+
predicted_tokens: Optional[torch.LongTensor] = None

src/liger_kernel/transformers/model/qwen3_5.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
88
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
99
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
10+
from liger_kernel.transformers.model.output_classes import LigerQwen3_5CausalLMOutputWithPast
1011

1112

1213
def lce_forward(
@@ -119,3 +120,137 @@ def lce_forward(
119120
token_accuracy=token_accuracy,
120121
predicted_tokens=predicted_tokens,
121122
)
123+
124+
125+
def lce_forward_for_multimodal(
126+
self,
127+
input_ids: Optional[torch.LongTensor] = None,
128+
attention_mask: Optional[torch.Tensor] = None,
129+
position_ids: Optional[torch.LongTensor] = None,
130+
past_key_values: Optional[List[torch.FloatTensor]] = None,
131+
inputs_embeds: Optional[torch.FloatTensor] = None,
132+
labels: Optional[torch.LongTensor] = None,
133+
pixel_values: Optional[torch.Tensor] = None,
134+
pixel_values_videos: Optional[torch.FloatTensor] = None,
135+
image_grid_thw: Optional[torch.LongTensor] = None,
136+
video_grid_thw: Optional[torch.LongTensor] = None,
137+
mm_token_type_ids: Optional[torch.IntTensor] = None,
138+
logits_to_keep: Union[int, torch.Tensor] = 0,
139+
skip_logits: Optional[bool] = None,
140+
**kwargs,
141+
) -> Union[tuple, LigerQwen3_5CausalLMOutputWithPast]:
142+
r"""
143+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
144+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
145+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
146+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
147+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
148+
The temporal, height and width of feature shape of each image in LLM.
149+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
150+
The temporal, height and width of feature shape of each video in LLM.
151+
152+
Example:
153+
154+
```python
155+
>>> from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration
156+
157+
>>> model = Qwen3_5ForConditionalGeneration.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
158+
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-8B-Instruct")
159+
160+
>>> messages = [
161+
{
162+
"role": "user",
163+
"content": [
164+
{
165+
"type": "image",
166+
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
167+
},
168+
{"type": "text", "text": "Describe the image."},
169+
],
170+
}
171+
]
172+
173+
>>> inputs = processor.apply_chat_template(
174+
messages,
175+
tokenize=True,
176+
add_generation_prompt=True,
177+
return_dict=True,
178+
return_tensors="pt"
179+
)
180+
181+
>>> # Generate
182+
>>> generated_ids = model.generate(**inputs, max_new_tokens=1024)
183+
>>> generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
184+
>>> output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
185+
>>> print(output_text)
186+
```
187+
"""
188+
return_dict = kwargs.pop("return_dict", None)
189+
if return_dict is None:
190+
return_dict = self.config.use_return_dict
191+
192+
outputs = self.model(
193+
input_ids=input_ids,
194+
pixel_values=pixel_values,
195+
pixel_values_videos=pixel_values_videos,
196+
image_grid_thw=image_grid_thw,
197+
video_grid_thw=video_grid_thw,
198+
position_ids=position_ids,
199+
attention_mask=attention_mask,
200+
past_key_values=past_key_values,
201+
inputs_embeds=inputs_embeds,
202+
mm_token_type_ids=mm_token_type_ids,
203+
**kwargs,
204+
)
205+
206+
hidden_states = outputs[0]
207+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
208+
kept_hidden_states = hidden_states[:, slice_indices, :]
209+
210+
shift_labels = kwargs.pop("shift_labels", None)
211+
logits = None
212+
loss = None
213+
token_accuracy = None
214+
predicted_tokens = None
215+
216+
if skip_logits is None:
217+
skip_logits = self.training and (labels is not None or shift_labels is not None)
218+
219+
if skip_logits:
220+
result = LigerForCausalLMLoss(
221+
hidden_states=kept_hidden_states,
222+
lm_head_weight=self.lm_head.weight,
223+
labels=labels,
224+
shift_labels=shift_labels,
225+
hidden_size=self.config.text_config.hidden_size,
226+
**kwargs,
227+
)
228+
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
229+
else:
230+
logits = self.lm_head(kept_hidden_states)
231+
if labels is not None or shift_labels is not None:
232+
loss = self.loss_function(
233+
logits=logits,
234+
labels=labels,
235+
shift_labels=shift_labels,
236+
vocab_size=self.config.text_config.vocab_size,
237+
**kwargs,
238+
)
239+
240+
if not return_dict:
241+
output = (logits,) + outputs[1:]
242+
output = ((loss,) + output) if loss is not None else output
243+
output = output + (token_accuracy,) if token_accuracy is not None else output
244+
output = output + (predicted_tokens,) if predicted_tokens is not None else output
245+
return output
246+
247+
return LigerQwen3_5CausalLMOutputWithPast(
248+
loss=loss,
249+
logits=logits,
250+
past_key_values=outputs.past_key_values,
251+
hidden_states=outputs.hidden_states,
252+
attentions=outputs.attentions,
253+
rope_deltas=outputs.rope_deltas,
254+
token_accuracy=token_accuracy,
255+
predicted_tokens=predicted_tokens,
256+
)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2711,47 +2711,65 @@ def apply_liger_kernel_to_qwen3_5(
27112711
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
27122712
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5TextModel
27132713

2714+
try:
2715+
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForConditionalGeneration
2716+
except ImportError:
2717+
Qwen3_5ForConditionalGeneration = None
2718+
27142719
from liger_kernel.transformers.model.qwen3_5 import lce_forward as qwen3_5_lce_forward
2720+
from liger_kernel.transformers.model.qwen3_5 import lce_forward_for_multimodal as qwen3_5_lce_forward_for_multimodal
2721+
from liger_kernel.transformers.monkey_patch import _patch_rms_norm_module
2722+
from liger_kernel.transformers.monkey_patch import _patch_swiglu_module
27152723
from liger_kernel.transformers.rms_norm import LigerRMSNormForQwen3Next
27162724
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
27172725

27182726
if rope:
27192727
raise NotImplementedError("liger_rotary_pos_emb is not available for Qwen3_5 models.")
2728+
27202729
if rms_norm:
27212730
modeling_qwen3_5.Qwen3_5RMSNorm = LigerRMSNormForQwen3Next
2731+
27222732
if cross_entropy:
27232733
from transformers.loss.loss_utils import nn
27242734

2735+
from liger_kernel.transformers.cross_entropy import liger_cross_entropy
2736+
27252737
nn.functional.cross_entropy = liger_cross_entropy
2738+
27262739
if fused_linear_cross_entropy:
27272740
if model is not None:
27282741
if isinstance(model, Qwen3_5ForCausalLM):
27292742
model.forward = MethodType(qwen3_5_lce_forward, model)
2743+
elif isinstance(model, Qwen3_5ForConditionalGeneration):
2744+
model.forward = MethodType(qwen3_5_lce_forward_for_multimodal, model)
27302745
else:
27312746
raise TypeError(
2732-
f"fused_linear_cross_entropy is only applicable on Qwen3_5ForCausalLM. Got: {type(model)}"
2747+
f"fused_linear_cross_entropy is only applicable on Qwen3_5ForCausalLM or Qwen3_5ForConditionalGeneration. Got: {type(model)}"
27332748
)
27342749
else:
27352750
modeling_qwen3_5.Qwen3_5ForCausalLM.forward = qwen3_5_lce_forward
2751+
if Qwen3_5ForConditionalGeneration is not None:
2752+
modeling_qwen3_5.Qwen3_5ForConditionalGeneration.forward = qwen3_5_lce_forward_for_multimodal
2753+
27362754
if swiglu:
27372755
modeling_qwen3_5.Qwen3_5MLP = LigerQwen3MoeSwiGLUMLP
27382756

27392757
if model is not None:
27402758
if isinstance(model, (Qwen3_5ForCausalLM, Qwen3_5TextModel)):
2741-
base_model: Qwen3_5TextModel = getattr(model, model.base_model_prefix, model)
2759+
text_model: Qwen3_5TextModel = getattr(model, model.base_model_prefix, model)
2760+
elif Qwen3_5ForConditionalGeneration is not None and isinstance(model, Qwen3_5ForConditionalGeneration):
2761+
text_model = model.model.language_model
27422762
else:
2743-
raise TypeError(
2744-
f"Unsupported qwen3_5 model type. `model` must be `Qwen3_5ForCausalLM` or `Qwen3_5TextModel`. Got: {type(model)}"
2745-
)
2763+
raise TypeError(f"Unsupported qwen3_5 model type. Got: {type(model)}")
27462764

27472765
_patch_rms_norm_module_for_qwen3_5 = partial(
27482766
_patch_rms_norm_module, offset=1.0, casting_mode="gemma", in_place=False
27492767
)
27502768

27512769
if rms_norm:
2752-
_patch_rms_norm_module_for_qwen3_5(base_model.norm)
2770+
_patch_rms_norm_module_for_qwen3_5(text_model.norm)
27532771

2754-
for decoder_layer in base_model.layers:
2772+
for decoder_layer in text_model.layers:
27552773
if rms_norm:
27562774
_patch_rms_norm_module_for_qwen3_5(decoder_layer.input_layernorm)
27572775
_patch_rms_norm_module_for_qwen3_5(decoder_layer.post_attention_layernorm)

0 commit comments

Comments
 (0)