-
Notifications
You must be signed in to change notification settings - Fork 234
Fix missing ParameterModule export in GPT-OSS compiler path #519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2524b1b
5776d3c
d164b79
2e1b178
4dd1ed6
a870ecc
bf5e103
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -648,6 +648,26 @@ def _load_from_state_dict( | |
| ) | ||
|
|
||
|
|
||
| def patch_gpt_oss_compiler_exports(): | ||
| model_name = os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally, if our temporary_patches creates and references a class, compiler.py should automatically handle copying it over. But for now this is fine. |
||
| if "gpt_oss" not in model_name: | ||
| return | ||
| try: | ||
| import transformers.models.gpt_oss.modeling_gpt_oss | ||
| except Exception as e: | ||
| raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e) | ||
| return | ||
|
|
||
| # Export helpers so compiler generated GPT-OSS modules can resolve symbols. | ||
| m = transformers.models.gpt_oss.modeling_gpt_oss | ||
| m.ParameterModule = ParameterModule | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My only worry here is that this would leave |
||
| m.swiglu_torch_forward = swiglu_torch_forward | ||
| m.dtype_from_config = dtype_from_config | ||
| m.transformers_version = transformers_version | ||
| m.Version = Version | ||
| TEMPORARY_PATCHES.append(patch_gpt_oss_compiler_exports) | ||
|
|
||
|
|
||
| class GptOssExperts(nn.Module): | ||
| """ | ||
| GPT OSS MoE Experts layer with 3D stacked parameters. | ||
|
|
@@ -1316,15 +1336,19 @@ def _should_use_gpt_oss_bnb4bit() -> bool: | |
| Default: True when load_in_4bit is active. | ||
| Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path. | ||
| """ | ||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): | ||
| if "gpt_oss" not in _normalized_unsloth_model_name(): | ||
| return False | ||
| if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): | ||
| if "_load_in_4bit_" not in _normalized_unsloth_model_name(): | ||
| return False | ||
| return os.environ.get("UNSLOTH_GPT_OSS_BNB4BIT_DISABLE", "0") != "1" | ||
|
|
||
|
|
||
| def _is_gpt_oss_4bit_load() -> bool: | ||
| return "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", "") | ||
| return "_load_in_4bit_" in _normalized_unsloth_model_name() | ||
|
|
||
|
|
||
| def _normalized_unsloth_model_name() -> str: | ||
| return os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_") | ||
|
|
||
|
|
||
| def _is_transformers_v5() -> bool: | ||
|
|
@@ -1340,7 +1364,7 @@ def patch_gpt_oss_moe_for_lora(): | |
| IMPORTANT: We only patch the forward method, NOT replace the entire class. | ||
| This preserves the original class structure so weights load correctly. | ||
| """ | ||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): | ||
| if "gpt_oss" not in _normalized_unsloth_model_name(): | ||
| return | ||
| if _is_gpt_oss_4bit_load() or _should_use_gpt_oss_bnb4bit(): | ||
| # 4-bit loads should keep quantized weights and use default PEFT LoRA. | ||
|
|
@@ -1774,8 +1798,8 @@ def patch_gpt_oss_linearized(): | |
| Patch GPT OSS for 4bit loading with grouped_mm support. | ||
| Only patches the GptOssExperts forward method - keeps original classes for proper weight loading. | ||
| """ | ||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return | ||
| if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return | ||
| if "gpt_oss" not in _normalized_unsloth_model_name(): return | ||
| if "_load_in_4bit_" not in _normalized_unsloth_model_name(): return | ||
| if _should_use_gpt_oss_bnb4bit(): return | ||
| try: | ||
| import transformers.models.gpt_oss.modeling_gpt_oss | ||
|
|
@@ -1813,7 +1837,7 @@ def experts_forward( | |
|
|
||
| def patch_GptOssAttention(): | ||
| if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return | ||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return | ||
| if "gpt_oss" not in _normalized_unsloth_model_name(): return | ||
| try: | ||
| from ..flex_attention import ( | ||
| flex_attention_with_sink, | ||
|
|
@@ -2054,7 +2078,7 @@ def forward( | |
|
|
||
| def patch_GptOssModel(): | ||
| if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return | ||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return | ||
| if "gpt_oss" not in _normalized_unsloth_model_name(): return | ||
| try: | ||
| import transformers.models.gpt_oss.modeling_gpt_oss | ||
| transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel | ||
|
|
@@ -2075,12 +2099,25 @@ def patch_GptOssModel(): | |
| import transformers.generation.utils | ||
| def wrap(f): | ||
| def return_attention_mask(*args, **kwargs): | ||
| if kwargs["input_embeds"].requires_grad: | ||
| input_embeds = kwargs.get("input_embeds", None) | ||
| if input_embeds is None: | ||
| input_embeds = kwargs.get("inputs_embeds", None) | ||
| if input_embeds is None: | ||
| for arg in args: | ||
| if type(arg) is torch.Tensor and arg.is_floating_point(): | ||
| input_embeds = arg | ||
| break | ||
|
|
||
| if input_embeds is not None and input_embeds.requires_grad: | ||
| if "attention_mask" in kwargs: | ||
| return kwargs["attention_mask"] | ||
| for arg in args: | ||
| if type(arg) is torch.Tensor and arg.dtype == torch.int32: | ||
| if ( | ||
| type(arg) is torch.Tensor and | ||
| arg.dtype in (torch.int32, torch.int64, torch.bool) | ||
| ): | ||
| return arg | ||
| return f(*args, **kwargs) | ||
| else: | ||
| # Eager | ||
| return f(*args, **kwargs) | ||
|
|
@@ -2739,7 +2776,7 @@ def patch_gpt_oss_config(): | |
|
|
||
|
|
||
| def patch_gpt_oss_init_weights_modulelist_fix(): | ||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): | ||
| if "gpt_oss" not in _normalized_unsloth_model_name(): | ||
| return | ||
| try: | ||
| import transformers.models.gpt_oss.modeling_gpt_oss | ||
|
|
@@ -2784,7 +2821,7 @@ def patch_gpt_oss_for_grpo(): | |
| When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits. | ||
| This fixes the matrix multiplication dimension mismatch issue in GRPO training. | ||
| """ | ||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): | ||
| if "gpt_oss" not in _normalized_unsloth_model_name(): | ||
| return | ||
|
|
||
| try: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.