Skip to content

fix(qwen3): build causal mask batch-independently (#3582)#3586

Open
pjdurden wants to merge 2 commits into
huggingface:mainfrom
pjdurden:fix/qwen3-causal-mask-batch-gt-1
Open

fix(qwen3): build causal mask batch-independently (#3582)#3586
pjdurden wants to merge 2 commits into
huggingface:mainfrom
pjdurden:fix/qwen3-causal-mask-batch-gt-1

Conversation

@pjdurden

@pjdurden pjdurden commented Jun 6, 2026

Copy link
Copy Markdown

Motivation

Closes #3582.

Qwen3 produces incorrect output for batch size > 1. Model::causal_mask builds
the additive mask buffer with tgt * (tgt + offset) elements — independent of the
batch — but then shapes it (b, 1, tgt, tgt + offset). Tensor::from_slice with a
fully-specified shape doesn't validate the element count (the blanket
ShapeWithOneHole impl ignores it), so for b > 1 the tensor claims the
elements that actually exist in storage. Batch row 0 reads the correct mask; every
row after it reads past the buffer and is masked incorrectly. The mask only feeds
the standard matmul attention path via scores.broadcast_add(m), so this surfaces
there (e.g. Metal) and disappears at b = 1 — matching the report.

Modifications

  • Extracted the mask construction into a build_causal_mask free function that
    shapes the mask (1, 1, tgt, tgt + offset) and lets the existing
    broadcast_add apply it across the batch. This is correct for any b and also
    drops the redundant per-batch allocation.
  • Model::forward now calls build_causal_mask directly; the b it used to pass
    into the mask is no longer needed.

Tests

  • Added regression test causal_mask_is_batch_independent_and_broadcasts,
    asserting the mask carries a leading batch dim of 1 and, broadcast onto a
    2-sequence batch, produces an identical causal mask for both rows. It fails on
    the old (b, 1, …) shape and passes after the fix.
  • cargo test -p candle-transformers --lib passes (14 tests).
  • cargo fmt --check and cargo clippy are clean.

`Model::causal_mask` built the additive mask buffer with `tgt*(tgt+offset)`
elements (independent of the batch) but shaped it `(b, 1, tgt, tgt+offset)`.
For `b > 1` the tensor claims b× the elements actually present, so every batch
row but the first reads past the buffer and is masked incorrectly, producing
wrong output for batched forwards (the bug disappears at b=1). This is hit on
the standard matmul attention path (e.g. Metal), as reported in huggingface#3582.

Extract the mask construction into a `build_causal_mask` free function that
shapes the mask `(1, 1, tgt, tgt+offset)` and relies on the existing
`broadcast_add` to apply it across the batch. Add a CPU regression test
asserting the mask is batch-independent and broadcasts to a causal,
per-row-identical mask.

The same `(b, 1, tgt, ...)` pattern exists in several sibling models
(qwen3_moe, quantized_qwen3{,_moe}, glm4_new, quantized_glm4, smollm3,
z_image/text_encoder); happy to fix those in this PR or a follow-up.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 6, 2026 23:31
@pjdurden

pjdurden commented Jun 6, 2026

Copy link
Copy Markdown
Author

While fixing this I noticed the same (b, 1, tgt, …) pattern — a batch-independent mask buffer shaped with the batch dimension — in several other models: qwen3_moe, quantized_qwen3, quantized_qwen3_moe, glm4_new, quantized_glm4, smol/smollm3, and z_image/text_encoder. They look like the same copy-pasted bug and would hit the same way for batch size > 1.

I kept this PR scoped to qwen3 (the reported model) for easy review, but I'm happy to fix the rest — either folded into this PR or as a follow-up. Just let me know which you'd prefer.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Fixes a batch-dependent corruption in Qwen3’s additive causal attention mask by constructing a batch-independent mask and relying on broadcasting, with a regression test to prevent reintroduction.

Changes:

  • Extracted causal mask construction into a batch-independent build_causal_mask helper (shape (1, 1, tgt, tgt + offset)).
  • Updated Model::forward to use the new helper and avoid building a (b, ...)-shaped mask.
  • Added a regression test validating correct broadcasting across batch.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +424 to +429
(0..(tgt + offset)).map(move |j| {
let past_ok = j <= i + offset;
let sw_ok = match sw {
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
None => true,
};

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to j + w >= i + offset — equivalent to the original check but stays in usize, so no signed casts and no subtraction underflow when j > i + offset. Also added a sliding-window regression test; the existing tests only covered the no-window path.

Address review feedback on huggingface#3586: the sliding-window check computed
`(i + offset) as i64 - j as i64 <= w as i64`, casting to signed to allow a
negative result when `j > i + offset`. Rearranged to `j + w >= i + offset` —
equivalent, but stays in `usize` with no signed casts and no subtraction
underflow. Add a sliding-window regression test; the prior tests only covered
the no-window path.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@pjdurden

Copy link
Copy Markdown
Author

@ivarflakstad gentle nudge when you get a chance. Small qwen3 causal mask fix for batch greater than 1, with a regression test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Qwen3: causal_mask shaped (b, 1, tgt, tgt) from a b-independent buffer → wrong output for batch size > 1

2 participants