Skip to content

Fix inf grad_norm on Qwen3.5 at seq_len > 65536 with tighter SDPA guards#587

Open
danielhanchen wants to merge 2 commits intomainfrom
fix/issue-4906-sdpa-bool-mask-tighter-guards
Open

Fix inf grad_norm on Qwen3.5 at seq_len > 65536 with tighter SDPA guards#587
danielhanchen wants to merge 2 commits intomainfrom
fix/issue-4906-sdpa-bool-mask-tighter-guards

Conversation

@danielhanchen
Copy link
Copy Markdown
Contributor

Summary

  • Fixes #4906: LoRA training of Qwen3.5-4B/9B produces grad_norm=inf at seq_len > 65536 when flash-attn is not installed
  • Wraps sdpa_attention_forward to detect materialized bool causal masks and replace with attention_mask=None, is_causal=True
  • Includes tighter guards than Fix inf grad_norm on Qwen3.5 at seq_len > 65536 without flash-attn #582 to protect sliding-window, bidirectional, and packed-sequence masks
  • Adds bool-to-float fallback for non-pure-causal masks to avoid the Cutlass bug while preserving mask semantics

Root cause

patch_transformers_masks wraps create_causal_mask with torch.compile(dynamic=True). Under tracing, find_packed_sequence_indices takes the is_tracing branch, forcing allow_is_causal_skip=False. This materializes a dense [1,1,Q,K] bool causal mask. PyTorch SDPA's Cutlass backend with bf16 + bool mask at seq_len > 65536 produces wrong outputs/gradients (int16 sequence-index overflow at 2^16).

Guards (all O(1), no CUDA ops)

  1. module.is_causal check -- protects BERT/bidirectional encoders
  2. sliding_window kwarg check -- protects Gemma2/Mistral/Qwen2/3
  3. Dict-mask unwrap with safe fallback when layer_type not in dict
  4. 4D bool + square + Q==K shape check -- not cross-attention or kv-cache decode
  5. Upper-triangle spot-check m[0,0,0,1]==False -- distinguishes pure causal from packed-sequence masks
  6. Bool-to-float additive bias fallback for non-pure-causal masks -- avoids Cutlass bug while preserving mask semantics

Test results (NVIDIA B200, torch 2.9.1, transformers 5.5.1, no flash-attn)

Qwen3.5-4B at seq_len=69632, 4-bit LoRA, 9 steps:

Method Time Peak Mem Grad Norms [1st, 5th, 9th]
Without fix 3649s 74.27 GB [inf, nan, nan]
With this fix 2765s 74.27 GB [3.319, 3.219, 1.057]

Llama-3.2-1B regression at seq_len=2048, 21 steps:

  • Max loss delta vs baseline: 0.003
  • Max grad_norm delta vs baseline: 0.06
  • No regression

Test plan

  • Reproduce grad_norm=inf at seq_len=69632 without fix
  • Verify finite grad_norms with fix at seq_len=69632
  • Llama short-context regression (max loss delta < 0.01)
  • Sliding-window model test (Gemma2/Mistral)
  • Packed-sequence training test

When patch_transformers_masks wraps create_causal_mask with
torch.compile(dynamic=True), the is_tracing branch in
find_packed_sequence_indices materializes a dense [1,1,Q,K] bool causal
mask. PyTorch SDPA's Cutlass backend with bf16 + bool mask at
seq_len > 65536 produces wrong outputs and gradients (int16 overflow).

Wrap sdpa_attention_forward to detect materialized bool causal masks and
replace with attention_mask=None, is_causal=True. Unlike the original
approach, this version includes tighter guards:

- Check module.is_causal to protect BERT/bidirectional encoders
- Check sliding_window kwarg to protect Gemma2/Mistral/Qwen2
- Spot-check upper triangle to distinguish pure causal from packed masks
- Convert non-pure-causal bool masks to float additive bias as fallback

Fixes unslothai/unsloth#4906
TEMPORARY_PATCHES.append(patch_transformers_masks)


def patch_sdpa_bool_causal_mask():
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 37d5da764e

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

# (first query cannot see second key). Packed-sequence masks and other
# non-trivial patterns will have True in the upper triangle.
S = m.shape[-1]
is_pure_causal = (S < 2) or (not m[0, 0, 0, 1].item())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Do not classify packed masks via a single upper-triangle check

is_pure_causal is inferred from only m[0,0,0,1], but packed-sequence and padding-constrained causal masks also have this entry set to False due to the base causal rule (kv_idx > q_idx is always masked). In those cases this branch misclassifies the mask as “pure causal” and drops it, so SDPA runs with attention_mask=None, is_causal=True and loses sequence/padding boundaries. That can silently let tokens attend across packed samples (or into padded tokens) and corrupt training/inference whenever packed or left-padded batches are used.

Useful? React with 👍 / 👎.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

The single-cell check m[0,0,0,1]==False is insufficient: packed masks
like [[1,0,0,0],[1,1,0,0],[0,0,1,0],[0,0,1,1]] also satisfy it.

Add m[0,0,-1,0]==True (last query can see first key) which is true for
pure lower-triangular causal masks but false for packed-sequence masks,
padded masks, and sliding-window masks. Both checks are O(1).
@danielhanchen
Copy link
Copy Markdown
Contributor Author

Updated with a second commit to address the packed-sequence mask regression identified by review.

What changed: The single-cell spot-check m[0,0,0,1]==False was insufficient -- packed masks like [[1,0,0,0],[1,1,0,0],[0,0,1,0],[0,0,1,1]] also satisfy it. Added a second O(1) check: m[0,0,-1,0]==True (last query can see first key), which correctly rejects packed, padded, and sliding-window masks while accepting pure lower-triangular causal masks.

Verification:

  • Unit test confirms all 5 mask types are correctly classified:
    • Pure causal: accepted (is_causal=True fast path)
    • Packed sequence: rejected (falls through to float fallback)
    • Right-padded: rejected (falls through to float fallback)
    • Left-padded: rejected (falls through to float fallback)
    • Sliding window: rejected (falls through to float fallback)
  • GPU test: Qwen3.5-4B at seq_len=69632, 9 steps -- all grad_norms finite (3.319, 3.320, 3.012, 2.932, 3.209, 3.671, 2.848, 1.612, 1.058)

sdpa_attention_forward_unsloth.__unsloth_bool_causal_mask_fix__ = True
sdpa_mod.sdpa_attention_forward = sdpa_attention_forward_unsloth
ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward_unsloth
pass
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f54a06b00b

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

# Convert bool -> float additive bias to avoid the Cutlass bool-mask
# bug while preserving mask semantics. SDPA dispatches to a different
# (working) kernel for float attn_mask inputs.
m_float = torch.where(m, 0.0, torch.finfo(query.dtype).min).to(query.dtype)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Use -inf when converting bool masks to additive bias

The bool→float fallback uses torch.finfo(query.dtype).min, which is finite, so rows that are fully masked (a common case for left-padded causal masks) no longer behave like the original boolean mask. In SDPA, an all-finite masked row becomes an all-equal logit row and produces a uniform attention mix instead of zero output, so this path does not preserve mask semantics for padded/custom masks and can change hidden states whenever is_pure_causal is false.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant