Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions unsloth_zoo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,8 +1151,14 @@ def create_standalone_class(
for line in lines:
stripped = line.strip()
if stripped.startswith("@"):
if "use_experts_implementation" in stripped:
logger.info(f'Unsloth: stripped use_experts_implementation decorator from {module}')
if (
"use_experts_implementation" in stripped
or "use_kernel_forward_from_hub" in stripped
or "use_kernelized_func" in stripped
or stripped.startswith("@auto_docstring")
):
decorator_name = stripped.split("(")[0].lstrip("@")
logger.info(f"Unsloth: stripped {decorator_name} decorator from {module}")
continue # Strip it
else:
logger.warning(f"Unsloth: Warning: Unknown decorator {stripped} found for {module}.")
Expand Down Expand Up @@ -1269,6 +1275,7 @@ def create_standalone_class(

# Remove @auto_docstring
source = re.sub(r"@auto_docstring[\s]{0,}(\([^\)]{0,}\))?", "", source)
source = re.sub(r"@use_kernelized_func[\s]{0,}(\([^\)]{0,}\))?", "", source)
source = re.sub(r"@check_model_inputs[\s]{0,}(\([^\)]{0,}\))?", "", source)
# source = source.replace("@auto_docstring", "")

Expand Down
61 changes: 49 additions & 12 deletions unsloth_zoo/temporary_patches/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,26 @@ def _load_from_state_dict(
)


def patch_gpt_oss_compiler_exports():
model_name = os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, if our temporary_patches creates and references a class, compiler.py should automatically handle copying it over. But for now this is fine.

if "gpt_oss" not in model_name:
return
try:
import transformers.models.gpt_oss.modeling_gpt_oss
except Exception as e:
raise_error("transformers.models.gpt_oss.modeling_gpt_oss", e)
return

# Export helpers so compiler generated GPT-OSS modules can resolve symbols.
m = transformers.models.gpt_oss.modeling_gpt_oss
m.ParameterModule = ParameterModule
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only worry here is that this would leave from unsloth_zoo.temporary_patches.gpt_oss import ParameterModule in the unsloth_compiled_module_gpt_oss.py. Can you please verify where the import resolves to and it doesnot create a dependency cycle?

m.swiglu_torch_forward = swiglu_torch_forward
m.dtype_from_config = dtype_from_config
m.transformers_version = transformers_version
m.Version = Version
TEMPORARY_PATCHES.append(patch_gpt_oss_compiler_exports)


class GptOssExperts(nn.Module):
"""
GPT OSS MoE Experts layer with 3D stacked parameters.
Expand Down Expand Up @@ -1316,15 +1336,19 @@ def _should_use_gpt_oss_bnb4bit() -> bool:
Default: True when load_in_4bit is active.
Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():
return False
if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "_load_in_4bit_" not in _normalized_unsloth_model_name():
return False
return os.environ.get("UNSLOTH_GPT_OSS_BNB4BIT_DISABLE", "0") != "1"


def _is_gpt_oss_4bit_load() -> bool:
return "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", "")
return "_load_in_4bit_" in _normalized_unsloth_model_name()


def _normalized_unsloth_model_name() -> str:
return os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_")


def _is_transformers_v5() -> bool:
Expand All @@ -1340,7 +1364,7 @@ def patch_gpt_oss_moe_for_lora():
IMPORTANT: We only patch the forward method, NOT replace the entire class.
This preserves the original class structure so weights load correctly.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():
return
if _is_gpt_oss_4bit_load() or _should_use_gpt_oss_bnb4bit():
# 4-bit loads should keep quantized weights and use default PEFT LoRA.
Expand Down Expand Up @@ -1774,8 +1798,8 @@ def patch_gpt_oss_linearized():
Patch GPT OSS for 4bit loading with grouped_mm support.
Only patches the GptOssExperts forward method - keeps original classes for proper weight loading.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "gpt_oss" not in _normalized_unsloth_model_name(): return
if "_load_in_4bit_" not in _normalized_unsloth_model_name(): return
if _should_use_gpt_oss_bnb4bit(): return
try:
import transformers.models.gpt_oss.modeling_gpt_oss
Expand Down Expand Up @@ -1813,7 +1837,7 @@ def experts_forward(

def patch_GptOssAttention():
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "gpt_oss" not in _normalized_unsloth_model_name(): return
try:
from ..flex_attention import (
flex_attention_with_sink,
Expand Down Expand Up @@ -2054,7 +2078,7 @@ def forward(

def patch_GptOssModel():
if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return
if "gpt_oss" not in _normalized_unsloth_model_name(): return
try:
import transformers.models.gpt_oss.modeling_gpt_oss
transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel
Expand All @@ -2075,12 +2099,25 @@ def patch_GptOssModel():
import transformers.generation.utils
def wrap(f):
def return_attention_mask(*args, **kwargs):
if kwargs["input_embeds"].requires_grad:
input_embeds = kwargs.get("input_embeds", None)
if input_embeds is None:
input_embeds = kwargs.get("inputs_embeds", None)
if input_embeds is None:
for arg in args:
if type(arg) is torch.Tensor and arg.is_floating_point():
input_embeds = arg
break

if input_embeds is not None and input_embeds.requires_grad:
if "attention_mask" in kwargs:
return kwargs["attention_mask"]
for arg in args:
if type(arg) is torch.Tensor and arg.dtype == torch.int32:
if (
type(arg) is torch.Tensor and
arg.dtype in (torch.int32, torch.int64, torch.bool)
):
return arg
return f(*args, **kwargs)
else:
# Eager
return f(*args, **kwargs)
Expand Down Expand Up @@ -2739,7 +2776,7 @@ def patch_gpt_oss_config():


def patch_gpt_oss_init_weights_modulelist_fix():
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():
return
try:
import transformers.models.gpt_oss.modeling_gpt_oss
Expand Down Expand Up @@ -2784,7 +2821,7 @@ def patch_gpt_oss_for_grpo():
When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits.
This fixes the matrix multiplication dimension mismatch issue in GRPO training.
"""
if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""):
if "gpt_oss" not in _normalized_unsloth_model_name():
return

try:
Expand Down