Fix transformers 5.x compat: GRPO token_type_ids, gpt_oss BlockMask, compiler decorators#511
Fix transformers 5.x compat: GRPO token_type_ids, gpt_oss BlockMask, compiler decorators#511danielhanchen wants to merge 3 commits intomainfrom
Conversation
…rence Three fixes for transformers 5.2.0 compatibility: 1. compiler.py: Strip @merge_with_config_defaults and @capture_outputs decorators from compiled forward functions. These decorators use func.__code__.co_varnames to find positional args, but when stacked the wrapper's (self, *args, **kwargs) is seen instead of the original function's parameter list, causing "got multiple values for argument" errors (e.g. use_cache in Gemma3N). 2. rl_replacements.py: Pass token_type_ids through grpo_accumulated_loss to model forward calls. Gemma3's create_causal_mask_mapping requires token_type_ids during training to build bidirectional attention masks for image tokens. The data pipeline (TRL + Unsloth data collator) already provides token_type_ids in the inputs dict, but it was not extracted or forwarded to the model. 3. gpt_oss.py: Avoid creating BlockMask during inference when _attn_implementation is "flex_attention". The GptOssModel.forward creates attention masks via create_causal_mask, which dispatches to flex_attention_mask and returns BlockMask objects. However, Unsloth uses eager attention for inference, which cannot handle BlockMask (TypeError: unsupported operand type(s) for +=). The fix checks _attn_implementation directly in the forward function rather than relying on the module-level wrapper (which was bypassed due to closure variable capture). Also fixes input_embeds -> inputs_embeds rename in the masking kwargs.
Summary of ChangesHello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses critical compatibility issues with Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f62fdbe98b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| attention_mask = { | ||
| "full_attention": None, | ||
| "sliding_attention": None, | ||
| } |
There was a problem hiding this comment.
Preserve padding mask when bypassing flex BlockMask
When _attn_implementation is "flex_attention", this branch replaces any provided attention_mask with None for both mask types, so later decoder calls receive no padding mask at all. In batched inference/generation with padded prompts, pad positions can become attendable because the user mask is discarded, which changes logits and can corrupt outputs even though the call no longer crashes.
Useful? React with 👍 / 👎.
| if token_type_ids is not None: | ||
| token_type_ids_chunks.append(token_type_ids[start:end]) |
There was a problem hiding this comment.
Repack token_type_ids after left-pack compaction
input_ids are left-packed earlier in this function for non-vision batches, but token_type_ids are chunked from the original un-packed tensor here. If token_type_ids are present with left padding (for models that use segment/token-type masking), each token can be paired with the wrong type id, producing incorrect attention masking and loss values.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request introduces several important compatibility fixes for transformers>=5.2.0. The changes in compiler.py to strip additional decorators, the forwarding of token_type_ids in rl_replacements.py, and the BlockMask prevention logic in gpt_oss.py all seem correct and well-justified. My review includes a suggestion to refactor the decorator stripping logic for better maintainability and a fix for a potential bug in gpt_oss.py related to dictionary key access. Overall, these are solid improvements.
| source = re.sub(r"@auto_docstring[\s]{0,}(\([^\)]{0,}\))?", "", source) | ||
| source = re.sub(r"@check_model_inputs[\s]{0,}(\([^\)]{0,}\))?", "", source) | ||
| source = re.sub(r"@merge_with_config_defaults[\s]{0,}(\([^\)]{0,}\))?", "", source) | ||
| source = re.sub(r"@capture_outputs[\s]{0,}(\([^\)]{0,}\))?", "", source) |
There was a problem hiding this comment.
To improve maintainability and reduce code duplication, you can refactor these repeated re.sub calls for removing decorators into a loop. This makes it easier to add or remove decorators from the list in the future.
| source = re.sub(r"@auto_docstring[\s]{0,}(\([^\)]{0,}\))?", "", source) | |
| source = re.sub(r"@check_model_inputs[\s]{0,}(\([^\)]{0,}\))?", "", source) | |
| source = re.sub(r"@merge_with_config_defaults[\s]{0,}(\([^\)]{0,}\))?", "", source) | |
| source = re.sub(r"@capture_outputs[\s]{0,}(\([^\)]{0,}\))?", "", source) | |
| decorators_to_remove = [ | |
| "auto_docstring", | |
| "check_model_inputs", | |
| "merge_with_config_defaults", | |
| "capture_outputs", | |
| ] | |
| for decorator in decorators_to_remove: | |
| source = re.sub(rf"@{decorator}[\s]*(\([^)]*\))?", "", source) |
| def wrap(f): | ||
| def return_attention_mask(*args, **kwargs): | ||
| if kwargs["input_embeds"].requires_grad: | ||
| embeds = kwargs.get("input_embeds", kwargs.get("inputs_embeds", None)) |
There was a problem hiding this comment.
The change from kwargs["input_embeds"] to kwargs.get("input_embeds", ...) is a good improvement for robustness. However, with transformers>=5.0, input_embeds was renamed to inputs_embeds. The current implementation correctly checks for both, which is great for backward compatibility. This is a solid fix.
The compiler's create_standalone_class uses inspect.getsource to grab the original GptOssMLP source, which references the original GptOssExperts class (with 3D Parameter down_proj). When the compiler replaces modeling_gpt_oss.GptOssMLP with the compiled version, the BnB4bit patch from step 1 is reverted in the compiled module's namespace. Even though patch_gpt_oss_bnb4bit re-runs later to fix the modeling_gpt_oss module attribute, the compiled GptOssMLP.__init__ still closes over its own module's GptOssExperts symbol (the original). Fix: After patching, scan sys.modules for compiled/combined modules and update their GptOssExperts symbol. Also replace GptOssMLP.__init__ with a closure that captures GptOssExpertsBnb4bit via default argument binding, making it immune to subsequent namespace changes.
Qwen2.5 VL in transformers 5.x has the text model at new_model.model.language_model, not new_model.language_model. The hasattr check missed this nesting, causing embed_tokens_key to be "model.embed_tokens.weight" instead of "model.language_model.embed_tokens.weight", which then fails with KeyError when looking up the weight in quant_state_dict. Add intermediate check for model.model.language_model. Also guard the pad_token_id assertion with an embed_tokens_key existence check to avoid KeyError on unexpected hierarchies.
Summary
Fixes three issues that break Unsloth when running with
transformers>=5.2.0:compiler.py: Strip
@merge_with_config_defaultsand@capture_outputsdecorators during source extraction. These decorators inspectfunc.__code__.co_varnamesto find positional parameter names, but when stacked the wrapper signature(self, *args, **kwargs)is resolved instead of the original function's parameters. This causedgot multiple values for argument 'use_cache'on Gemma3N and similar models.rl_replacements.py: Forward
token_type_idsthroughgrpo_accumulated_lossto the model's forward call. In transformers 5.x, Gemma3'screate_causal_mask_mappingrequirestoken_type_idsduring training to construct bidirectional attention masks for image vs text tokens. TRL and the Unsloth data collator already providetoken_type_idsin the inputs dict, but the GRPO accumulated loss path never extracted or passed it to the model. Uses**_extra_fwd_kwargspattern to maintain backwards compatibility with models that do not accepttoken_type_ids.gpt_oss.py: Prevent
BlockMaskcreation during inference. TheGptOssModel.forwardfunction callscreate_causal_mask()which, when_attn_implementation="flex_attention", dispatches toflex_attention_mask()returning aBlockMask. But Unsloth uses eager attention for inference, which cannot handleBlockMaskobjects (TypeError: unsupported operand type(s) for +=: 'Tensor' and 'BlockMask'). The fix directly checks_attn_implementationin the forward function and passesNonefor the flex_attention case, rather than relying on the module-level wrapper which was bypassed due to closure variable capture. Also fixes theinput_embedstoinputs_embedsrename in the mask kwargs dict.Test plan