Fix inf grad_norm on Qwen3.5 at seq_len > 65536 with tighter SDPA guards#587
Fix inf grad_norm on Qwen3.5 at seq_len > 65536 with tighter SDPA guards#587danielhanchen wants to merge 2 commits intomainfrom
Conversation
When patch_transformers_masks wraps create_causal_mask with torch.compile(dynamic=True), the is_tracing branch in find_packed_sequence_indices materializes a dense [1,1,Q,K] bool causal mask. PyTorch SDPA's Cutlass backend with bf16 + bool mask at seq_len > 65536 produces wrong outputs and gradients (int16 overflow). Wrap sdpa_attention_forward to detect materialized bool causal masks and replace with attention_mask=None, is_causal=True. Unlike the original approach, this version includes tighter guards: - Check module.is_causal to protect BERT/bidirectional encoders - Check sliding_window kwarg to protect Gemma2/Mistral/Qwen2 - Spot-check upper triangle to distinguish pure causal from packed masks - Convert non-pure-causal bool masks to float additive bias as fallback Fixes unslothai/unsloth#4906
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 37d5da764e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # (first query cannot see second key). Packed-sequence masks and other | ||
| # non-trivial patterns will have True in the upper triangle. | ||
| S = m.shape[-1] | ||
| is_pure_causal = (S < 2) or (not m[0, 0, 0, 1].item()) |
There was a problem hiding this comment.
Do not classify packed masks via a single upper-triangle check
is_pure_causal is inferred from only m[0,0,0,1], but packed-sequence and padding-constrained causal masks also have this entry set to False due to the base causal rule (kv_idx > q_idx is always masked). In those cases this branch misclassifies the mask as “pure causal” and drops it, so SDPA runs with attention_mask=None, is_causal=True and loses sequence/padding boundaries. That can silently let tokens attend across packed samples (or into padded tokens) and corrupt training/inference whenever packed or left-padded batches are used.
Useful? React with 👍 / 👎.
|
Warning Gemini encountered an error creating the review. You can try again by commenting |
The single-cell check m[0,0,0,1]==False is insufficient: packed masks like [[1,0,0,0],[1,1,0,0],[0,0,1,0],[0,0,1,1]] also satisfy it. Add m[0,0,-1,0]==True (last query can see first key) which is true for pure lower-triangular causal masks but false for packed-sequence masks, padded masks, and sliding-window masks. Both checks are O(1).
|
Updated with a second commit to address the packed-sequence mask regression identified by review. What changed: The single-cell spot-check Verification:
|
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f54a06b00b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # Convert bool -> float additive bias to avoid the Cutlass bool-mask | ||
| # bug while preserving mask semantics. SDPA dispatches to a different | ||
| # (working) kernel for float attn_mask inputs. | ||
| m_float = torch.where(m, 0.0, torch.finfo(query.dtype).min).to(query.dtype) |
There was a problem hiding this comment.
Use -inf when converting bool masks to additive bias
The bool→float fallback uses torch.finfo(query.dtype).min, which is finite, so rows that are fully masked (a common case for left-padded causal masks) no longer behave like the original boolean mask. In SDPA, an all-finite masked row becomes an all-equal logit row and produces a uniform attention mix instead of zero output, so this path does not preserve mask semantics for padded/custom masks and can change hidden states whenever is_pure_causal is false.
Useful? React with 👍 / 👎.
Summary
grad_norm=infatseq_len > 65536when flash-attn is not installedsdpa_attention_forwardto detect materialized bool causal masks and replace withattention_mask=None, is_causal=TrueRoot cause
patch_transformers_maskswrapscreate_causal_maskwithtorch.compile(dynamic=True). Under tracing,find_packed_sequence_indicestakes theis_tracingbranch, forcingallow_is_causal_skip=False. This materializes a dense[1,1,Q,K]bool causal mask. PyTorch SDPA's Cutlass backend with bf16 + bool mask atseq_len > 65536produces wrong outputs/gradients (int16 sequence-index overflow at 2^16).Guards (all O(1), no CUDA ops)
module.is_causalcheck -- protects BERT/bidirectional encoderssliding_windowkwarg check -- protects Gemma2/Mistral/Qwen2/3layer_typenot in dictm[0,0,0,1]==False-- distinguishes pure causal from packed-sequence masksTest results (NVIDIA B200, torch 2.9.1, transformers 5.5.1, no flash-attn)
Qwen3.5-4B at seq_len=69632, 4-bit LoRA, 9 steps:
Llama-3.2-1B regression at seq_len=2048, 21 steps:
Test plan
grad_norm=infat seq_len=69632 without fix