Skip to content

Fix CSM depth decoder generate: preserve forward signature on wrapper#590

Open
danielhanchen wants to merge 1 commit intomainfrom
fix/csm-depth-decoder-signature
Open

Fix CSM depth decoder generate: preserve forward signature on wrapper#590
danielhanchen wants to merge 1 commit intomainfrom
fix/csm-depth-decoder-signature

Conversation

@danielhanchen
Copy link
Copy Markdown
Contributor

Summary

  • CsmGenerationMixin._sample calls self.depth_decoder.generate(backbone_last_hidden_state=...), which goes through GenerationMixin._validate_model_kwargs. That validator reads inspect.signature(self.forward).parameters and rejects any kwarg not listed there.

  • The depth-decoder patch in temporary_patches/misc.py wraps the inner forward in a (self, *args, **kwargs) shell so check_args_kwargs accepts the removal of output_attentions / output_hidden_states / cache_position across transformers versions. After that wrap, inspect.signature on the wrapper only reports (self, *args, **kwargs), so _validate_model_kwargs no longer sees backbone_last_hidden_state and raises:

    ValueError: The following model_kwargs are not used by the model: ['backbone_last_hidden_state']
    
  • Capture the original class signature before wrapping and assign it to forward.__signature__. Python's inspect.signature honours __signature__ unconditionally, so transformers sees the real named parameters while check_args_kwargs still sees a (*args, **kwargs) passthrough at runtime. The captured signature tracks whatever transformers version is installed, so the fix works on both 4.57.x and 5.x.

  • Apply the same treatment to patch_CsmForConditionalGeneration_forward and to the MinistralAttention / PixtralAttention wrappers for consistency.

Repro (before fix)

import unsloth
import torch
from transformers import CsmForConditionalGeneration, AutoProcessor

model_id = "unsloth/csm-1b"
processor = AutoProcessor.from_pretrained(model_id)
model = CsmForConditionalGeneration.from_pretrained(model_id, dtype=torch.bfloat16, device_map="cuda:0")

conversation = [{"role": "0", "content": [{"type": "text", "text": "Hello there"}]}]
inputs = processor.apply_chat_template(conversation, tokenize=True, return_dict=True).to(model.device)
model.generate(**inputs, output_audio=True, max_new_tokens=8)
# ValueError: The following model_kwargs are not used by the model: ['backbone_last_hidden_state']

After fix

depth_decoder.forward params: ['input_ids', 'backbone_last_hidden_state', ...]
PASS: signature includes backbone_last_hidden_state
Output type: <class 'list'>
  [0] shape=torch.Size([13440]) dtype=torch.bfloat16
SUCCESS

Verified on transformers 4.57.6 with torch 2.10.0+rocm7.1 on MI300X.

Test plan

  • CsmForConditionalGeneration.generate(output_audio=True) produces a non-empty audio tensor
  • inspect.signature(model.depth_decoder.forward) lists backbone_last_hidden_state
  • inspect.signature(model.forward) lists input_ids, input_values, input_values_cutoffs
  • MinistralAttention.forward and PixtralAttention.forward patches still apply cleanly after adding __signature__
  • Checked upstream transformers 5.5.0: CsmDepthDecoderForCausalLM.forward still declares backbone_last_hidden_state, so the captured signature works there too

CsmGenerationMixin._sample calls
self.depth_decoder.generate(backbone_last_hidden_state=...), which goes
through GenerationMixin._validate_model_kwargs. That validator reads
inspect.signature(self.forward).parameters and rejects any kwarg not
listed there.

The depth-decoder patch in temporary_patches/misc.py wraps the inner
forward in a (self, *args, **kwargs) shell so check_args_kwargs accepts
the removal of output_attentions/output_hidden_states/cache_position
across transformers versions. After that wrap, inspect.signature on the
wrapper only reports (self, *args, **kwargs), so _validate_model_kwargs
no longer sees backbone_last_hidden_state and raises:

  ValueError: The following model_kwargs are not used by the model:
  ['backbone_last_hidden_state']

Capture the original class signature before wrapping and assign it to
forward.__signature__. Python's inspect.signature honours __signature__
unconditionally, so transformers sees the real named parameters while
check_args_kwargs still sees a (*args, **kwargs) passthrough at runtime.
The captured signature tracks whatever transformers version is
installed, so the fix works on both 4.57.x and 5.x.

Apply the same treatment to patch_CsmForConditionalGeneration_forward
and to the MinistralAttention / PixtralAttention wrappers for
consistency.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates several temporary patches for Ministral, Pixtral, and other models to preserve original function signatures on patched forward methods. By utilizing inspect.signature and assigning it to the wrapper's signature attribute, the changes ensure that transformers' keyword argument validation correctly identifies parameters during generation. Feedback suggests adding explanatory comments to the CsmForConditionalGeneration patch for consistency and explicitly setting the match_level to 'relaxed' in the Pixtral patch to align with other implementations.

Comment on lines +375 to +379
_original_forward_signature = inspect.signature(target_cls.forward)
_full_forward = forward
def forward(self, *args, **kwargs):
return _full_forward(self, *args, **kwargs)
forward.__signature__ = _original_forward_signature
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the other patches in this PR (such as patch_CsmDepthDecoderForCausalLM_forward and patch_MinistralAttention), it would be beneficial to add a brief comment explaining why the original signature is being preserved on the wrapper. This helps future maintainers understand that this is necessary for transformers' kwarg validation during generation.

Suggested change
_original_forward_signature = inspect.signature(target_cls.forward)
_full_forward = forward
def forward(self, *args, **kwargs):
return _full_forward(self, *args, **kwargs)
forward.__signature__ = _original_forward_signature
# Preserve the original signature on the wrapper so inspect.signature
# (used by transformers._validate_model_kwargs among others) still sees
# the real named parameters.
_original_forward_signature = inspect.signature(target_cls.forward)
_full_forward = forward
def forward(self, *args, **kwargs):
return _full_forward(self, *args, **kwargs)
forward.__signature__ = _original_forward_signature

Comment on lines 106 to 108
"forward",
forward,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the MinistralAttention and CsmDepthDecoderForCausalLM patches, consider explicitly setting match_level="relaxed" here. While patch_function's internal logic for *args, **kwargs wrappers might allow this even in strict mode, being explicit matches the pattern used in the other files modified in this PR.

    patch_function(
        target_cls,
        "forward",
        forward,
        match_level="relaxed",
    )
References
  1. It is acceptable to use fragile string-matching for code patching if it is consistent with the existing codebase's architecture and a more robust solution would require a large-scale refactor. This avoids introducing new types of fragility.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant