Skip to content

fix(flash-attn): bump vendored FA2 kernels from Dec 2024 to v2.8.3#3576

Open
aryanputta wants to merge 1 commit into
huggingface:mainfrom
aryanputta:fix/bump-flash-attn-kernels-v2.8.3
Open

fix(flash-attn): bump vendored FA2 kernels from Dec 2024 to v2.8.3#3576
aryanputta wants to merge 1 commit into
huggingface:mainfrom
aryanputta:fix/bump-flash-attn-kernels-v2.8.3

Conversation

@aryanputta

Copy link
Copy Markdown

Problem

candle-flash-attn's vendored kernels were last updated December 2024 (#2688#2690). Since then, flash_attn upstream has shipped v2.7.0 through v2.8.3 with bf16 numerics fixes, sliding-window correctness improvements, and GQA edge case fixes.

The divergence is measurable and significant on long-context bf16 inference. Testing against Jina V4 (Qwen2.5-VL base, bf16, A6000) with both sides using FA2, comparing pooled final-layer embeddings:

Token band P95 cosine (candle vs Python FA2 2.8.3)
0–200 tokens 0.9997 ✅
500–1000 tokens 0.9867 ⚠️
8k–12k tokens 0.8667 ❌
16k–20k tokens 0.8265 ❌
20k–25k tokens 0.7018 ❌

Both sides use FA2; both sides use bf16; the only difference is the kernel version. Details in #3515.

Changes

Updated from Dao-AILab/flash-attention@v2.8.3 csrc/flash_attn/src/:

Headers (algorithmic — all updated):
alibi.h, block_info.h, dropout.h, flash.h, hardware_info.h, kernel_traits.h, mask.h, philox.cuh, rotary.h, softmax.h, static_switch.h, utils.h, flash_fwd_launch_template.h, flash_fwd_kernel.h

New file added: namespace_config.h — introduced in v2.8.3 to support configurable namespace isolation. Defaults to namespace flash so candle's existing flash_api.cu is unaffected.

Forward kernels updated: flash_fwd_hdim{32,64,96,128,192,256}_{fp16,bf16}{,_causal}_sm80.cu (28 files)

candle-specific adaptations (preserved):

  • flash.h: PyTorch CUDAGeneratorImpl include and at::PhiloxCudaState member remain commented out — candle is a standalone CUDA build without PyTorch
  • flash_fwd_launch_template.h: c10/cuda/CUDAException.h replaced with candle's own error.h which provides equivalent C10_CUDA_CHECK / C10_CUDA_KERNEL_LAUNCH_CHECK macros
  • Candle-specific head dimensions (160, 224, 512) are not in upstream FA2 v2.8.3 — left unchanged

GPU validation needed

I don't have access to SM80+ hardware to compile and run the full test suite. The PR is mechanically correct (all PyTorch deps removed, namespace_config.h defaults to flash) but needs a maintainer with GPU access to:

  1. cargo build -p candle-flash-attn
  2. Run the divergence measurement from candle-flash-attn vendored kernels ~13 months stale; long-context bf16 divergence vs PyTorch flash_attn 2.8.3 #3515 to confirm the long-context cosine gap closes

Closes #3515

candle-flash-attn's vendored kernels were last updated in December 2024
(PRs huggingface#2688huggingface#2690), locking us at an FA2 state before v2.7.0. The
upstream flash_attn Python package is now at v2.8.3, and we have
measured significant cosine divergence on long-context bf16 inference:
P95 cosine drops from 0.9997 at 0–200 tokens to ~0.72 at 20k+ tokens
vs HuggingFace transformers with flash_attention_2 forced. See huggingface#3515.

Changes
-------
* Replaced all updatable algorithmic headers and forward kernel .cu
  files with their FA2 v2.8.3 equivalents from
  Dao-AILab/flash-attention@v2.8.3 csrc/flash_attn/src/:
    - alibi.h, block_info.h, dropout.h, flash.h, hardware_info.h
    - kernel_traits.h, mask.h, philox.cuh, rotary.h, softmax.h
    - static_switch.h, utils.h
    - flash_fwd_launch_template.h, flash_fwd_kernel.h
    - All flash_fwd_hdim{32,64,96,128,192,256}_{fp16,bf16}{,_causal}_sm80.cu

* Added namespace_config.h (new in FA2 v2.8.3): defines FLASH_NAMESPACE
  which defaults to `flash`, keeping full ABI compatibility with candle's
  existing flash_api.cu.

* candle-specific adaptations (preserved from previous vendoring):
  - flash.h: PyTorch CUDAGeneratorImpl include and at::PhiloxCudaState
    member remain commented out (candle uses standalone CUDA, not PyTorch)
  - flash_fwd_launch_template.h: c10/cuda/CUDAException.h replaced with
    candle's own error.h which provides the C10_CUDA_CHECK macros
  - kernels not present upstream (hdim160/224/512): left unchanged as
    they are candle-specific additions using the old algorithmic headers

Needs GPU validation
--------------------
This PR cannot be fully validated without an SM80+ GPU. The expected
observable fix is that long-context bf16 cosine similarity vs Python
flash_attn 2.8.3 converges at every token band (target: P95 >= 0.999
for 0–500 tokens, improvement above 500).

Reproducer from issue huggingface#3515:
  - Model: jinaai/jina-embeddings-v4 (Qwen2.5-VL base, 36 layers)
  - Hardware: A6000 (sm86), bf16, single GPU
  - Metric: pooled final-layer embedding cosine vs HuggingFace FA2

Fixes huggingface#3515
@toddwbucy

Copy link
Copy Markdown

Hi @aryanputta, thanks for digging into this -- the Jina-V4 cosine numbers across token bands are a solid justification for the bump.

Heads-up that this overlaps with #3521, which I opened on May 7 and which also closes #3515. That PR vendors the same Tri Dao v2.8.3 sources and additionally includes the split-KV kernels (24 new files), the build.rs wiring, and the split-KV dispatch in flash_api.cu -- i.e. the kernel bump plus the integration needed to actually build and dispatch the new variants.

To avoid duplicated review effort, can we converge on one PR? Happy to fold your benchmark numbers into #3521 as the numerical justification (with credit), or to coordinate however works best. @EricLBuehler -- since both are open against #3515, which approach would you prefer to take forward?

@aryanputta

Copy link
Copy Markdown
Author

Thanks for the heads-up and for pointing me to #3521. I took a look and it definitely covers the broader integration path beyond just the kernel/vendor bump, especially with the split-KV dispatch and build wiring.

I’m completely onboard with converging on a single PR to avoid duplicate review effort. Happy to contribute the Jina-V4 cosine benchmark data and token-band results as supporting validation for #3521, with the attribution handled however makes the most sense.

@EricLBuehler since #3521 already contains the full integration stack, I’m good with consolidating there if that’s the preferred direction.If I see anything I will let you know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

candle-flash-attn vendored kernels ~13 months stale; long-context bf16 divergence vs PyTorch flash_attn 2.8.3

2 participants