fix(qwen3): build causal mask batch-independently (#3582)#3586
fix(qwen3): build causal mask batch-independently (#3582)#3586pjdurden wants to merge 2 commits into
Conversation
`Model::causal_mask` built the additive mask buffer with `tgt*(tgt+offset)` elements (independent of the batch) but shaped it `(b, 1, tgt, tgt+offset)`. For `b > 1` the tensor claims b× the elements actually present, so every batch row but the first reads past the buffer and is masked incorrectly, producing wrong output for batched forwards (the bug disappears at b=1). This is hit on the standard matmul attention path (e.g. Metal), as reported in huggingface#3582. Extract the mask construction into a `build_causal_mask` free function that shapes the mask `(1, 1, tgt, tgt+offset)` and relies on the existing `broadcast_add` to apply it across the batch. Add a CPU regression test asserting the mask is batch-independent and broadcasts to a causal, per-row-identical mask. The same `(b, 1, tgt, ...)` pattern exists in several sibling models (qwen3_moe, quantized_qwen3{,_moe}, glm4_new, quantized_glm4, smollm3, z_image/text_encoder); happy to fix those in this PR or a follow-up. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
While fixing this I noticed the same I kept this PR scoped to |
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Fixes a batch-dependent corruption in Qwen3’s additive causal attention mask by constructing a batch-independent mask and relying on broadcasting, with a regression test to prevent reintroduction.
Changes:
- Extracted causal mask construction into a batch-independent
build_causal_maskhelper (shape(1, 1, tgt, tgt + offset)). - Updated
Model::forwardto use the new helper and avoid building a(b, ...)-shaped mask. - Added a regression test validating correct broadcasting across batch.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| (0..(tgt + offset)).map(move |j| { | ||
| let past_ok = j <= i + offset; | ||
| let sw_ok = match sw { | ||
| Some(w) => (i + offset) as i64 - j as i64 <= w as i64, | ||
| None => true, | ||
| }; |
There was a problem hiding this comment.
Switched to j + w >= i + offset — equivalent to the original check but stays in usize, so no signed casts and no subtraction underflow when j > i + offset. Also added a sliding-window regression test; the existing tests only covered the no-window path.
Address review feedback on huggingface#3586: the sliding-window check computed `(i + offset) as i64 - j as i64 <= w as i64`, casting to signed to allow a negative result when `j > i + offset`. Rearranged to `j + w >= i + offset` — equivalent, but stays in `usize` with no signed casts and no subtraction underflow. Add a sliding-window regression test; the prior tests only covered the no-window path. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
@ivarflakstad gentle nudge when you get a chance. Small qwen3 causal mask fix for batch greater than 1, with a regression test. |
Motivation
Closes #3582.
Qwen3produces incorrect output for batch size > 1.Model::causal_maskbuildsthe additive mask buffer with
tgt * (tgt + offset)elements — independent of thebatch — but then shapes it
(b, 1, tgt, tgt + offset).Tensor::from_slicewith afully-specified shape doesn't validate the element count (the blanket
ShapeWithOneHoleimpl ignores it), so forb > 1the tensor claimsb×theelements that actually exist in storage. Batch row 0 reads the correct mask; every
row after it reads past the buffer and is masked incorrectly. The mask only feeds
the standard matmul attention path via
scores.broadcast_add(m), so this surfacesthere (e.g. Metal) and disappears at
b = 1— matching the report.Modifications
build_causal_maskfree function thatshapes the mask
(1, 1, tgt, tgt + offset)and lets the existingbroadcast_addapply it across the batch. This is correct for anyband alsodrops the redundant per-batch allocation.
Model::forwardnow callsbuild_causal_maskdirectly; thebit used to passinto the mask is no longer needed.
Tests
causal_mask_is_batch_independent_and_broadcasts,asserting the mask carries a leading batch dim of 1 and, broadcast onto a
2-sequence batch, produces an identical causal mask for both rows. It fails on
the old
(b, 1, …)shape and passes after the fix.cargo test -p candle-transformers --libpasses (14 tests).cargo fmt --checkandcargo clippyare clean.