Skip to content

Fix 128x softmax_lse over-allocation in the dense flash-attn-v3 path#3597

Open
Xueying-VirtueAI wants to merge 1 commit into
huggingface:mainfrom
Xueying-VirtueAI:fa3-fix-lse-overalloc
Open

Fix 128x softmax_lse over-allocation in the dense flash-attn-v3 path#3597
Xueying-VirtueAI wants to merge 1 commit into
huggingface:mainfrom
Xueying-VirtueAI:fa3-fix-lse-overalloc

Conversation

@Xueying-VirtueAI

Copy link
Copy Markdown

Problem

The dense path of candle-flash-attn-v3 allocates the softmax LSE scratch buffer as

let mut softmax_lse = dev.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)?;

but the dense LSE layout is [b, nheads, seqlen_q] (see flash.h). The stray 128x factor turns every forward call into a multi-GB cudaMalloc + memset + free — e.g. 4.3 GB of scratch at batch=128, seqlen=2048, 32 heads. Profiling with nsys showed the FA3 kernel itself taking ~6.6 ms for that shape while wall-clock was ~28 ms; the gap was almost entirely this allocation, not compute.

The varlen path already allocates the correct nheads * total_q.

Fix

Allocate b_sz * num_heads * seqlen_q_rounded (one line). seqlen_q_rounded (next multiple of 128) guards partial-tile epilogue writes. Outputs are unchanged — this buffer is scratch whose contents are produced by the kernel; only its size changes.

Numbers

H200, batch=128, GQA 32/8, headdim=128, causal, bf16, us/forward (measured on a vendored copy of this crate at v0.10.2, CUDA 12.9, sm90a; a fix for the NULL tile_count_semaphore crash was also applied locally since causal=true does not run at all without it — that bug is being reported/fixed separately):

seqlen before after speedup
128 1603.7 302.5 5.3x
1024 12784.8 2201.9 5.8x
2048 28253.9 6971.1 4.1x
8192 183511.4 94019.9 2.0x

Outputs verified unchanged (bit-identical before/after on random bf16 inputs). Existing test suite on this branch (H200, sm90a, CUDA 12.9):

test result: ok. 12 passed; 0 failed; 0 ignored; 0 measured; 3 filtered out

The 3 filtered-out cases are the pre-existing head_dim=512 varlen tests, which already fail on unpatched main on this GPU (CUDA error (hkernel/flash_fwd_launch_template.h:142): invalid argument) — unrelated to this PR.

🤖 Generated with Claude Code

The dense path allocated softmax_lse as b*128*nheads*seqlen_q, but the dense
LSE layout is [b, nheads, seqlen_q] (see flash.h). The stray 128x factor made
every forward cudaMalloc + memset (and later free) a multi-GB scratch buffer
-- e.g. 4.3GB at batch=128, seqlen=2048, 32 heads -- which dominated the
per-call host overhead: nsys showed the kernel itself at ~6.6ms while
wall-clock was ~28ms for that shape.

Allocate b*nheads*seqlen_q_rounded instead (seqlen_q_rounded guards
partial-tile epilogue writes). Outputs are unchanged; the varlen path already
allocates the correct nheads*total_q.

Latency (H200, batch=128, GQA 32/8, headdim=128, causal, bf16), us/forward:

  seqlen    before      after
     128    1603.7      302.5
    1024   12784.8     2201.9
    2048   28253.9     6971.1
    8192  183511.4    94019.9

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant