Skip to content

feat(candle-nn): QuantizedKvCache — INT8 KV cache with attention sinks (TurboQuant)#3577

Open
aryanputta wants to merge 2 commits into
huggingface:mainfrom
aryanputta:feat/turbo-quant-int8-kv-cache
Open

feat(candle-nn): QuantizedKvCache — INT8 KV cache with attention sinks (TurboQuant)#3577
aryanputta wants to merge 2 commits into
huggingface:mainfrom
aryanputta:feat/turbo-quant-int8-kv-cache

Conversation

@aryanputta

Copy link
Copy Markdown

Summary

Adds QuantizedKvCache to candle-nn::kv_cache, implementing the KV-cache quantisation component of TurboQuant (Google Research, ICLR 2024).

Closes #3425 (partially — see "What's not included" below).

Motivation

KV cache is the dominant memory bottleneck for long-context LLM inference. TurboQuant shows that symmetric INT8 quantisation of keys and values with per-head scales achieves ~4× memory reduction with negligible accuracy loss. The crucial insight is that the first few tokens ("attention sinks") must be kept at full precision — they receive disproportionately large attention scores and quantising them causes measurable degradation.

Implementation

Quantisation scheme

For a tensor x of shape (B, H, S, D), per-token per-head symmetric INT8 mapped to U8:

scale = clamp(max(|x|, dim=D), min=1e-6)   // shape (B, H, S, 1)
q     = clip(round(x/scale) + 128, 0, 255) // U8, same shape as x
x'    = (q.f32() - 128) * scale            // dequant to original dtype

Memory layout

Field Type Shape
k_sink, v_sink original dtype (B, H, n_sink, D)
k_q, v_q U8 (B, H, bulk, D)
k_scale, v_scale F32 (B, H, bulk, 1)

API (drop-in for existing caches)

let mut cache = QuantizedKvCache::new(/*dim=*/2, /*n_sink_tokens=*/4);

// Append during prefill or decode — returns dequantised tensors ready for attention.
let (k, v) = cache.append(&k_new, &v_new)?;

// Standard lifecycle methods.
cache.current_seq_len()
cache.is_empty()
cache.reset()

The interface is identical to KvCache / ConcatKvCache / RotatingKvCache, so it can be swapped in at the model level without touching attention code.

Tests

Five CPU-only unit tests:

  • test_quantized_kv_cache_roundtrip — max error < 0.02 (INT8 tolerance)
  • test_quantized_kv_cache_sinks_full_precision — sink tokens are lossless
  • test_quantized_kv_cache_mixed_sink_and_bulk — correct shapes with sink+bulk split
  • test_quantized_kv_cache_incremental_append — autoregressive append grows seq correctly
  • test_quantized_kv_cache_reset — reset clears all state

What's not included

This PR implements the KV quantisation + attention sink component of TurboQuant. The third component — cross-layer KV sharing (sharing K/V pairs across adjacent layers, which provides the bulk of the 6× aggregate compression) — requires per-model architectural changes and is a separate PR. This primitive is the foundation that makes cross-layer sharing composable.

candle-flash-attn's vendored kernels were last updated in December 2024
(PRs huggingface#2688huggingface#2690), locking us at an FA2 state before v2.7.0. The
upstream flash_attn Python package is now at v2.8.3, and we have
measured significant cosine divergence on long-context bf16 inference:
P95 cosine drops from 0.9997 at 0–200 tokens to ~0.72 at 20k+ tokens
vs HuggingFace transformers with flash_attention_2 forced. See huggingface#3515.

Changes
-------
* Replaced all updatable algorithmic headers and forward kernel .cu
  files with their FA2 v2.8.3 equivalents from
  Dao-AILab/flash-attention@v2.8.3 csrc/flash_attn/src/:
    - alibi.h, block_info.h, dropout.h, flash.h, hardware_info.h
    - kernel_traits.h, mask.h, philox.cuh, rotary.h, softmax.h
    - static_switch.h, utils.h
    - flash_fwd_launch_template.h, flash_fwd_kernel.h
    - All flash_fwd_hdim{32,64,96,128,192,256}_{fp16,bf16}{,_causal}_sm80.cu

* Added namespace_config.h (new in FA2 v2.8.3): defines FLASH_NAMESPACE
  which defaults to `flash`, keeping full ABI compatibility with candle's
  existing flash_api.cu.

* candle-specific adaptations (preserved from previous vendoring):
  - flash.h: PyTorch CUDAGeneratorImpl include and at::PhiloxCudaState
    member remain commented out (candle uses standalone CUDA, not PyTorch)
  - flash_fwd_launch_template.h: c10/cuda/CUDAException.h replaced with
    candle's own error.h which provides the C10_CUDA_CHECK macros
  - kernels not present upstream (hdim160/224/512): left unchanged as
    they are candle-specific additions using the old algorithmic headers

Needs GPU validation
--------------------
This PR cannot be fully validated without an SM80+ GPU. The expected
observable fix is that long-context bf16 cosine similarity vs Python
flash_attn 2.8.3 converges at every token band (target: P95 >= 0.999
for 0–500 tokens, improvement above 500).

Reproducer from issue huggingface#3515:
  - Model: jinaai/jina-embeddings-v4 (Qwen2.5-VL base, 36 layers)
  - Hardware: A6000 (sm86), bf16, single GPU
  - Metric: pooled final-layer embedding cosine vs HuggingFace FA2

Fixes huggingface#3515
…ion sinks

Implements the KV-cache quantisation component of TurboQuant (Google
Research, ICLR 2024) as a new `QuantizedKvCache` struct in
`candle-nn::kv_cache`.

## What it does

Symmetric INT8 quantisation (via U8 + 128 bias) with per-token per-head
scales reduces KV-cache memory by ~4× vs BF16/F16, which is the key
bottleneck for long-context inference. The first `n_sink_tokens` positions
are kept at full precision because initial tokens accumulate
disproportionately large attention scores ("attention sinks") and
quantising them causes measurable accuracy degradation.

Quantisation scheme (per token per head):
  scale = clamp(max(|x|, dim=D), min=1e-6)   // shape (B, H, S, 1)
  q     = clip(round(x/scale) + 128, 0, 255) // U8
  x'    = (q.f32() - 128) * scale            // dequant

## API

```rust
// 4 sink tokens, sequence dimension=2 (standard attention layout)
let mut cache = QuantizedKvCache::new(/*dim=*/2, /*n_sink_tokens=*/4);
let (k_out, v_out) = cache.append(&k, &v)?;  // returns dequantised f32
```

The struct follows the same append/reset/current_seq_len interface used by
`KvCache`, `ConcatKvCache`, and `RotatingKvCache`, so it can be swapped in
without changing attention code.

## Tests

Five unit tests on CPU (no GPU needed):
- Round-trip quantisation error stays below INT8 tolerance (~0.02)
- Sink tokens are returned with zero loss
- Mixed sink+bulk path produces correct shapes
- Incremental (autoregressive) appends grow seq dimension correctly
- Reset clears all state

## What this does NOT include

Cross-layer KV sharing (the other component of TurboQuant that achieves
the 6× aggregate reduction) requires architectural changes to each model's
forward pass and is out of scope here. A follow-on issue or PR can build on
this primitive.

Addresses huggingface#3425.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement new TurboQuant research

1 participant