Fix 128x softmax_lse over-allocation in the dense flash-attn-v3 path#3597
Open
Xueying-VirtueAI wants to merge 1 commit into
Open
Fix 128x softmax_lse over-allocation in the dense flash-attn-v3 path#3597Xueying-VirtueAI wants to merge 1 commit into
Xueying-VirtueAI wants to merge 1 commit into
Conversation
The dense path allocated softmax_lse as b*128*nheads*seqlen_q, but the dense
LSE layout is [b, nheads, seqlen_q] (see flash.h). The stray 128x factor made
every forward cudaMalloc + memset (and later free) a multi-GB scratch buffer
-- e.g. 4.3GB at batch=128, seqlen=2048, 32 heads -- which dominated the
per-call host overhead: nsys showed the kernel itself at ~6.6ms while
wall-clock was ~28ms for that shape.
Allocate b*nheads*seqlen_q_rounded instead (seqlen_q_rounded guards
partial-tile epilogue writes). Outputs are unchanged; the varlen path already
allocates the correct nheads*total_q.
Latency (H200, batch=128, GQA 32/8, headdim=128, causal, bf16), us/forward:
seqlen before after
128 1603.7 302.5
1024 12784.8 2201.9
2048 28253.9 6971.1
8192 183511.4 94019.9
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
The dense path of
candle-flash-attn-v3allocates the softmax LSE scratch buffer asbut the dense LSE layout is
[b, nheads, seqlen_q](seeflash.h). The stray128xfactor turns every forward call into a multi-GBcudaMalloc+memset+ free — e.g. 4.3 GB of scratch at batch=128, seqlen=2048, 32 heads. Profiling with nsys showed the FA3 kernel itself taking ~6.6 ms for that shape while wall-clock was ~28 ms; the gap was almost entirely this allocation, not compute.The varlen path already allocates the correct
nheads * total_q.Fix
Allocate
b_sz * num_heads * seqlen_q_rounded(one line).seqlen_q_rounded(next multiple of 128) guards partial-tile epilogue writes. Outputs are unchanged — this buffer is scratch whose contents are produced by the kernel; only its size changes.Numbers
H200, batch=128, GQA 32/8, headdim=128, causal, bf16, us/forward (measured on a vendored copy of this crate at v0.10.2, CUDA 12.9, sm90a; a fix for the NULL
tile_count_semaphorecrash was also applied locally sincecausal=truedoes not run at all without it — that bug is being reported/fixed separately):Outputs verified unchanged (bit-identical before/after on random bf16 inputs). Existing test suite on this branch (H200, sm90a, CUDA 12.9):
The 3 filtered-out cases are the pre-existing
head_dim=512varlen tests, which already fail on unpatchedmainon this GPU (CUDA error (hkernel/flash_fwd_launch_template.h:142): invalid argument) — unrelated to this PR.🤖 Generated with Claude Code