Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions unsloth_zoo/temporary_patches/ministral.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import torch
import inspect
from typing import Optional, Callable
from .common import TEMPORARY_PATCHES
from .utils import (
Expand Down Expand Up @@ -90,13 +91,19 @@ def forward(
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights

# Wrap so check_args_kwargs accepts removed params (e.g. cache_position in v5)
# Wrap so check_args_kwargs accepts removed params (e.g. cache_position in v5).
# Preserve the original signature on the wrapper so inspect.signature
# (used by transformers._validate_model_kwargs among others) still sees
# the real named parameters.
target_cls = transformers.models.ministral.modeling_ministral.MinistralAttention
_original_forward_signature = inspect.signature(target_cls.forward)
_full_forward = forward
def forward(self, *args, **kwargs):
return _full_forward(self, *args, **kwargs)
forward.__signature__ = _original_forward_signature

patch_function(
transformers.models.ministral.modeling_ministral.MinistralAttention,
target_cls,
"forward",
forward,
match_level="relaxed",
Expand Down
9 changes: 8 additions & 1 deletion unsloth_zoo/temporary_patches/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,15 @@ def forward(
pass

# Wrap with (self, *args, **kwargs) so check_args_kwargs accepts any
# removed params (output_attentions, output_hidden_states, cache_position)
# removed params (output_attentions, output_hidden_states, cache_position).
# Copy the original class signature onto the wrapper so
# transformers._validate_model_kwargs (used by generate) still sees
# the real named parameters like backbone_last_hidden_state.
_original_forward_signature = inspect.signature(target_cls.forward)
_full_forward = forward
def forward(self, *args, **kwargs):
return _full_forward(self, *args, **kwargs)
forward.__signature__ = _original_forward_signature
patch_function(target_cls, "forward", forward, match_level="relaxed")
pass
TEMPORARY_PATCHES.append(patch_CsmDepthDecoderForCausalLM_forward)
Expand Down Expand Up @@ -367,9 +372,11 @@ def forward(
})
pass

_original_forward_signature = inspect.signature(target_cls.forward)
_full_forward = forward
def forward(self, *args, **kwargs):
return _full_forward(self, *args, **kwargs)
forward.__signature__ = _original_forward_signature
Comment on lines +375 to +379
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the other patches in this PR (such as patch_CsmDepthDecoderForCausalLM_forward and patch_MinistralAttention), it would be beneficial to add a brief comment explaining why the original signature is being preserved on the wrapper. This helps future maintainers understand that this is necessary for transformers' kwarg validation during generation.

Suggested change
_original_forward_signature = inspect.signature(target_cls.forward)
_full_forward = forward
def forward(self, *args, **kwargs):
return _full_forward(self, *args, **kwargs)
forward.__signature__ = _original_forward_signature
# Preserve the original signature on the wrapper so inspect.signature
# (used by transformers._validate_model_kwargs among others) still sees
# the real named parameters.
_original_forward_signature = inspect.signature(target_cls.forward)
_full_forward = forward
def forward(self, *args, **kwargs):
return _full_forward(self, *args, **kwargs)
forward.__signature__ = _original_forward_signature

patch_function(target_cls, "forward", forward, match_level="relaxed")
pass
TEMPORARY_PATCHES.append(patch_CsmForConditionalGeneration_forward)
Expand Down
13 changes: 10 additions & 3 deletions unsloth_zoo/temporary_patches/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
import torch.nn as nn
import inspect
from typing import Optional, Tuple
from .common import TEMPORARY_PATCHES
from .utils import (
Expand Down Expand Up @@ -83,19 +84,25 @@ def forward(
attn_output = self.o_proj(attn_output)
return attn_output, None

# Wrap so check_args_kwargs accepts removed params (e.g. output_attentions in v5)
# Wrap so check_args_kwargs accepts removed params (e.g. output_attentions in v5).
# Preserve the original signature on the wrapper so inspect.signature
# (used by transformers._validate_model_kwargs among others) still sees
# the real named parameters.
target_cls = transformers.models.pixtral.modeling_pixtral.PixtralAttention
_original_forward_signature = inspect.signature(target_cls.forward)
_full_forward = forward
def forward(self, *args, **kwargs):
return _full_forward(self, *args, **kwargs)
forward.__signature__ = _original_forward_signature

patch_function(
transformers.models.pixtral.modeling_pixtral.PixtralAttention,
target_cls,
"__init__",
__init__,
)

patch_function(
transformers.models.pixtral.modeling_pixtral.PixtralAttention,
target_cls,
"forward",
forward,
)
Comment on lines 106 to 108
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the MinistralAttention and CsmDepthDecoderForCausalLM patches, consider explicitly setting match_level="relaxed" here. While patch_function's internal logic for *args, **kwargs wrappers might allow this even in strict mode, being explicit matches the pattern used in the other files modified in this PR.

    patch_function(
        target_cls,
        "forward",
        forward,
        match_level="relaxed",
    )
References
  1. It is acceptable to use fragile string-matching for code patching if it is consistent with the existing codebase's architecture and a more robust solution would require a large-scale refactor. This avoids introducing new types of fragility.

Expand Down
Loading