Fix inf grad_norm on Qwen3.5 at seq_len > 65536 without flash-attn#582
Fix inf grad_norm on Qwen3.5 at seq_len > 65536 without flash-attn#582danielhanchen wants to merge 1 commit intomainfrom
Conversation
Adds patch_sdpa_bool_causal_mask to drop materialised [1, 1, Q, K] bool causal masks before dispatching to PyTorch SDPA, and instead use SDPA's native is_causal=True path. When patch_transformers_masks wraps create_causal_mask with torch.compile, find_packed_sequence_indices takes its is_tracing branch and forces allow_is_causal_skip=False, so the dense mask gets materialised. On builds without flash-attn, SDPA dispatches to the memory-efficient backend, which with bf16 and head_dim=256 produces wrong outputs and inf gradients at seq_len above 2**16. Routing through is_causal=True avoids that kernel path. Fixes unslothai/unsloth#4906.
| sdpa_attention_forward_unsloth_4906.__unsloth_bool_causal_mask_fix__ = True | ||
| sdpa_mod.sdpa_attention_forward = sdpa_attention_forward_unsloth_4906 | ||
| ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward_unsloth_4906 | ||
| pass |
| TEMPORARY_PATCHES.append(patch_transformers_masks) | ||
|
|
||
|
|
||
| def patch_sdpa_bool_causal_mask(): |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d79628e7f7
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| return _orig( | ||
| module, query, key, value, None, | ||
| dropout = dropout, scaling = scaling, is_causal = True, **kwargs, | ||
| ) |
There was a problem hiding this comment.
Preserve non-causal bool masks when calling SDPA
This rewrite drops attention_mask for any 4D square torch.bool mask and forces is_causal=True, but that predicate also matches valid non-causal variants (for example sliding-window or packed/padding-constrained masks) during prefill when Q == K. In those cases, the call silently changes semantics from the provided mask to full causal attention, which can alter model behavior and training results; this is especially relevant for hybrid mask flows that build per-layer masks (see the create_sliding_window_causal_mask path in unsloth_zoo/temporary_patches/gpt_oss.py). Please gate this rewrite on a stricter proof that the mask is the plain dense causal mask from the failing path before nulling it out.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request introduces a patch to address incorrect gradients and forward outputs in PyTorch SDPA when processing dense boolean causal masks for sequences longer than 65536. The implementation wraps the SDPA forward pass to replace such masks with is_causal=True. I have provided feedback suggesting additional guardrails to ensure this logic does not inadvertently affect non-causal models or mask out necessary padding/packing information.
| if ( | ||
| isinstance(m, torch.Tensor) | ||
| and m.dtype == torch.bool | ||
| and m.dim() == 4 | ||
| and m.shape[-1] == m.shape[-2] | ||
| and m.shape[-1] == query.shape[2] | ||
| ): |
There was a problem hiding this comment.
The current logic might incorrectly apply causal attention to non-causal models (like BERT or ModernBERT) if they pass a square boolean mask. Additionally, if a causal model uses padding or packing, replacing the mask with is_causal=True would lose that information, leading to incorrect attention. It is safer to verify that the module is intended to be causal and that the mask is a full triangle (not containing padding/packing zeros).
if (
isinstance(m, torch.Tensor)
and m.dtype == torch.bool
and m.dim() == 4
and m.shape[-1] == m.shape[-2]
and m.shape[-1] == query.shape[2]
and getattr(module, "is_causal", False)
and m[..., -1, 0].all()
):|
Absolutely stellar work. Can't wait to try this out. |
| dropout = dropout, scaling = scaling, is_causal = is_causal, **kwargs, | ||
| ) | ||
|
|
||
| sdpa_attention_forward_unsloth_4906.__unsloth_bool_causal_mask_fix__ = True |
There was a problem hiding this comment.
NIT: We don't need the 4906 suffix
| if ( | ||
| isinstance(m, torch.Tensor) | ||
| and m.dtype == torch.bool | ||
| and m.dim() == 4 |
There was a problem hiding this comment.
This might impact VLMs?
At the very least we need a small comment explaining the change and why it doesn't conflict with other model types and is compatible
|
tACK, confirmed working: |
|
Something I did notice though is that with no |
Fixes unslothai/unsloth#4906.
What was happening
LoRA training
unsloth/Qwen3.5-4Bandunsloth/Qwen3.5-9BwithSFTTrainerproducedgrad_norm=infatseq_len=69632and NaN atseq_len=73728on setups without a working flash-attn install. The threshold sat exactly on65536 = 2^16, matching the failure pattern in issue #4906.Both the original reporter (@djkazic) and I confirmed the bug does not appear when flash-attn is present (see the Axolotl-vs-Unsloth comparison in the issue, and @Datta0's note that the bug did not reproduce on a box with flash-attn 2.8.3 installed). The reporter also confirmed the bug persists after upgrading to
transformers==5.6.0.dev0andtorch==2.10.0when flash-attn is broken in the environment.Root cause
patch_transformers_masksinunsloth_zoo/temporary_patches/misc.pywrapstransformers.masking_utils.create_causal_maskwithtorch.compile(fullgraph=False, dynamic=True).transformers.masking_utils.find_packed_sequence_indicestakes itsis_tracingbranch and returns a non-Nonepacked_sequence_maskeven when position_ids are sequential.allow_is_causal_skip = Falseincreate_causal_mask, sosdpa_maskmaterialises a dense[1, 1, Q, K]torch.boolcausal mask instead of returning None.sdpa_attention_forward, which callsF.scaled_dot_product_attentionwithattn_mask=<bool>, is_causal=False.seq_len > 65536produces wrong forward outputs and wrong gradients.Forward loss at
seq_len=69632drops from the correct 9.4138 to 8.9536 when the bug triggers, so this is a forward-path bug too, not just a backward-path bug.grad_normgoes to inf, and atseq_len=73728some parameters get NaN gradients.Fix
Adds a defensive wrapper around
transformers.integrations.sdpa_attention.sdpa_attention_forward, registered throughALL_ATTENTION_FUNCTIONS["sdpa"]. When the incomingattention_maskis a full 4Dtorch.booltensor with Q == K (a plain materialised causal mask), the wrapper discards it and calls SDPA withattention_mask=None, is_causal=True. SDPA's native causal path does not hit the broken kernel and is also faster on these shapes.The wrapper has guardrails so it only triggers on:
torch.booltensor (the exact shapesdpa_maskproduces)query.shape[2](not a cross-attention mask or a cache slice)Any other mask shape or dtype is forwarded unmodified. For hybrid models that pass a dict of per-layer-type masks (e.g. Qwen3.5), the wrapper also unwraps the dict before the check.
Verification
Environment: B200, torch 2.9.1+cu128, transformers 5.2.0, unsloth 2026.4.4, xformers 0.0.33.post2, no flash-attn.
Qwen3.5-4B, bf16, LoRA r=16 + rslora,
use_gradient_checkpointing="unsloth":Qwen3.5-9B:
Short-context regression, 21 steps at
seq_len=4096, bsz=2, AdamW lr=2e-4, synthetic data (Qwen3.5-4B):Non-Qwen3.5 sanity check (Llama-3.2-1B-Instruct, 4-bit, 10 steps at
seq_len=2048, synthetic data):[1,1,Q,K]bool tensor at long seq_lens.Seed stability at Qwen3.5-4B
seq_len=69632:grad_norm=inf.The fix also speeds up fwd+bwd by 20 to 25% at
seq_len >= 65536because it skips the mask materialisation overhead.Files changed
unsloth_zoo/temporary_patches/misc.py— addspatch_sdpa_bool_causal_maskdirectly afterpatch_transformers_masksand registers it inTEMPORARY_PATCHES.Test plan
grad_normfinite atseq_len in {65536, 69632, 73728, 131072}on Qwen3.5-4B.grad_normfinite atseq_len in {65536, 73728}on Qwen3.5-9B.transformers==5.6.0.dev0+torch==2.10.0(the versions @Datta0 tested).