Skip to content

Fix causal attention in candle-flash-attn-v3: NULL tile_count_semaphore crash (dense) + silently non-causal varlen#3606

Open
Xueying-VirtueAI wants to merge 3 commits into
huggingface:mainfrom
Xueying-VirtueAI:fa3-fix-causal
Open

Fix causal attention in candle-flash-attn-v3: NULL tile_count_semaphore crash (dense) + silently non-causal varlen#3606
Xueying-VirtueAI wants to merge 3 commits into
huggingface:mainfrom
Xueying-VirtueAI:fa3-fix-causal

Conversation

@Xueying-VirtueAI

Copy link
Copy Markdown

This PR fixes two independent bugs in candle-flash-attn-v3 that both break causal=true. The existing test suite only exercises causal=false, which is why neither was caught. Both fixes come with regression tests against an eager f32 reference.

Note: this branch is stacked on #3597 (its one-line softmax_lse allocation fix is the first commit here) because both PRs touch the same allocation block in the dense path. Once #3597 merges, this PR reduces to the two causal commits.

Bug 1: flash_attn(..., causal=true) fails with an illegal global atomic

The causal/local (non-varlen) path selects the DynamicPersistentTileScheduler (flash_fwd_launch_template.h), which does

atomicAdd(params.tile_count_semaphore, 1)   // tile_scheduler.hpp:266

The binding memsets Flash_fwd_params to 0 and never allocates that counter, so the pointer stays NULL and every dense causal call performs an illegal global atomic (CUDA illegal memory access / unspecified launch failure).

Fix: allocate a zero-initialized int32 in the binding (mirroring softmax_lse) and thread it through run_mha_v3 into params.tile_count_semaphore. The varlen path currently selects the SingleTileScheduler, which ignores the counter, but it now passes a valid pointer as well.

Bug 2: flash_attn_varlen(..., causal=true) silently runs full non-causal attention

The varlen path (and only the varlen path) clamped both window sizes up to max_seqlen_k unconditionally:

if window_size_right < self.max_seqlen_k as i32 {
    window_size_right = self.max_seqlen_k.clone() as i32;
}

causal=true encodes as window_size_right = 0, so the clamp rewrote it to max_seqlen_k (full attention) before is_causal is derived from window_size_right == 0. The kernel then ran full non-causal attention and returned wrong values — no error, no warning. The dense path has no such clamp and derives is_causal correctly.

Fix: drop the clamps and keep the raw window sizes, exactly mirroring the dense path: is_causal is derived first, and only the unset (-1) side is extended to max_seqlen_k afterwards.

Verification

On unpatched main (3d3d9c4; H200, sm90a, CUDA 12.9), running the new tests only:

  • flash_attn_causal: 9/9 fail. With CUDA_LAUNCH_BLOCKING=1 the error surfaces at the kernel itself:
    CUDA error (hkernel/flash_fwd_launch_template.h:170): an illegal memory access was encountered
    (without launch blocking it shows up asynchronously on the next op, e.g. CublasError(CUBLAS_STATUS_EXECUTION_FAILED)).
  • flash_attn_varlen_causal: 9/9 fail with wrong values (output does not match the causal reference — it is the non-causal result).

With this PR (same machine):

test result: ok. 30 passed; 0 failed; 0 ignored; 0 measured; 3 filtered out

i.e. all 18 new causal tests (head_dim {64, 128, 256} x seq_len {2, 4, 9}, dense + varlen) plus the 12 pre-existing tests pass. The 3 filtered-out cases are the pre-existing head_dim=512 varlen tests, which already fail on unpatched main on this GPU with CUDA error (hkernel/flash_fwd_launch_template.h:142): invalid argument — unrelated to this PR.

In a separate harness we also cross-checked the fixed varlen causal output against vLLM's FlashAttention-3 on random bf16 inputs: bit-identical (max_abs = 0.0).

Known limitation (pre-existing, disclosed)

With causal now actually reaching the kernel, varlen batches whose cu_seqlens contain different sequence lengths (ragged) can hang in the SingleTileScheduler causal mainloop; single-sequence and uniform-length batches work. This is a pre-existing kernel-level defect that was previously masked by bug 2 (causal silently never reached the kernel). Reported separately in #3603 with details. The new tests deliberately use uniform lengths.

🤖 Generated with Claude Code

Xueying-VirtueAI and others added 3 commits June 9, 2026 23:59
The dense path allocated softmax_lse as b*128*nheads*seqlen_q, but the dense
LSE layout is [b, nheads, seqlen_q] (see flash.h). The stray 128x factor made
every forward cudaMalloc + memset (and later free) a multi-GB scratch buffer
-- e.g. 4.3GB at batch=128, seqlen=2048, 32 heads -- which dominated the
per-call host overhead: nsys showed the kernel itself at ~6.6ms while
wall-clock was ~28ms for that shape.

Allocate b*nheads*seqlen_q_rounded instead (seqlen_q_rounded guards
partial-tile epilogue writes). Outputs are unchanged; the varlen path already
allocates the correct nheads*total_q.

Latency (H200, batch=128, GQA 32/8, headdim=128, causal, bf16), us/forward:

  seqlen    before      after
     128    1603.7      302.5
    1024   12784.8     2201.9
    2048   28253.9     6971.1
    8192  183511.4    94019.9

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…semaphore)

The causal/local (non-varlen) path selects the DynamicPersistentTileScheduler,
which does an atomicAdd on params.tile_count_semaphore (tile_scheduler.hpp).
The binding memsets Flash_fwd_params to 0 and never allocated that counter, so
every flash_attn(..., causal=true) call performed an illegal global atomic on a
NULL pointer.

Allocate a zero-initialized int32 in the binding (mirroring softmax_lse) and
thread it through the C entry point into params.tile_count_semaphore. The
varlen path currently selects the SingleTileScheduler which ignores the
counter, but it now passes a valid pointer as well.

Adds a causal regression test against an eager f32 reference; existing tests
only covered causal=false, which is why this never fired in the suite.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…d the causal flag)

The varlen path clamped both window sizes up to max_seqlen_k unconditionally:

    if window_size_right < self.max_seqlen_k as i32 {
        window_size_right = self.max_seqlen_k.clone() as i32;
    }

causal=true encodes as window_size_right=0, so the clamp rewrote it to
max_seqlen_k (full attention) before is_causal was derived from it. As a
result flash_attn_varlen(..., causal=true) silently ran full NON-causal
attention and returned wrong values. The dense path has no such clamp and was
unaffected.

Drop the clamps and keep the raw window sizes, mirroring the dense path:
is_causal is derived below, and only the unset (-1) side is extended to
max_seqlen_k afterwards. Adds a varlen causal regression test against an eager
f32 reference; existing varlen tests only covered causal=false.

Note: with causal now actually reaching the kernel, varlen batches with
*different* sequence lengths in one cu_seqlens (ragged) can hang in the
SingleTileScheduler causal mainloop; single-sequence and uniform-length
batches work. That pre-existing kernel-level defect is reported separately.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant