Skip to content

Fix Qwen3 causal mask for batch size > 1#3610

Open
NahButch wants to merge 1 commit into
huggingface:mainfrom
NahButch:qwen3-batch-mask
Open

Fix Qwen3 causal mask for batch size > 1#3610
NahButch wants to merge 1 commit into
huggingface:mainfrom
NahButch:qwen3-batch-mask

Conversation

@NahButch

Copy link
Copy Markdown

Model::causal_mask filled tgt * (tgt + offset) elements but declared the shape (b, 1, tgt, tgt + offset), so only batch row 0 carried a valid mask and rows >= 1 attended over a corrupt, effectively non-causal pattern — two identical sequences in one batch produce different outputs (row 0 correct, row 1 garbage).

Build the mask as (1, 1, tgt, tgt + offset) and expand it across the batch (a zero-copy view). Adds a unit test asserting every batch row gets the correct causal pattern; it fails on the previous behavior.

Related to #3609 (Qwen2 mask fix); this one is the batch-dimension variant in Qwen3.

Fixes #3582

🤖 Generated with Claude Code

causal_mask filled tgt * (tgt + offset) elements but declared the shape
(b, 1, tgt, tgt + offset), so only batch row 0 carried a valid mask and
rows >= 1 attended over a corrupt, effectively non-causal pattern.

Build the mask as (1, 1, tgt, tgt + offset) and expand it across the
batch (a zero-copy view). Adds a unit test asserting every batch row
gets the correct causal pattern; it fails on the previous behavior.

Fixes huggingface#3582

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Qwen3: causal_mask shaped (b, 1, tgt, tgt) from a b-independent buffer → wrong output for batch size > 1

1 participant