Skip to content

Commit 8d83a09

Browse files
authored
Merge branch 'main' into tcc/from_config
2 parents 21ec50a + 6059dfb commit 8d83a09

File tree

7 files changed

+120
-102
lines changed

7 files changed

+120
-102
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def get_default_dependencies():
2525
"torch>=2.6.0",
2626
]
2727
elif platform == "npu":
28-
return ["torch_npu==2.6.0", "triton-ascend"]
28+
return ["torch_npu==2.7.1", "triton-ascend"]
2929

3030

3131
def get_optional_dependencies():

src/liger_kernel/chunked_loss/cosine_similarity_loss.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99

1010
class LigerFusedLinearCosineSimilarityFunction(LigerFusedLinearDistillationBase):
1111
@staticmethod
12-
def distillation_loss_fn(student_logits, teacher_logits, beta=1.0):
12+
def distillation_loss_fn(
13+
student_logits,
14+
teacher_logits,
15+
target=None,
16+
ignore_index=None,
17+
beta=1.0,
18+
):
1319
"""
1420
Compute Cosine loss (Cosine Similarity Loss).
1521
Args:

src/liger_kernel/chunked_loss/fused_linear_distillation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class LigerFusedLinearDistillationBase(torch.autograd.Function):
1313
def distillation_loss_fn(
1414
student_logits,
1515
teacher_logits,
16+
target=None,
17+
ignore_index=None,
1618
):
1719
"""
1820
Compute distillation loss.

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def apply_liger_kernel_to_llava(
430430
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
431431
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
432432
)
433-
text_kwargs["model"] = model.language_model
433+
text_kwargs["model"] = model.model.language_model
434434
text_liger_fn(**text_kwargs)
435435
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
436436
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
@@ -445,7 +445,7 @@ def apply_liger_kernel_to_llava(
445445
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
446446
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
447447
)
448-
vision_kwargs["model"] = model.vision_tower
448+
vision_kwargs["model"] = model.model.vision_tower
449449
vision_liger_fn(**vision_kwargs)
450450
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
451451
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
@@ -615,8 +615,8 @@ def apply_liger_kernel_to_mllama(
615615
# instance variables that reference already-instantiated modules
616616

617617
if isinstance(model, MllamaForConditionalGeneration):
618-
language_model: MllamaForCausalLM = model.language_model
619-
vision_model: MllamaVisionModel = model.vision_model
618+
language_model: MllamaForCausalLM = model.model.language_model
619+
vision_model: MllamaVisionModel = model.model.vision_model
620620
if isinstance(language_model, MllamaForCausalLM):
621621
text_model: MllamaTextModel = language_model.model
622622
else:
@@ -1118,8 +1118,8 @@ def apply_liger_kernel_to_gemma3(
11181118
# instance variables that reference already-instantiated modules
11191119

11201120
if isinstance(model, Gemma3ForConditionalGeneration):
1121-
if isinstance(model.vision_tower, SiglipVisionModel):
1122-
vision_tower = model.vision_tower
1121+
if isinstance(model.model.vision_tower, SiglipVisionModel):
1122+
vision_tower = model.model.vision_tower
11231123

11241124
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
11251125

@@ -1132,15 +1132,15 @@ def apply_liger_kernel_to_gemma3(
11321132
raise TypeError("The vision tower must be SiglipVisionModel")
11331133

11341134
if rms_norm:
1135-
_patch_rms_norm_module_for_gemma3(model.multi_modal_projector.mm_soft_emb_norm)
1135+
_patch_rms_norm_module_for_gemma3(model.model.multi_modal_projector.mm_soft_emb_norm)
11361136

11371137
apply_liger_kernel_to_gemma3_text(
11381138
rope=rope,
11391139
cross_entropy=False,
11401140
fused_linear_cross_entropy=False,
11411141
rms_norm=rms_norm,
11421142
geglu=geglu,
1143-
model=model.language_model,
1143+
model=model.model.language_model,
11441144
)
11451145

11461146
else:
@@ -1228,7 +1228,7 @@ def apply_liger_kernel_to_paligemma(
12281228
if not isinstance(model, PaliGemmaForConditionalGeneration):
12291229
raise TypeError("model have to be of type PaliGemmaForConditionalGeneration")
12301230

1231-
vision_tower: SiglipVisionModel = model.vision_tower
1231+
vision_tower: SiglipVisionModel = model.model.vision_tower
12321232

12331233
_patch_layer_norm_module(vision_tower.vision_model.post_layernorm)
12341234

@@ -1238,7 +1238,7 @@ def apply_liger_kernel_to_paligemma(
12381238
_patch_layer_norm_module(layer.layer_norm1)
12391239
_patch_layer_norm_module(layer.layer_norm2)
12401240

1241-
language_model = model.language_model
1241+
language_model = model.model.language_model
12421242

12431243
if isinstance(language_model, (GemmaForCausalLM, GemmaModel)):
12441244
apply_liger_kernel_to_gemma(
@@ -1593,11 +1593,10 @@ def apply_liger_kernel_to_qwen2_vl(
15931593
if model is not None:
15941594
# The model instance already exists, so we need to additionally patch the
15951595
# instance variables that reference already-instantiated modules
1596-
1597-
if isinstance(model, (Qwen2VLForConditionalGeneration, Qwen2VLModel)):
1598-
# Note: language_model and visual properties can be accessed throught conditional class for BC.
1599-
# Not sure if it is subject to changes in the future.
1600-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1698
1596+
if isinstance(model, Qwen2VLForConditionalGeneration):
1597+
text_model: Qwen2VLTextModel = model.model.language_model
1598+
vision_model: Qwen2VisionTransformerPretrainedModel = model.model.visual
1599+
elif isinstance(model, Qwen2VLModel):
16011600
text_model: Qwen2VLTextModel = model.language_model
16021601
vision_model: Qwen2VisionTransformerPretrainedModel = model.visual
16031602
elif isinstance(model, Qwen2VLTextModel):
@@ -1684,11 +1683,10 @@ def apply_liger_kernel_to_qwen2_5_vl(
16841683
if model is not None:
16851684
# The model instance already exists, so we need to additionally patch the
16861685
# instance variables that reference already-instantiated modules
1687-
1688-
if isinstance(model, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLModel)):
1689-
# Note: language_model and visual properties can be accessed throught conditional class for BC.
1690-
# Not sure if it is subject to changes in the future.
1691-
# Reference: https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1823
1686+
if isinstance(model, Qwen2_5_VLForConditionalGeneration):
1687+
text_model: Qwen2_5_VLTextModel = model.model.language_model
1688+
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.model.visual
1689+
elif isinstance(model, Qwen2_5_VLModel):
16921690
text_model: Qwen2_5_VLTextModel = model.language_model
16931691
vision_model: Qwen2_5_VisionTransformerPretrainedModel = model.visual
16941692
elif isinstance(model, Qwen2_5_VLTextModel):
@@ -1702,7 +1700,7 @@ def apply_liger_kernel_to_qwen2_5_vl(
17021700

17031701
if vision_model is not None:
17041702
# Patch Qwen2_5_VisionTransformerPretrainedModel
1705-
for vision_block in model.visual.blocks:
1703+
for vision_block in vision_model.blocks:
17061704
if rms_norm:
17071705
_patch_rms_norm_module(vision_block.norm1)
17081706
_patch_rms_norm_module(vision_block.norm2)
@@ -1771,7 +1769,9 @@ def apply_liger_kernel_to_qwen3_vl(
17711769
modeling_qwen3_vl.Qwen3VLForConditionalGeneration.forward = qwen3_vl_lce_forward
17721770

17731771
if model is not None and rms_norm:
1774-
if isinstance(model, (Qwen3VLForConditionalGeneration, Qwen3VLModel)):
1772+
if isinstance(model, Qwen3VLForConditionalGeneration):
1773+
text_model: Qwen3VLTextModel = model.model.language_model
1774+
elif isinstance(model, Qwen3VLModel):
17751775
text_model: Qwen3VLTextModel = model.language_model
17761776
elif isinstance(model, Qwen3VLTextModel):
17771777
text_model = model
@@ -1846,7 +1846,9 @@ def apply_liger_kernel_to_qwen3_vl_moe(
18461846
modeling_qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_moe_lce_forward
18471847

18481848
if model is not None and rms_norm:
1849-
if isinstance(model, (Qwen3VLMoeForConditionalGeneration, Qwen3VLMoeModel)):
1849+
if isinstance(model, Qwen3VLMoeForConditionalGeneration):
1850+
text_model: Qwen3VLMoeTextModel = model.model.language_model
1851+
elif isinstance(model, Qwen3VLMoeModel):
18501852
text_model: Qwen3VLMoeTextModel = model.language_model
18511853
elif isinstance(model, Qwen3VLMoeTextModel):
18521854
text_model = model
@@ -2191,10 +2193,10 @@ def apply_liger_kernel_to_glm4v(
21912193
if model is not None:
21922194
# The model instance already exists, so we need to additionally patch the
21932195
# instance variables that reference already-instantiated modules
2194-
if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
2195-
# Note: language_model and visual properties can be accessed throught conditional class for BC.
2196-
# Not sure if it is subject to changes in the future.
2197-
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
2196+
if isinstance(model, Glm4vForConditionalGeneration):
2197+
text_model: Glm4vTextModel = model.model.language_model
2198+
vision_model: Glm4vVisionModel = model.model.visual
2199+
elif isinstance(model, Glm4vModel):
21982200
text_model: Glm4vTextModel = model.language_model
21992201
vision_model: Glm4vVisionModel = model.visual
22002202
elif isinstance(model, Glm4vTextModel):
@@ -2281,10 +2283,11 @@ def apply_liger_kernel_to_glm4v_moe(
22812283
if model is not None:
22822284
# The model instance already exists, so we need to additionally patch the
22832285
# instance variables that reference already-instantiated modules
2284-
if isinstance(model, (Glm4vMoeForConditionalGeneration, Glm4vMoeModel)):
2285-
# Note: language_model and visual properties can be accessed throught conditional class for BC.
2286-
# Not sure if it is subject to changes in the future.
2287-
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py#L337
2286+
if isinstance(model, Glm4vMoeForConditionalGeneration):
2287+
text_model: Glm4vMoeTextModel = model.model.language_model
2288+
vision_model: Glm4vMoeVisionModel = model.model.visual
2289+
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
2290+
elif isinstance(model, Glm4vMoeModel):
22882291
text_model: Glm4vMoeTextModel = model.language_model
22892292
vision_model: Glm4vMoeVisionModel = model.visual
22902293
Glm4vMoeTextMoE = modeling_glm4v_moe.Glm4vMoeTextMoE
@@ -2387,8 +2390,10 @@ def apply_liger_kernel_to_internvl(
23872390
if model is not None:
23882391
# The model instance already exists, so we need to additionally patch the
23892392
# instance variables that reference already-instantiated modules
2390-
if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
2391-
# NOTE: language_model and visual properties can be accessed throught conditional class.
2393+
if isinstance(model, InternVLForConditionalGeneration):
2394+
text_model = model.model.language_model
2395+
vision_model: InternVLVisionModel = model.model.vision_tower
2396+
elif isinstance(model, InternVLModel):
23922397
text_model = model.language_model
23932398
vision_model: InternVLVisionModel = model.vision_tower
23942399
else:

test/chunked_loss/test_cosine_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
temperature=temperature,
3434
)
3535

36-
def distillation_loss(self, student_logits, teacher_logits, beta=1.0):
36+
def distillation_loss(self, student_logits, teacher_logits, target=None, ignore_index=None, beta=1.0, **kwargs):
3737
# Compute normalized logits
3838
print(f"student_logits.shape: {student_logits.shape}")
3939
student_norm = F.normalize(student_logits, p=2, dim=-1)

0 commit comments

Comments
 (0)