fix(flash-attn): bump vendored FA2 kernels from Dec 2024 to v2.8.3#3576
fix(flash-attn): bump vendored FA2 kernels from Dec 2024 to v2.8.3#3576aryanputta wants to merge 1 commit into
Conversation
candle-flash-attn's vendored kernels were last updated in December 2024 (PRs huggingface#2688–huggingface#2690), locking us at an FA2 state before v2.7.0. The upstream flash_attn Python package is now at v2.8.3, and we have measured significant cosine divergence on long-context bf16 inference: P95 cosine drops from 0.9997 at 0–200 tokens to ~0.72 at 20k+ tokens vs HuggingFace transformers with flash_attention_2 forced. See huggingface#3515. Changes ------- * Replaced all updatable algorithmic headers and forward kernel .cu files with their FA2 v2.8.3 equivalents from Dao-AILab/flash-attention@v2.8.3 csrc/flash_attn/src/: - alibi.h, block_info.h, dropout.h, flash.h, hardware_info.h - kernel_traits.h, mask.h, philox.cuh, rotary.h, softmax.h - static_switch.h, utils.h - flash_fwd_launch_template.h, flash_fwd_kernel.h - All flash_fwd_hdim{32,64,96,128,192,256}_{fp16,bf16}{,_causal}_sm80.cu * Added namespace_config.h (new in FA2 v2.8.3): defines FLASH_NAMESPACE which defaults to `flash`, keeping full ABI compatibility with candle's existing flash_api.cu. * candle-specific adaptations (preserved from previous vendoring): - flash.h: PyTorch CUDAGeneratorImpl include and at::PhiloxCudaState member remain commented out (candle uses standalone CUDA, not PyTorch) - flash_fwd_launch_template.h: c10/cuda/CUDAException.h replaced with candle's own error.h which provides the C10_CUDA_CHECK macros - kernels not present upstream (hdim160/224/512): left unchanged as they are candle-specific additions using the old algorithmic headers Needs GPU validation -------------------- This PR cannot be fully validated without an SM80+ GPU. The expected observable fix is that long-context bf16 cosine similarity vs Python flash_attn 2.8.3 converges at every token band (target: P95 >= 0.999 for 0–500 tokens, improvement above 500). Reproducer from issue huggingface#3515: - Model: jinaai/jina-embeddings-v4 (Qwen2.5-VL base, 36 layers) - Hardware: A6000 (sm86), bf16, single GPU - Metric: pooled final-layer embedding cosine vs HuggingFace FA2 Fixes huggingface#3515
|
Hi @aryanputta, thanks for digging into this -- the Jina-V4 cosine numbers across token bands are a solid justification for the bump. Heads-up that this overlaps with #3521, which I opened on May 7 and which also closes #3515. That PR vendors the same Tri Dao v2.8.3 sources and additionally includes the split-KV kernels (24 new files), the build.rs wiring, and the split-KV dispatch in flash_api.cu -- i.e. the kernel bump plus the integration needed to actually build and dispatch the new variants. To avoid duplicated review effort, can we converge on one PR? Happy to fold your benchmark numbers into #3521 as the numerical justification (with credit), or to coordinate however works best. @EricLBuehler -- since both are open against #3515, which approach would you prefer to take forward? |
|
Thanks for the heads-up and for pointing me to #3521. I took a look and it definitely covers the broader integration path beyond just the kernel/vendor bump, especially with the split-KV dispatch and build wiring. I’m completely onboard with converging on a single PR to avoid duplicate review effort. Happy to contribute the Jina-V4 cosine benchmark data and token-band results as supporting validation for #3521, with the attribution handled however makes the most sense. @EricLBuehler since #3521 already contains the full integration stack, I’m good with consolidating there if that’s the preferred direction.If I see anything I will let you know. |
Problem
candle-flash-attn's vendored kernels were last updated December 2024 (#2688–#2690). Since then,flash_attnupstream has shipped v2.7.0 through v2.8.3 with bf16 numerics fixes, sliding-window correctness improvements, and GQA edge case fixes.The divergence is measurable and significant on long-context bf16 inference. Testing against Jina V4 (Qwen2.5-VL base, bf16, A6000) with both sides using FA2, comparing pooled final-layer embeddings:
Both sides use FA2; both sides use bf16; the only difference is the kernel version. Details in #3515.
Changes
Updated from Dao-AILab/flash-attention@v2.8.3
csrc/flash_attn/src/:Headers (algorithmic — all updated):
alibi.h,block_info.h,dropout.h,flash.h,hardware_info.h,kernel_traits.h,mask.h,philox.cuh,rotary.h,softmax.h,static_switch.h,utils.h,flash_fwd_launch_template.h,flash_fwd_kernel.hNew file added:
namespace_config.h— introduced in v2.8.3 to support configurable namespace isolation. Defaults tonamespace flashso candle's existingflash_api.cuis unaffected.Forward kernels updated:
flash_fwd_hdim{32,64,96,128,192,256}_{fp16,bf16}{,_causal}_sm80.cu(28 files)candle-specific adaptations (preserved):
flash.h: PyTorchCUDAGeneratorImplinclude andat::PhiloxCudaStatemember remain commented out — candle is a standalone CUDA build without PyTorchflash_fwd_launch_template.h:c10/cuda/CUDAException.hreplaced with candle's ownerror.hwhich provides equivalentC10_CUDA_CHECK/C10_CUDA_KERNEL_LAUNCH_CHECKmacrosGPU validation needed
I don't have access to SM80+ hardware to compile and run the full test suite. The PR is mechanically correct (all PyTorch deps removed, namespace_config.h defaults to
flash) but needs a maintainer with GPU access to:cargo build -p candle-flash-attnCloses #3515