Fix CSM depth decoder generate: preserve forward signature on wrapper#590
Fix CSM depth decoder generate: preserve forward signature on wrapper#590danielhanchen wants to merge 1 commit intomainfrom
Conversation
CsmGenerationMixin._sample calls self.depth_decoder.generate(backbone_last_hidden_state=...), which goes through GenerationMixin._validate_model_kwargs. That validator reads inspect.signature(self.forward).parameters and rejects any kwarg not listed there. The depth-decoder patch in temporary_patches/misc.py wraps the inner forward in a (self, *args, **kwargs) shell so check_args_kwargs accepts the removal of output_attentions/output_hidden_states/cache_position across transformers versions. After that wrap, inspect.signature on the wrapper only reports (self, *args, **kwargs), so _validate_model_kwargs no longer sees backbone_last_hidden_state and raises: ValueError: The following model_kwargs are not used by the model: ['backbone_last_hidden_state'] Capture the original class signature before wrapping and assign it to forward.__signature__. Python's inspect.signature honours __signature__ unconditionally, so transformers sees the real named parameters while check_args_kwargs still sees a (*args, **kwargs) passthrough at runtime. The captured signature tracks whatever transformers version is installed, so the fix works on both 4.57.x and 5.x. Apply the same treatment to patch_CsmForConditionalGeneration_forward and to the MinistralAttention / PixtralAttention wrappers for consistency.
There was a problem hiding this comment.
Code Review
This pull request updates several temporary patches for Ministral, Pixtral, and other models to preserve original function signatures on patched forward methods. By utilizing inspect.signature and assigning it to the wrapper's signature attribute, the changes ensure that transformers' keyword argument validation correctly identifies parameters during generation. Feedback suggests adding explanatory comments to the CsmForConditionalGeneration patch for consistency and explicitly setting the match_level to 'relaxed' in the Pixtral patch to align with other implementations.
| _original_forward_signature = inspect.signature(target_cls.forward) | ||
| _full_forward = forward | ||
| def forward(self, *args, **kwargs): | ||
| return _full_forward(self, *args, **kwargs) | ||
| forward.__signature__ = _original_forward_signature |
There was a problem hiding this comment.
For consistency with the other patches in this PR (such as patch_CsmDepthDecoderForCausalLM_forward and patch_MinistralAttention), it would be beneficial to add a brief comment explaining why the original signature is being preserved on the wrapper. This helps future maintainers understand that this is necessary for transformers' kwarg validation during generation.
| _original_forward_signature = inspect.signature(target_cls.forward) | |
| _full_forward = forward | |
| def forward(self, *args, **kwargs): | |
| return _full_forward(self, *args, **kwargs) | |
| forward.__signature__ = _original_forward_signature | |
| # Preserve the original signature on the wrapper so inspect.signature | |
| # (used by transformers._validate_model_kwargs among others) still sees | |
| # the real named parameters. | |
| _original_forward_signature = inspect.signature(target_cls.forward) | |
| _full_forward = forward | |
| def forward(self, *args, **kwargs): | |
| return _full_forward(self, *args, **kwargs) | |
| forward.__signature__ = _original_forward_signature |
| "forward", | ||
| forward, | ||
| ) |
There was a problem hiding this comment.
For consistency with the MinistralAttention and CsmDepthDecoderForCausalLM patches, consider explicitly setting match_level="relaxed" here. While patch_function's internal logic for *args, **kwargs wrappers might allow this even in strict mode, being explicit matches the pattern used in the other files modified in this PR.
patch_function(
target_cls,
"forward",
forward,
match_level="relaxed",
)References
- It is acceptable to use fragile string-matching for code patching if it is consistent with the existing codebase's architecture and a more robust solution would require a large-scale refactor. This avoids introducing new types of fragility.
Summary
CsmGenerationMixin._samplecallsself.depth_decoder.generate(backbone_last_hidden_state=...), which goes throughGenerationMixin._validate_model_kwargs. That validator readsinspect.signature(self.forward).parametersand rejects any kwarg not listed there.The depth-decoder patch in
temporary_patches/misc.pywraps the inner forward in a(self, *args, **kwargs)shell socheck_args_kwargsaccepts the removal ofoutput_attentions/output_hidden_states/cache_positionacross transformers versions. After that wrap,inspect.signatureon the wrapper only reports(self, *args, **kwargs), so_validate_model_kwargsno longer seesbackbone_last_hidden_stateand raises:Capture the original class signature before wrapping and assign it to
forward.__signature__. Python'sinspect.signaturehonours__signature__unconditionally, so transformers sees the real named parameters whilecheck_args_kwargsstill sees a(*args, **kwargs)passthrough at runtime. The captured signature tracks whatever transformers version is installed, so the fix works on both 4.57.x and 5.x.Apply the same treatment to
patch_CsmForConditionalGeneration_forwardand to theMinistralAttention/PixtralAttentionwrappers for consistency.Repro (before fix)
After fix
Verified on transformers 4.57.6 with torch 2.10.0+rocm7.1 on MI300X.
Test plan
CsmForConditionalGeneration.generate(output_audio=True)produces a non-empty audio tensorinspect.signature(model.depth_decoder.forward)listsbackbone_last_hidden_stateinspect.signature(model.forward)listsinput_ids,input_values,input_values_cutoffsMinistralAttention.forwardandPixtralAttention.forwardpatches still apply cleanly after adding__signature__CsmDepthDecoderForCausalLM.forwardstill declaresbackbone_last_hidden_state, so the captured signature works there too