Skip to content

Fix inf grad_norm on Qwen3.5 at seq_len > 65536 without flash-attn#582

Open
danielhanchen wants to merge 1 commit intomainfrom
fix/issue-4906-qwen35-sdpa-bool-mask
Open

Fix inf grad_norm on Qwen3.5 at seq_len > 65536 without flash-attn#582
danielhanchen wants to merge 1 commit intomainfrom
fix/issue-4906-qwen35-sdpa-bool-mask

Conversation

@danielhanchen
Copy link
Copy Markdown
Contributor

Fixes unslothai/unsloth#4906.

What was happening

LoRA training unsloth/Qwen3.5-4B and unsloth/Qwen3.5-9B with SFTTrainer produced grad_norm=inf at seq_len=69632 and NaN at seq_len=73728 on setups without a working flash-attn install. The threshold sat exactly on 65536 = 2^16, matching the failure pattern in issue #4906.

Both the original reporter (@djkazic) and I confirmed the bug does not appear when flash-attn is present (see the Axolotl-vs-Unsloth comparison in the issue, and @Datta0's note that the bug did not reproduce on a box with flash-attn 2.8.3 installed). The reporter also confirmed the bug persists after upgrading to transformers==5.6.0.dev0 and torch==2.10.0 when flash-attn is broken in the environment.

Root cause

  1. patch_transformers_masks in unsloth_zoo/temporary_patches/misc.py wraps transformers.masking_utils.create_causal_mask with torch.compile(fullgraph=False, dynamic=True).
  2. Under that compile wrap, transformers.masking_utils.find_packed_sequence_indices takes its is_tracing branch and returns a non-None packed_sequence_mask even when position_ids are sequential.
  3. That non-None value sets allow_is_causal_skip = False in create_causal_mask, so sdpa_mask materialises a dense [1, 1, Q, K] torch.bool causal mask instead of returning None.
  4. The mask is then passed to sdpa_attention_forward, which calls F.scaled_dot_product_attention with attn_mask=<bool>, is_causal=False.
  5. On builds without a working flash-attn package, SDPA cannot use FLASH_ATTENTION for Qwen3.5 (head_dim=256). CUDNN_ATTENTION also refuses head_dim=256. SDPA dispatches to the memory-efficient (Cutlass) backend. That backend on bf16 + head_dim=256 + a dense bool causal mask at seq_len > 65536 produces wrong forward outputs and wrong gradients.

Forward loss at seq_len=69632 drops from the correct 9.4138 to 8.9536 when the bug triggers, so this is a forward-path bug too, not just a backward-path bug. grad_norm goes to inf, and at seq_len=73728 some parameters get NaN gradients.

Fix

Adds a defensive wrapper around transformers.integrations.sdpa_attention.sdpa_attention_forward, registered through ALL_ATTENTION_FUNCTIONS["sdpa"]. When the incoming attention_mask is a full 4D torch.bool tensor with Q == K (a plain materialised causal mask), the wrapper discards it and calls SDPA with attention_mask=None, is_causal=True. SDPA's native causal path does not hit the broken kernel and is also faster on these shapes.

The wrapper has guardrails so it only triggers on:

  • a 4D torch.bool tensor (the exact shape sdpa_mask produces)
  • last two dims equal (square mask)
  • last dim equal to query.shape[2] (not a cross-attention mask or a cache slice)

Any other mask shape or dtype is forwarded unmodified. For hybrid models that pass a dict of per-layer-type masks (e.g. Qwen3.5), the wrapper also unwraps the dict before the check.

Verification

Environment: B200, torch 2.9.1+cu128, transformers 5.2.0, unsloth 2026.4.4, xformers 0.0.33.post2, no flash-attn.

Qwen3.5-4B, bf16, LoRA r=16 + rslora, use_gradient_checkpointing="unsloth":

seq_len before after
65536 loss 9.4130, grad_norm 6.3918 loss 9.4131, grad_norm 6.3945
69632 loss 8.9536, grad_norm inf loss 9.4138, grad_norm 6.4041
73728 loss 9.4339, grad_norm inf, 12 NaN params loss 9.4092, grad_norm 6.3995
131072 (not tested) loss 9.4091, grad_norm 6.4435

Qwen3.5-9B:

seq_len before after
65536 loss 9.5640, grad_norm 5.4012 loss 9.5641, grad_norm 5.4012
73728 loss 9.5877, grad_norm inf, 3 inf + 22 NaN params loss 9.5584, grad_norm 5.4170

Short-context regression, 21 steps at seq_len=4096, bsz=2, AdamW lr=2e-4, synthetic data (Qwen3.5-4B):

  • Max absolute loss delta with vs without fix: 0.0005 (0.006% relative).
  • Grad-norm delta: within 1.5% per step.

Non-Qwen3.5 sanity check (Llama-3.2-1B-Instruct, 4-bit, 10 steps at seq_len=2048, synthetic data):

  • Max loss delta with vs without fix: 0.0005 (0.005% relative).
  • The fix is a no-op for Llama since its mask path does not produce the [1,1,Q,K] bool tensor at long seq_lens.

Seed stability at Qwen3.5-4B seq_len=69632:

  • Without fix, seeds 3407, 1234, 42 all produce grad_norm=inf.
  • With fix, seeds 3407, 1234, 42 produce finite grad_norms 6.4041, 6.3545, 6.4087.

The fix also speeds up fwd+bwd by 20 to 25% at seq_len >= 65536 because it skips the mask materialisation overhead.

Files changed

  • unsloth_zoo/temporary_patches/misc.py — adds patch_sdpa_bool_causal_mask directly after patch_transformers_masks and registers it in TEMPORARY_PATCHES.

Test plan

  • grad_norm finite at seq_len in {65536, 69632, 73728, 131072} on Qwen3.5-4B.
  • grad_norm finite at seq_len in {65536, 73728} on Qwen3.5-9B.
  • Short-context training (21 steps) loss curve matches pre-fix to within 0.006%.
  • Non-Qwen3.5 sanity (Llama 3.2 1B) loss curve matches pre-fix to within 0.005%.
  • Stable across seeds 3407, 1234, 42.
  • To be confirmed by reviewer: no-op on transformers==5.6.0.dev0 + torch==2.10.0 (the versions @Datta0 tested).

Adds patch_sdpa_bool_causal_mask to drop materialised [1, 1, Q, K] bool
causal masks before dispatching to PyTorch SDPA, and instead use SDPA's
native is_causal=True path. When patch_transformers_masks wraps
create_causal_mask with torch.compile, find_packed_sequence_indices
takes its is_tracing branch and forces allow_is_causal_skip=False, so
the dense mask gets materialised. On builds without flash-attn, SDPA
dispatches to the memory-efficient backend, which with bf16 and
head_dim=256 produces wrong outputs and inf gradients at seq_len above
2**16. Routing through is_causal=True avoids that kernel path.

Fixes unslothai/unsloth#4906.
sdpa_attention_forward_unsloth_4906.__unsloth_bool_causal_mask_fix__ = True
sdpa_mod.sdpa_attention_forward = sdpa_attention_forward_unsloth_4906
ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward_unsloth_4906
pass
TEMPORARY_PATCHES.append(patch_transformers_masks)


def patch_sdpa_bool_causal_mask():
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d79628e7f7

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +582 to +585
return _orig(
module, query, key, value, None,
dropout = dropout, scaling = scaling, is_causal = True, **kwargs,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve non-causal bool masks when calling SDPA

This rewrite drops attention_mask for any 4D square torch.bool mask and forces is_causal=True, but that predicate also matches valid non-causal variants (for example sliding-window or packed/padding-constrained masks) during prefill when Q == K. In those cases, the call silently changes semantics from the provided mask to full causal attention, which can alter model behavior and training results; this is especially relevant for hybrid mask flows that build per-layer masks (see the create_sliding_window_causal_mask path in unsloth_zoo/temporary_patches/gpt_oss.py). Please gate this rewrite on a stricter proof that the mask is the plain dense causal mask from the failing path before nulling it out.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a patch to address incorrect gradients and forward outputs in PyTorch SDPA when processing dense boolean causal masks for sequences longer than 65536. The implementation wraps the SDPA forward pass to replace such masks with is_causal=True. I have provided feedback suggesting additional guardrails to ensure this logic does not inadvertently affect non-causal models or mask out necessary padding/packing information.

Comment on lines +575 to +581
if (
isinstance(m, torch.Tensor)
and m.dtype == torch.bool
and m.dim() == 4
and m.shape[-1] == m.shape[-2]
and m.shape[-1] == query.shape[2]
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic might incorrectly apply causal attention to non-causal models (like BERT or ModernBERT) if they pass a square boolean mask. Additionally, if a causal model uses padding or packing, replacing the mask with is_causal=True would lose that information, leading to incorrect attention. It is safer to verify that the module is intended to be causal and that the mask is a full triangle (not containing padding/packing zeros).

        if (
            isinstance(m, torch.Tensor)
            and m.dtype == torch.bool
            and m.dim() == 4
            and m.shape[-1] == m.shape[-2]
            and m.shape[-1] == query.shape[2]
            and getattr(module, "is_causal", False)
            and m[..., -1, 0].all()
        ):

@djkazic
Copy link
Copy Markdown

djkazic commented Apr 9, 2026

Absolutely stellar work. Can't wait to try this out.

dropout = dropout, scaling = scaling, is_causal = is_causal, **kwargs,
)

sdpa_attention_forward_unsloth_4906.__unsloth_bool_causal_mask_fix__ = True
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: We don't need the 4906 suffix

if (
isinstance(m, torch.Tensor)
and m.dtype == torch.bool
and m.dim() == 4
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might impact VLMs?
At the very least we need a small comment explaining the change and why it doesn't conflict with other model types and is compatible

@djkazic
Copy link
Copy Markdown

djkazic commented Apr 9, 2026

tACK, confirmed working:

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
ERROR 04-09 15:18:35 [gpt_oss_triton_kernels_moe.py:34] Failed to import Triton kernels. Please make sure your triton version is compatible. Error: No module named 'triton_kernels.routing'
🦥 Unsloth Zoo will now patch everything to make training faster!
Loading unsloth/Qwen3.5-9B...
==((====))==  Unsloth 2026.4.4: Fast Qwen3_5 patching. Transformers: 5.2.0. vLLM: 0.16.0.
   \\   /|    NVIDIA RTX PRO 6000 Blackwell Server Edition. Num GPUs = 1. Max memory: 94.981 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.1+cu128. CUDA: 12.0. CUDA Toolkit: 12.8. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Loading weights: 100% 760/760 [00:20<00:00, 37.88it/s, Materializing param=model.visual.pos_embed.weight]
Unsloth: Making `model.base_model.model.model.language_model` require gradients
Using unsloth gradient checkpointing
Model loaded. Testing sequence lengths...

Unsloth: Will smartly offload gradients to save VRAM!
seq_len= 65536: loss=9.5637 fwd_nan=False grad_nan=False grad_norm=49.14
seq_len= 69632: loss=9.6417 fwd_nan=False grad_nan=False grad_norm=56.65
seq_len= 73728: loss=9.5687 fwd_nan=False grad_nan=False grad_norm=47.69
seq_len= 98304: loss=9.5703 fwd_nan=False grad_nan=False grad_norm=49.71

@djkazic
Copy link
Copy Markdown

djkazic commented Apr 9, 2026

Something I did notice though is that with no flash-attn (but flash-linear-attention and causal-conv1d installed) I can't seem to hit the same seq len as axolotl. Might need to do some more testing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Gradient explosion (inf/NaN) training Qwen 3.5 at >65536 sequence length

3 participants