Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions unsloth_zoo/temporary_patches/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,91 @@ def inner_mask(batch_idx, head_idx, q_idx, kv_idx):
TEMPORARY_PATCHES.append(patch_transformers_masks)


def patch_sdpa_bool_causal_mask():
"""Drop materialised bool causal masks before dispatching to SDPA.

Fixes unslothai/unsloth#4906 — inf grad_norm on Qwen3.5 at seq_len > 65536.

When `patch_transformers_masks` (above) wraps `create_causal_mask` with
`torch.compile(dynamic=True)`, HF's `find_packed_sequence_indices` takes
the `is_tracing` branch and returns a non-None packed_sequence_mask even
for sequential positions. That sets `allow_is_causal_skip=False` inside
`create_causal_mask`, so `sdpa_mask` materialises a dense [1, 1, Q, K]
bool causal mask instead of returning None.

On builds that fall through to PyTorch SDPA's memory-efficient (Cutlass)
backend — e.g. head_dim=256 models without flash-attn installed —
that backend with bf16 + a dense bool causal mask at seq_len > 65536
produces wrong forward outputs AND wrong gradients. The 65536 = 2**16
boundary matches a hard kernel limit.

Fix: wrap `sdpa_attention_forward` so that when the incoming
attention_mask is a full 4D bool tensor with Q == K (i.e. a plain
materialised causal mask), we discard it and call SDPA with
`attention_mask=None, is_causal=True`. SDPA's native causal path does
not hit the broken backward kernel and is also faster on these shapes.

Guardrails make this a no-op for any other mask:
- Only trigger on a 4D torch.bool tensor.
- Last two dims must be equal (square mask, not a kv-cache slice).
- Last dim must equal query.shape[2] (not cross-attention or a
past-kv decode step).
Any other mask shape/dtype is forwarded unmodified.
"""
if os.environ.get("UNSLOTH_COMPILE_DISABLE", "0") == "1":
return
try:
import transformers.integrations.sdpa_attention as sdpa_mod
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
except Exception as e:
return raise_error("transformers.integrations.sdpa_attention", e)

current = getattr(sdpa_mod, "sdpa_attention_forward", None)
if current is None:
return
if getattr(current, "__unsloth_bool_causal_mask_fix__", False):
return # already installed

_orig = current

def sdpa_attention_forward_unsloth_4906(
module,
query,
key,
value,
attention_mask,
dropout = 0.0,
scaling = None,
is_causal = None,
**kwargs,
):
# Hybrid models (e.g. Qwen3.5) pass a dict keyed by layer type.
m = attention_mask
if isinstance(m, dict):
m = m.get(getattr(module, "layer_type", None), None)
if (
isinstance(m, torch.Tensor)
and m.dtype == torch.bool
and m.dim() == 4
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might impact VLMs?
At the very least we need a small comment explaining the change and why it doesn't conflict with other model types and is compatible

and m.shape[-1] == m.shape[-2]
and m.shape[-1] == query.shape[2]
):
Comment on lines +575 to +581
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic might incorrectly apply causal attention to non-causal models (like BERT or ModernBERT) if they pass a square boolean mask. Additionally, if a causal model uses padding or packing, replacing the mask with is_causal=True would lose that information, leading to incorrect attention. It is safer to verify that the module is intended to be causal and that the mask is a full triangle (not containing padding/packing zeros).

        if (
            isinstance(m, torch.Tensor)
            and m.dtype == torch.bool
            and m.dim() == 4
            and m.shape[-1] == m.shape[-2]
            and m.shape[-1] == query.shape[2]
            and getattr(module, "is_causal", False)
            and m[..., -1, 0].all()
        ):

return _orig(
module, query, key, value, None,
dropout = dropout, scaling = scaling, is_causal = True, **kwargs,
)
Comment on lines +582 to +585
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve non-causal bool masks when calling SDPA

This rewrite drops attention_mask for any 4D square torch.bool mask and forces is_causal=True, but that predicate also matches valid non-causal variants (for example sliding-window or packed/padding-constrained masks) during prefill when Q == K. In those cases, the call silently changes semantics from the provided mask to full causal attention, which can alter model behavior and training results; this is especially relevant for hybrid mask flows that build per-layer masks (see the create_sliding_window_causal_mask path in unsloth_zoo/temporary_patches/gpt_oss.py). Please gate this rewrite on a stricter proof that the mask is the plain dense causal mask from the failing path before nulling it out.

Useful? React with 👍 / 👎.

return _orig(
module, query, key, value, attention_mask,
dropout = dropout, scaling = scaling, is_causal = is_causal, **kwargs,
)

sdpa_attention_forward_unsloth_4906.__unsloth_bool_causal_mask_fix__ = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: We don't need the 4906 suffix

sdpa_mod.sdpa_attention_forward = sdpa_attention_forward_unsloth_4906
ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward_unsloth_4906
pass
TEMPORARY_PATCHES.append(patch_sdpa_bool_causal_mask)


def patch_modernbert_attention_mask():
"""Fix ModernBERT attn_bias stride alignment for SDPA backward pass.

Expand Down
Loading