Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 84 additions & 29 deletions candle-transformers/src/models/qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,41 @@ impl DecoderLayer {
}
}

/// Builds the additive causal attention mask of shape `(1, 1, tgt, tgt + offset)`.
///
/// The mask values depend only on the query/key positions, never on the batch,
/// so it is built with a leading batch dim of `1` and broadcast across the batch
/// by `broadcast_add` in attention. Shaping this `b`-independent buffer as
/// `(b, 1, tgt, tgt + offset)` claims `b×` the elements actually present, so every
/// batch row but the first reads past the buffer and is masked incorrectly
/// (see https://github.com/huggingface/candle/issues/3582).
fn build_causal_mask(
tgt: usize,
offset: usize,
sw: Option<usize>,
device: &Device,
dtype: DType,
) -> Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| {
(0..(tgt + offset)).map(move |j| {
let past_ok = j <= i + offset;
let sw_ok = match sw {
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
None => true,
};
Comment on lines +424 to +432

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to j + w >= i + offset — equivalent to the original check but stays in usize, so no signed casts and no subtraction underflow when j > i + offset. Also added a sliding-window regression test; the existing tests only covered the no-window path.

if past_ok && sw_ok {
0.
} else {
minf
}
})
})
.collect();
Tensor::from_slice(&mask, (1, 1, tgt, tgt + offset), device)?.to_dtype(dtype)
}

#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
Expand Down Expand Up @@ -437,35 +472,8 @@ impl Model {
}
}

fn causal_mask(
&self,
b: usize,
tgt: usize,
offset: usize,
sw: Option<usize>,
) -> Result<Tensor> {
let minf = f32::NEG_INFINITY;
let mask: Vec<_> = (0..tgt)
.flat_map(|i| {
(0..(tgt + offset)).map(move |j| {
let past_ok = j <= i + offset;
let sw_ok = match sw {
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
None => true,
};
if past_ok && sw_ok {
0.
} else {
minf
}
})
})
.collect();
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
}

pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
let (b, l) = input.dims2()?;
let (_b, l) = input.dims2()?;
let mut h = self.embed_tokens.forward(input)?;

// Build causal mask only for the standard attention fallback path.
Expand All @@ -475,7 +483,13 @@ impl Model {
#[cfg(feature = "flash-attn")]
let needs_mask = self.device.is_cpu() && l > 1;
let causal = if needs_mask {
Some(self.causal_mask(b, l, offset, None)?)
Some(build_causal_mask(
l,
offset,
None,
&self.device,
self.dtype,
)?)
} else {
None
};
Expand Down Expand Up @@ -516,3 +530,44 @@ impl ModelForCausalLM {
self.base.clear_kv_cache();
}
}

#[cfg(test)]
mod tests {
use super::*;

// Regression test for https://github.com/huggingface/candle/issues/3582:
// the additive causal mask is independent of the batch dimension, so it must
// be built with a leading batch dim of 1 and broadcast across the batch.
// Shaping it `(b, 1, tgt, tgt + offset)` from a `b`-independent buffer claims
// `b×` the elements actually present, corrupting every batch row but the first.
#[test]
fn causal_mask_is_batch_independent_and_broadcasts() {
let device = Device::Cpu;
let neg = f32::NEG_INFINITY;

let mask = build_causal_mask(3, 0, None, &device, DType::F32).unwrap();
assert_eq!(
mask.dims(),
&[1, 1, 3, 3],
"mask must carry a leading batch dim of 1 so broadcast_add applies it to every row",
);

// Broadcasting onto a 2-sequence batch must mask both rows identically and causally.
let scores = Tensor::zeros((2, 1, 3, 3), DType::F32, &device).unwrap();
let masked = scores
.broadcast_add(&mask)
.unwrap()
.squeeze(1)
.unwrap()
.to_vec3::<f32>()
.unwrap();

assert_eq!(
masked[0], masked[1],
"both batch rows must receive the same causal mask",
);
assert_eq!(masked[0][0], vec![0.0, neg, neg]);
assert_eq!(masked[0][1], vec![0.0, 0.0, neg]);
assert_eq!(masked[0][2], vec![0.0, 0.0, 0.0]);
}
}