feat(candle-nn): QuantizedKvCache — INT8 KV cache with attention sinks (TurboQuant)#3577
Open
aryanputta wants to merge 2 commits into
Open
feat(candle-nn): QuantizedKvCache — INT8 KV cache with attention sinks (TurboQuant)#3577aryanputta wants to merge 2 commits into
aryanputta wants to merge 2 commits into
Conversation
candle-flash-attn's vendored kernels were last updated in December 2024 (PRs huggingface#2688–huggingface#2690), locking us at an FA2 state before v2.7.0. The upstream flash_attn Python package is now at v2.8.3, and we have measured significant cosine divergence on long-context bf16 inference: P95 cosine drops from 0.9997 at 0–200 tokens to ~0.72 at 20k+ tokens vs HuggingFace transformers with flash_attention_2 forced. See huggingface#3515. Changes ------- * Replaced all updatable algorithmic headers and forward kernel .cu files with their FA2 v2.8.3 equivalents from Dao-AILab/flash-attention@v2.8.3 csrc/flash_attn/src/: - alibi.h, block_info.h, dropout.h, flash.h, hardware_info.h - kernel_traits.h, mask.h, philox.cuh, rotary.h, softmax.h - static_switch.h, utils.h - flash_fwd_launch_template.h, flash_fwd_kernel.h - All flash_fwd_hdim{32,64,96,128,192,256}_{fp16,bf16}{,_causal}_sm80.cu * Added namespace_config.h (new in FA2 v2.8.3): defines FLASH_NAMESPACE which defaults to `flash`, keeping full ABI compatibility with candle's existing flash_api.cu. * candle-specific adaptations (preserved from previous vendoring): - flash.h: PyTorch CUDAGeneratorImpl include and at::PhiloxCudaState member remain commented out (candle uses standalone CUDA, not PyTorch) - flash_fwd_launch_template.h: c10/cuda/CUDAException.h replaced with candle's own error.h which provides the C10_CUDA_CHECK macros - kernels not present upstream (hdim160/224/512): left unchanged as they are candle-specific additions using the old algorithmic headers Needs GPU validation -------------------- This PR cannot be fully validated without an SM80+ GPU. The expected observable fix is that long-context bf16 cosine similarity vs Python flash_attn 2.8.3 converges at every token band (target: P95 >= 0.999 for 0–500 tokens, improvement above 500). Reproducer from issue huggingface#3515: - Model: jinaai/jina-embeddings-v4 (Qwen2.5-VL base, 36 layers) - Hardware: A6000 (sm86), bf16, single GPU - Metric: pooled final-layer embedding cosine vs HuggingFace FA2 Fixes huggingface#3515
…ion sinks
Implements the KV-cache quantisation component of TurboQuant (Google
Research, ICLR 2024) as a new `QuantizedKvCache` struct in
`candle-nn::kv_cache`.
## What it does
Symmetric INT8 quantisation (via U8 + 128 bias) with per-token per-head
scales reduces KV-cache memory by ~4× vs BF16/F16, which is the key
bottleneck for long-context inference. The first `n_sink_tokens` positions
are kept at full precision because initial tokens accumulate
disproportionately large attention scores ("attention sinks") and
quantising them causes measurable accuracy degradation.
Quantisation scheme (per token per head):
scale = clamp(max(|x|, dim=D), min=1e-6) // shape (B, H, S, 1)
q = clip(round(x/scale) + 128, 0, 255) // U8
x' = (q.f32() - 128) * scale // dequant
## API
```rust
// 4 sink tokens, sequence dimension=2 (standard attention layout)
let mut cache = QuantizedKvCache::new(/*dim=*/2, /*n_sink_tokens=*/4);
let (k_out, v_out) = cache.append(&k, &v)?; // returns dequantised f32
```
The struct follows the same append/reset/current_seq_len interface used by
`KvCache`, `ConcatKvCache`, and `RotatingKvCache`, so it can be swapped in
without changing attention code.
## Tests
Five unit tests on CPU (no GPU needed):
- Round-trip quantisation error stays below INT8 tolerance (~0.02)
- Sink tokens are returned with zero loss
- Mixed sink+bulk path produces correct shapes
- Incremental (autoregressive) appends grow seq dimension correctly
- Reset clears all state
## What this does NOT include
Cross-layer KV sharing (the other component of TurboQuant that achieves
the 6× aggregate reduction) requires architectural changes to each model's
forward pass and is out of scope here. A follow-on issue or PR can build on
this primitive.
Addresses huggingface#3425.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
QuantizedKvCachetocandle-nn::kv_cache, implementing the KV-cache quantisation component of TurboQuant (Google Research, ICLR 2024).Closes #3425 (partially — see "What's not included" below).
Motivation
KV cache is the dominant memory bottleneck for long-context LLM inference. TurboQuant shows that symmetric INT8 quantisation of keys and values with per-head scales achieves ~4× memory reduction with negligible accuracy loss. The crucial insight is that the first few tokens ("attention sinks") must be kept at full precision — they receive disproportionately large attention scores and quantising them causes measurable degradation.
Implementation
Quantisation scheme
For a tensor
xof shape(B, H, S, D), per-token per-head symmetric INT8 mapped to U8:Memory layout
k_sink,v_sink(B, H, n_sink, D)k_q,v_q(B, H, bulk, D)k_scale,v_scale(B, H, bulk, 1)API (drop-in for existing caches)
The interface is identical to
KvCache/ConcatKvCache/RotatingKvCache, so it can be swapped in at the model level without touching attention code.Tests
Five CPU-only unit tests:
test_quantized_kv_cache_roundtrip— max error < 0.02 (INT8 tolerance)test_quantized_kv_cache_sinks_full_precision— sink tokens are losslesstest_quantized_kv_cache_mixed_sink_and_bulk— correct shapes with sink+bulk splittest_quantized_kv_cache_incremental_append— autoregressive append grows seq correctlytest_quantized_kv_cache_reset— reset clears all stateWhat's not included
This PR implements the KV quantisation + attention sink component of TurboQuant. The third component — cross-layer KV sharing (sharing K/V pairs across adjacent layers, which provides the bulk of the 6× aggregate compression) — requires per-model architectural changes and is a separate PR. This primitive is the foundation that makes cross-layer sharing composable.