Fix causal attention in candle-flash-attn-v3: NULL tile_count_semaphore crash (dense) + silently non-causal varlen#3606
Open
Xueying-VirtueAI wants to merge 3 commits into
Conversation
The dense path allocated softmax_lse as b*128*nheads*seqlen_q, but the dense
LSE layout is [b, nheads, seqlen_q] (see flash.h). The stray 128x factor made
every forward cudaMalloc + memset (and later free) a multi-GB scratch buffer
-- e.g. 4.3GB at batch=128, seqlen=2048, 32 heads -- which dominated the
per-call host overhead: nsys showed the kernel itself at ~6.6ms while
wall-clock was ~28ms for that shape.
Allocate b*nheads*seqlen_q_rounded instead (seqlen_q_rounded guards
partial-tile epilogue writes). Outputs are unchanged; the varlen path already
allocates the correct nheads*total_q.
Latency (H200, batch=128, GQA 32/8, headdim=128, causal, bf16), us/forward:
seqlen before after
128 1603.7 302.5
1024 12784.8 2201.9
2048 28253.9 6971.1
8192 183511.4 94019.9
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…semaphore) The causal/local (non-varlen) path selects the DynamicPersistentTileScheduler, which does an atomicAdd on params.tile_count_semaphore (tile_scheduler.hpp). The binding memsets Flash_fwd_params to 0 and never allocated that counter, so every flash_attn(..., causal=true) call performed an illegal global atomic on a NULL pointer. Allocate a zero-initialized int32 in the binding (mirroring softmax_lse) and thread it through the C entry point into params.tile_count_semaphore. The varlen path currently selects the SingleTileScheduler which ignores the counter, but it now passes a valid pointer as well. Adds a causal regression test against an eager f32 reference; existing tests only covered causal=false, which is why this never fired in the suite. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…d the causal flag)
The varlen path clamped both window sizes up to max_seqlen_k unconditionally:
if window_size_right < self.max_seqlen_k as i32 {
window_size_right = self.max_seqlen_k.clone() as i32;
}
causal=true encodes as window_size_right=0, so the clamp rewrote it to
max_seqlen_k (full attention) before is_causal was derived from it. As a
result flash_attn_varlen(..., causal=true) silently ran full NON-causal
attention and returned wrong values. The dense path has no such clamp and was
unaffected.
Drop the clamps and keep the raw window sizes, mirroring the dense path:
is_causal is derived below, and only the unset (-1) side is extended to
max_seqlen_k afterwards. Adds a varlen causal regression test against an eager
f32 reference; existing varlen tests only covered causal=false.
Note: with causal now actually reaching the kernel, varlen batches with
*different* sequence lengths in one cu_seqlens (ragged) can hang in the
SingleTileScheduler causal mainloop; single-sequence and uniform-length
batches work. That pre-existing kernel-level defect is reported separately.
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR fixes two independent bugs in
candle-flash-attn-v3that both breakcausal=true. The existing test suite only exercisescausal=false, which is why neither was caught. Both fixes come with regression tests against an eager f32 reference.Bug 1:
flash_attn(..., causal=true)fails with an illegal global atomicThe causal/local (non-varlen) path selects the
DynamicPersistentTileScheduler(flash_fwd_launch_template.h), which doesThe binding memsets
Flash_fwd_paramsto 0 and never allocates that counter, so the pointer stays NULL and every dense causal call performs an illegal global atomic (CUDA illegal memory access / unspecified launch failure).Fix: allocate a zero-initialized
int32in the binding (mirroringsoftmax_lse) and thread it throughrun_mha_v3intoparams.tile_count_semaphore. The varlen path currently selects theSingleTileScheduler, which ignores the counter, but it now passes a valid pointer as well.Bug 2:
flash_attn_varlen(..., causal=true)silently runs full non-causal attentionThe varlen path (and only the varlen path) clamped both window sizes up to
max_seqlen_kunconditionally:causal=trueencodes aswindow_size_right = 0, so the clamp rewrote it tomax_seqlen_k(full attention) beforeis_causalis derived fromwindow_size_right == 0. The kernel then ran full non-causal attention and returned wrong values — no error, no warning. The dense path has no such clamp and derivesis_causalcorrectly.Fix: drop the clamps and keep the raw window sizes, exactly mirroring the dense path:
is_causalis derived first, and only the unset (-1) side is extended tomax_seqlen_kafterwards.Verification
On unpatched
main(3d3d9c4; H200, sm90a, CUDA 12.9), running the new tests only:flash_attn_causal: 9/9 fail. WithCUDA_LAUNCH_BLOCKING=1the error surfaces at the kernel itself:CUDA error (hkernel/flash_fwd_launch_template.h:170): an illegal memory access was encountered(without launch blocking it shows up asynchronously on the next op, e.g.
CublasError(CUBLAS_STATUS_EXECUTION_FAILED)).flash_attn_varlen_causal: 9/9 fail with wrong values (output does not match the causal reference — it is the non-causal result).With this PR (same machine):
i.e. all 18 new causal tests (head_dim {64, 128, 256} x seq_len {2, 4, 9}, dense + varlen) plus the 12 pre-existing tests pass. The 3 filtered-out cases are the pre-existing
head_dim=512varlen tests, which already fail on unpatchedmainon this GPU withCUDA error (hkernel/flash_fwd_launch_template.h:142): invalid argument— unrelated to this PR.In a separate harness we also cross-checked the fixed varlen causal output against vLLM's FlashAttention-3 on random bf16 inputs: bit-identical (max_abs = 0.0).
Known limitation (pre-existing, disclosed)
With causal now actually reaching the kernel, varlen batches whose
cu_seqlenscontain different sequence lengths (ragged) can hang in theSingleTileSchedulercausal mainloop; single-sequence and uniform-length batches work. This is a pre-existing kernel-level defect that was previously masked by bug 2 (causal silently never reached the kernel). Reported separately in #3603 with details. The new tests deliberately use uniform lengths.🤖 Generated with Claude Code