Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions candle-flash-attn-v3/hkernel/flash_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ extern "C" void run_mha_v3(
void *v_ptr,
void *o_ptr,
void *softmax_lse_ptr,
void *tile_count_semaphore_ptr,
void *alibi_slopes_ptr,

int32_t *cu_seqlens_q_ptr,
Expand Down Expand Up @@ -260,6 +261,11 @@ extern "C" void run_mha_v3(
params.o_ptr = o_ptr;

params.softmax_lse_ptr = softmax_lse_ptr;
// Global tile counter for the DynamicPersistentTileScheduler, which is selected for
// the causal/local (non-varlen) path and does an atomicAdd on this pointer. The
// caller passes a zero-initialized int32; without it the pointer stays NULL after
// the memset above and the causal kernel performs an illegal global atomic.
params.tile_count_semaphore = static_cast<int *>(tile_count_semaphore_ptr);
params.alibi_slopes_ptr = alibi_slopes_ptr;

// All stride are in elements, not bytes.
Expand Down
4 changes: 4 additions & 0 deletions candle-flash-attn-v3/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ extern "C" {
v_ptr: *const c_void,
o_ptr: *const c_void,
softmax_lse_ptr: *const c_void,
// Zero-initialized int32 global counter for the DynamicPersistentTileScheduler
// (causal/local non-varlen path). Leaving it NULL makes the causal kernel fail
// with an illegal global atomic.
tile_count_semaphore_ptr: *const c_void,
alibi_slopes_ptr: *const c_void,

cu_seqlens_q_ptr: *const i32,
Expand Down
32 changes: 25 additions & 7 deletions candle-flash-attn-v3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,18 @@ impl FlashAttn {

let elem_count = out_shape.elem_count();
let mut dst = unsafe { dev.alloc::<T>(elem_count) }?;
let mut softmax_lse = dev.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)?;
// The dense-path LSE layout is [b, nheads, seqlen_q] (padded, see flash.h), so
// b*nheads*seqlen_q_rounded is sufficient; seqlen_q_rounded guards partial-tile
// epilogue writes. The previous b*128*nheads*seqlen_q allocation was 128x too
// large, turning every forward into a multi-GB cudaMalloc+memset (e.g. 4.3GB at
// batch=128, seqlen=2048, 32 heads) that dominated the per-call host overhead.
let mut softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q_rounded)?;
// Zero-initialized global tile counter for the DynamicPersistentTileScheduler,
// which is selected for the causal/local path and does an atomicAdd on
// params.tile_count_semaphore. Without this allocation the pointer stays NULL
// (params are memset to 0 in run_mha_v3) and causal=true fails with an illegal
// global atomic.
let mut tile_count_semaphore = dev.alloc_zeros::<i32>(1)?;

let is_bf16 = if is_bf16 { 1 } else { 0 };

Expand All @@ -188,12 +199,14 @@ impl FlashAttn {
let (v_ptr, _guard) = v.device_ptr(&stream);
let (dst_ptr, _guard) = dst.device_ptr_mut(&stream);
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr_mut(&stream);
let (tile_count_semaphore_ptr, _guard) = tile_count_semaphore.device_ptr_mut(&stream);
ffi::run_mha_v3(
q_ptr as *const core::ffi::c_void,
k_ptr as *const core::ffi::c_void,
v_ptr as *const core::ffi::c_void,
dst_ptr as *const core::ffi::c_void,
softmax_lse_ptr as *const core::ffi::c_void,
tile_count_semaphore_ptr as *const core::ffi::c_void,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ std::ptr::null(),
/* cu_seqlens_k_ptr */ std::ptr::null(),
Expand Down Expand Up @@ -583,24 +596,23 @@ impl FlashAttnVarLen {
};

// if window_size_left > self.max_seqlen_k or None => -1
// Keep the raw window sizes here (like the dense path): is_causal is derived
// below from window_size_right == 0, and only the unset (-1) side is extended
// to max_seqlen_k after that. The previous unconditional clamp to max_seqlen_k
// clobbered the causal signal (window_size_right 0 -> max_seqlen_k), so
// flash_attn_varlen(..., causal=true) silently ran full non-causal attention.
let mut window_size_left = self
.window_size_left
.filter(|v| v <= &self.max_seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
if window_size_left < self.max_seqlen_k as i32 {
window_size_left = self.max_seqlen_k.clone() as i32;
}

// if window_size_right > self.max_seqlen_k or None => -1
let mut window_size_right = self
.window_size_right
.filter(|v| v <= &self.max_seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
if window_size_right < self.max_seqlen_k as i32 {
window_size_right = self.max_seqlen_k.clone() as i32;
}

let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32);
Expand All @@ -610,6 +622,10 @@ impl FlashAttnVarLen {
let elem_count = out_shape.elem_count();
let mut dst = unsafe { dev.alloc::<T>(elem_count) }?;
let mut softmax_lse = dev.alloc_zeros::<f32>(num_heads * total_q)?;
// Zero-initialized global tile counter; see the dense path above. The varlen
// path currently selects the SingleTileScheduler (which ignores it), but the
// C entry point expects a valid pointer either way.
let mut tile_count_semaphore = dev.alloc_zeros::<i32>(1)?;

let is_bf16 = if is_bf16 { 1 } else { 0 };

Expand All @@ -632,6 +648,7 @@ impl FlashAttnVarLen {
let (v_ptr, _guard) = v.device_ptr(&stream);
let (dst_ptr, _guard) = dst.device_ptr_mut(&stream);
let (softmax_lse_ptr, _guard) = softmax_lse.device_ptr_mut(&stream);
let (tile_count_semaphore_ptr, _guard) = tile_count_semaphore.device_ptr_mut(&stream);
let (seqlens_q_ptr, _guard) = seqlens_q.device_ptr(&stream);
let (seqlens_k_ptr, _guard) = seqlens_k.device_ptr(&stream);
ffi::run_mha_v3(
Expand All @@ -640,6 +657,7 @@ impl FlashAttnVarLen {
v_ptr as *const core::ffi::c_void,
dst_ptr as *const core::ffi::c_void,
softmax_lse_ptr as *const core::ffi::c_void,
tile_count_semaphore_ptr as *const core::ffi::c_void,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ seqlens_q_ptr as *const i32,
/* cu_seqlens_k_ptr */ seqlens_k_ptr as *const i32,
Expand Down
106 changes: 106 additions & 0 deletions candle-flash-attn-v3/tests/flash_attn_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@ fn fa_acausal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<
Ok(output)
}

fn fa_causal(q: &Tensor, k: &Tensor, v: &Tensor, softmax_scale: f32) -> Result<Tensor> {
let in_dtype = q.dtype();
let seq_q = q.dim(D::Minus2)?;
let seq_k = k.dim(D::Minus2)?;
let q = q.to_dtype(DType::F32)?;
let k = k.to_dtype(DType::F32)?;
let v = v.to_dtype(DType::F32)?;
let att = (q.matmul(&k.t()?)? * softmax_scale as f64)?;
// Causal mask aligned to the bottom-right corner, matching flash attention's
// behavior: query i attends to keys 0..=(i + seq_k - seq_q).
let mask: Vec<f32> = (0..seq_q)
.flat_map(|i| {
(0..seq_k).map(move |j| {
if j + seq_q > seq_k + i {
f32::NEG_INFINITY
} else {
0.0
}
})
})
.collect();
let mask = Tensor::from_vec(mask, (seq_q, seq_k), q.device())?;
let att = att.broadcast_add(&mask)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
let output = att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?;
Ok(output)
}

#[test]
fn flash_attn_acausal() -> Result<()> {
let device = Device::new_cuda(0)?;
Expand Down Expand Up @@ -242,6 +270,38 @@ fn flash_attn_acausal_gqa() -> Result<()> {
Ok(())
}

#[rstest(
head_dim => [64, 128, 256],
seq_len => [2, 4, 9],
)]
fn flash_attn_causal(head_dim: usize, seq_len: usize) -> Result<()> {
// Regression test for the NULL tile_count_semaphore: the causal (non-varlen) path
// selects the DynamicPersistentTileScheduler, which atomicAdds on
// params.tile_count_semaphore. The binding used to leave it unallocated (NULL),
// so any flash_attn(..., causal=true) call failed with an illegal global atomic.
let device = Device::new_cuda(0)?;
let q = Tensor::arange(0u32, (3 * seq_len * head_dim) as u32, &device)?
.to_dtype(DType::F16)?
.reshape((1, 3, seq_len, head_dim))?;
let k = (&q / ((head_dim * seq_len) as f64 * 4.))?;
let v = (&q / ((head_dim * seq_len) as f64 * 2.))?;
let q = (&q / ((head_dim * seq_len) as f64 * 3.))?;

let ys = {
let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
let v = v.transpose(1, 2)?;
candle_flash_attn_v3::flash_attn(&q, &k, &v, 0.5, true, false)?.transpose(1, 2)?
};
let ys = ys.i(0)?.to_dtype(DType::F32)?;
assert_eq!(ys.dims(), &[3, seq_len, head_dim]);

let ys2 = fa_causal(&q, &k, &v, 0.5)?.i(0)?.to_dtype(DType::F32)?;
let diff = ys.sub(&ys2)?.abs()?.flatten_all()?.max(0)?;
assert!(diff.to_vec0::<f32>()?.abs() < 5e-3);
Ok(())
}

#[test]
fn flash_attn_varlen() -> Result<()> {
let device = Device::new_cuda(0)?;
Expand Down Expand Up @@ -392,3 +452,49 @@ fn flash_attn_varlen_param(head_dim: usize, seq_len: usize, use_gqa_packing: boo
assert!(diff.to_vec0::<f32>()?.abs() < 5e-3);
Ok(())
}

#[rstest(
head_dim => [64, 128, 256],
seq_len => [2, 4, 9],
)]
fn flash_attn_varlen_causal(head_dim: usize, seq_len: usize) -> Result<()> {
// Regression test for the window-size clamp that silently disabled causal masking
// on the varlen path: flash_attn_varlen(..., causal=true) used to run full
// (non-causal) attention because window_size_right=0 was clobbered to max_seqlen_k
// before is_causal was derived from it.
let device = Device::new_cuda(0)?;
let q = Tensor::arange(0u32, (3 * seq_len * head_dim) as u32, &device)?
.to_dtype(DType::F16)?
.reshape((3, seq_len, head_dim))?;
let k = (&q / ((head_dim * seq_len) as f64 * 4.))?;
let v = (&q / ((head_dim * seq_len) as f64 * 2.))?;
let q = (&q / ((head_dim * seq_len) as f64 * 3.))?;

let seqlens_q = Tensor::new(&[0u32, seq_len as u32], &device)?;
let seqlens_k = Tensor::new(&[0u32, seq_len as u32], &device)?;

let ys = {
let q = q.transpose(0, 1)?;
let k = k.transpose(0, 1)?;
let v = v.transpose(0, 1)?;
candle_flash_attn_v3::flash_attn_varlen(
&q, &k, &v, &seqlens_q, &seqlens_k, seq_len, seq_len, 0.5, true, false,
)?
.transpose(0, 1)?
};
let ys = ys.to_dtype(DType::F32)?;
assert_eq!(ys.dims(), &[3, seq_len, head_dim]);

let ys2 = {
// reference implementation
let q = q.unsqueeze(0)?;
let k = k.unsqueeze(0)?;
let v = v.unsqueeze(0)?;
let y = fa_causal(&q, &k, &v, 0.5)?;
y.i(0)?.to_dtype(DType::F32)?
};

let diff = ys.sub(&ys2)?.abs()?.flatten_all()?.max(0)?;
assert!(diff.to_vec0::<f32>()?.abs() < 5e-3);
Ok(())
}