Skip to content

Add StaticScatterKeyValueCache: drive mobius bias-aware external-KV (TensorScatter + nonpad_kv_seqlen) static cache#2235

Open
titaiwangms wants to merge 8 commits into
mainfrom
static-scatter-kv-cache
Open

Add StaticScatterKeyValueCache: drive mobius bias-aware external-KV (TensorScatter + nonpad_kv_seqlen) static cache#2235
titaiwangms wants to merge 8 commits into
mainfrom
static-scatter-kv-cache

Conversation

@titaiwangms

Copy link
Copy Markdown
Contributor

What

Adds StaticScatterKeyValueCache, a new KV-cache variant that lets onnxruntime-genai
drive a mobius-exported bias-aware external-KV static-cache decoder directly. It:

  • Reuses the existing past_present_share_buffer lifecycle: each key_cache.{i} /
    value_cache.{i} is a pre-allocated 3D [batch, max_seq_len, kv_hidden] buffer that
    the graph updates in place via opset-24 TensorScatter, so past and present alias
    one OrtValue and never need rebinding between steps.
  • Adds a small [B] int64 write_indices / nonpad_kv_seqlen producer in
    input_ids (driven by a standalone StaticScatterIndexTracker): write_indices is the
    scatter row offset (valid tokens before this step), nonpad_kv_seqlen is the valid
    cached length after this step (Attention's per-batch seqlens_k).
  • Factory auto-detect: CreateKeyValueCache selects this variant when the model
    declares both write_indices and nonpad_kv_seqlen inputs (kept in lockstep with
    the producer gate so a half-declared model can't create the cache with unbound indices).
  • Consumes mobius's 3D external-KV emission directly (vs the 4D
    [batch, num_kv_heads, seq, head_dim] default layout), and supports per-layer-varying
    kv_hidden
    (e.g. Gemma-4 sliding GQA vs global MQA), read from each layer's own
    declared input shape. The batch dim is taken from BatchBeamSize() (symbolic in the
    graph); only max_seq_len / kv_hidden must be static.

Why

genai's existing GQA past_present_share_buffer path cannot carry an arbitrary float
attention bias
. mobius now emits float-bias decoders (causal + sliding-window built
in-graph) whose KV is an external TensorScatter cache — that path can carry the bias.
This PR is the host-side consumer for that graph, recovering near-GQA decode behaviour
on the memory-efficient-attention (MEA) path for bias-aware models like Gemma-4.

Relationship to onnxruntime-genai#2204

#2204 proposed Path A = a prefill/decode GQA graph-split. This PR implements Path B
= a single TensorScatter external-KV graph, which sidesteps the cross-stage KV handoff
that #2204 flagged as the crux. Referenced as related; this PR does not close #2204.

Mobius side

The bias-aware external-KV graph this cache consumes is emitted by mobius
(onnxruntime/mobius#367, issue onnxruntime/mobius#366). This PR is the genai-side consumer.

Testing

  • StaticScatterKeyValueCache.EndToEndFixtureParity — bit-exact vs an ORT-CPU-MEA golden
    on a committed slice-A fixture (prefill seq=4 @ write 0/nonpad 4, then decode @ write
    4/nonpad 5): logits argmax/sum + updated_key_cache.0 sums.
  • StaticScatterKeyValueCache.RewindToThrows — deny-by-default guard: RewindTo() throws
    (the write_indices/nonpad_kv_seqlen stream has no rewind hook, so a silent no-op would
    desync the tracker; matches the LFM2Cache / WindowedKeyValueCache siblings).
  • 5 StaticScatterIndexTracker producer tests pinning the off-by-one / init contract.
  • 7/7 green. CPU / MEA path — not gated on onnxruntime#28958.

Scope / limits

  • batch == 1 (beam search unsupported, throws).
  • MEA ≈ GQA decode ceiling, not Flash: the arbitrary float bias precludes the Flash path.
  • Per-layer kv_hidden supported (Gemma-4 sliding/global mix).
  • The slice-A fixture is committed under test/models/ per the existing test-model
    convention, tracked via a .gitignore allowlist entry matching its siblings.

titaiwangms and others added 4 commits June 20, 2026 00:26
Implement Path B: a 3D in-place TensorScatter KV cache that consumes mobius
static-cache emission directly ([batch, max_seq_len, kv_hidden] FLOAT, with
per-layer-varying kv_hidden), avoiding a mobius 4D re-export.

- input_ids: add the write_indices / nonpad_kv_seqlen [batch] int64 producer,
  gated on the model declaring those inputs (batch==1, like the existing
  current/past_sequence_length producer). The per-step values come from a new
  dependency-free StaticScatterIndexTracker: write_index = valid tokens before
  the step (TensorScatter row offset), nonpad = valid tokens after. The first
  prefill step writes at row 0 (NOT -1); each later step appends at the previous
  step nonpad. This deliberately does not reuse past_sequence_length (init -1),
  whose semantics would yield nonpad = 2N-1 after a length-N prefill.

- kv_cache: add StaticScatterKeyValueCache, reusing the past_present_share_buffer
  lifecycle verbatim (Add binds key_cache.{i} and updated_key_cache.{i} to the
  same OrtValue; Update/RewindTo are no-ops). New is only a small 3D per-layer
  allocator that reads each layer declared shape, so per-layer kv_hidden
  variation (e.g. Gemma-4 sliding GQA 8*256 vs global MQA 1*512) is handled
  naturally. Auto-detected in CreateKeyValueCache via HasInput(write_indices),
  before the Default fallback, with no new search flag.

- config: add write_indices / nonpad_kv_seqlen Inputs fields + Defaults names +
  JSON parse. Cache name templates (key_cache.%d / updated_key_cache.%d) are set
  via genai_config.json, no struct change.

- test: unit-test the StaticScatterIndexTracker off-by-one / init / sequencing
  contract; scaffold (DISABLED) the e2e fixture-parity test owned by the
  build-test-genai task.

Built against onnxruntime-genai origin/main (6450690, includes the #2214 Gemma-4
per-layer KV fix). Library + unit_tests compile; 5 producer tests pass.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwang <titaiwang@microsoft.com>
Replace the DISABLED_EndToEndFixtureParity scaffold with a real e2e test that
drives the mobius slice-A bias-aware external-KV static-cache fixture (#366,
Part of #349) through the public Oga generator API. AppendTokens runs the model,
DefaultInputIDs feeds write_indices/nonpad_kv_seqlen via StaticScatterIndexTracker,
and StaticScatterKeyValueCache binds the 3D in-place share-buffer KV cache.

Forces the golden prompt [1,2,3,4] (write_index 0, nonpad 4) then decode token 5
(write_index 4, nonpad 5) and asserts last-token logits argmax/sum and the
updated_key_cache.0 element-sum against the frozen golden (ORT CPU MEA, opset 24).
Runs on the CPU EP to match the CPU golden. Verified: 7/7 StaticScatter* gtests
pass; full all-tensor parity is bit-exact (maxabs=0.000e+00).

Fixture (model.onnx + genai_config.json) embedded under test/models/ and
force-added since test/models/* is gitignored (matches existing test fixtures).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwang <titaiwang@microsoft.com>
…ate, nits

Triple-review fixes on the genai Path B static-scatter KV cache:

- M1 (silent KV corruption): StaticScatterKeyValueCache::RewindTo() was a no-op
  whose comment falsely claimed it reset the index stream. RewindTo cannot reset
  the write_indices/nonpad_kv_seqlen stream (it lives in InputIDs with no rewind
  hook), and DecoderOnly_State::RewindTo + the public OgaGenerator::RewindTo
  reach it, so a no-op silently desynchronizes the tracker -> wrong scatter slots
  + over-reported nonpad => wrong logits. Make it throw, matching the LFM2Cache
  and WindowedKeyValueCache siblings. Remove the now-dead StaticScatterIndex-
  Tracker::Reset() (its only rationale was rewind).

- Minor (detection predicate mismatch): IsStaticScatterCache() required only
  write_indices, but the input_ids producer gate requires BOTH write_indices and
  nonpad_kv_seqlen. A write_indices-only model would get the cache created but
  the indices never bound (obscure unbound-input error). Require both, in lockstep
  with the producer.

- Readability: rename StaticScatterIndices::nonpad_seqlen -> nonpad_kv_seqlen to
  match the field name used everywhere else; annotate the retained input_index_/
  output_index_ members as parity-only/unused.

- Test: add DISABLED_RewindToThrows pinning the M1 fail-loud contract through the
  public generator API (enabled once the slice-A fixture is wired as a genai model
  dir under test/models/, shared with the e2e parity test). Remove the Reset test.

5 producer tracker tests stay green; library + unit_tests build clean.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwang <titaiwang@microsoft.com>
…throw test

The review-fix commit (d3a3837) rewrote the whole test file and unintentionally
clobbered qa's intervening e2e enablement (5ee880d), because 5ee880d landed on
the shared branch between this developer's read and write. 5ee880d's test also
predated the nonpad_seqlen->nonpad_kv_seqlen rename and the Reset() removal, so
it no longer compiled against the fixed source. This reconciles both:

- Restore qa's EndToEndFixtureParity test (enabled, real StaticScatterKeyValueCache
  C++ path against the #366 slice-A fixture via the public Oga generator API),
  adapted to the renamed field. Golden parity unchanged.
- Enable the M1 RewindToThrows test against the now-committed fixture
  (test/models/static-scatter-bias-decoder) instead of a DISABLED stub. The
  happy-path e2e never exercises rewind, so this pins the fail-loud contract with
  a running assertion (EXPECT_THROW on generator->RewindTo).
- Keep the 5 producer tracker tests (renamed field, Reset test dropped).
- Track the fixture dir via a .gitignore allowlist entry
  (!test/models/static-scatter-bias-decoder/*), matching the sibling test-model
  dirs, rather than relying on a force-add.

7/7 static-scatter tests pass (5 tracker + RewindToThrows + EndToEndFixtureParity);
unit_tests build clean.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwang <titaiwang@microsoft.com>
Copilot AI review requested due to automatic review settings June 20, 2026 00:51
@titaiwangms titaiwangms requested a review from a team as a code owner June 20, 2026 00:51

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new KV-cache implementation to let onnxruntime-genai drive mobius-exported “static-scatter” (opset-24 TensorScatter) external-KV decoders, including per-step production of write_indices and nonpad_kv_seqlen and an end-to-end fixture-backed parity test.

Changes:

  • Introduces StaticScatterKeyValueCache (3D [B, max_seq_len, kv_hidden] per-layer KV buffers updated in-place) and auto-detection via declared model inputs.
  • Adds StaticScatterIndexTracker plus DefaultInputIDs plumbing to produce/bind write_indices and nonpad_kv_seqlen.
  • Adds a committed slice-A test fixture + new unit tests (producer contract tests, rewind guard, and end-to-end parity).

Reviewed changes

Copilot reviewed 9 out of 11 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
test/static_scatter_kv_cache_test.cpp New unit tests for index-tracker semantics and end-to-end parity against a committed fixture.
test/models/static-scatter-bias-decoder/genai_config.json Adds a minimal test model config declaring write_indices/nonpad_kv_seqlen and cache I/O names.
src/models/static_scatter_indices.h Defines the StaticScatterIndices pair and StaticScatterIndexTracker helper.
src/models/kv_cache.h Declares StaticScatterKeyValueCache interface and members.
src/models/kv_cache.cpp Implements StaticScatterKeyValueCache, shape discovery/allocation, and factory auto-detect.
src/models/input_ids.h Extends DefaultInputIDs to hold static-scatter driver tensors + tracker.
src/models/input_ids.cpp Creates/binds/updates write_indices and nonpad_kv_seqlen when declared by the model.
src/config.h Adds default config names for write_indices and nonpad_kv_seqlen.
src/config.cpp Parses the new decoder input names from genai_config.json.
.gitignore Allowlists the new committed test fixture directory under test/models/.

Comment thread test/static_scatter_kv_cache_test.cpp
Comment thread src/config.h
Comment thread src/models/kv_cache.cpp
Comment thread src/models/input_ids.cpp Outdated
Comment thread src/models/input_ids.cpp Outdated
titaiwangms and others added 3 commits June 22, 2026 21:03
Add the missing <algorithm> include (std::max_element at line 105 was
used without it, breaking all platform builds) and apply clang-format
to clear the lint-cpp violation on the comment alignment.

Test-only change; production sources and the committed fixture are
untouched. unit_tests builds clean and StaticScatter gtests are 7/7
green (5 tracker + RewindToThrows + EndToEndFixtureParity).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwang <titaiwang@microsoft.com>
Resolve the 5 Copilot inline review comments on PR #2235:

- test/static_scatter_kv_cache_test.cpp: add <functional> for
  std::multiplies (was relying on a transitive include).
- src/config.h: clarify that BOTH write_indices AND nonpad_kv_seqlen
  must be present to select StaticScatterKeyValueCache, matching the
  factory predicate and the input_ids producer gate.
- src/models/kv_cache.cpp: enforce BatchBeamSize() == 1 in the
  StaticScatterKeyValueCache constructor, in lockstep with the
  DefaultInputIDs producer gate (single-stream index tracker).
- src/models/input_ids.cpp: reword the two write_indices/nonpad_kv_seqlen
  guards to say "batch beam size (batch_size * num_beams) must be 1",
  since the check is BatchBeamSize() != 1, not batch size alone.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwang <titaiwang@microsoft.com>
The two StaticScatterKeyValueCache tests (RewindToThrows, EndToEndFixtureParity)
drive the real cache path end-to-end through the Oga generator, which executes
the fixture's TensorScatter(24) node. TensorScatter has CPU and CUDA kernels but
no DirectML implementation, so under a USE_DML build (which routes these models
onto the DML EP) session.run throws "Could not find an implementation for
TensorScatter(24)" and both tests fail at the Run-tests step.

This surfaced only after the <algorithm> include fix let the DirectML build
compile the test file (the earlier max_element compile error masked it). CUDA,
CPU and WebGPU builds are green.

Guard the two fixture-backed tests (and their helpers/golden) with #if !USE_DML,
matching the established pattern in c_api_tests.cpp. The static-cache Flash
feature is CPU/CUDA-targeted; the StaticScatterIndexTracker unit tests are pure
C++ and remain enabled on every EP.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwang <titaiwang@microsoft.com>

@titaiwangms titaiwangms left a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review summary — StaticScatterKeyValueCache (Path B consumer)

Reviewed at 2fd84e8 by a multi-model review team plus independent call-trace verification. Mergeable as scoped infrastructure; no Critical correctness defect found.

Verified correct ✅

  • Prefill index ordering is sound (I traced decoder_only.cpp): input_ids_.Update()StaticScatterIndexTracker::Advance(N) runs before State::Run, so prefill binds write=0, nonpad=N — the ctor's 0/0 default is dead, and the feared one-step lag is not present.
  • Index semantics agree with mobius: write=valid_before, nonpad=valid_after, satisfying mobius's nonpad == write + S_q for unpadded chunks. The tracker unit tests pin the off-by-one contract well.
  • RewindTo fail-loud throw, factory predicate in lockstep with the producer gate (both inputs required), per-layer-varying kv_hidden, and batch==1 enforced on both sides — all consistent.

Major — factory hard-fails instead of falling back

IsStaticScatterCache() selects on presence of both driver inputs before the other cache types in CreateKeyValueCache. A model that declares those inputs but uses 4D KV would hit the constructor's rank != 3 throw rather than falling back to Default/Windowed. Name collision is unlikely (low real risk), but consider tightening the predicate (require the rank-3 static layout) or making the factory attempt-and-fall-back. (kv_cache.cpp:711-719, 1069-1076)

Minor

  • Layer-index parse: std::stoi(idx_str) accepts malformed names ("0.bad" → 0) and doesn't reject duplicates. Use std::from_chars / stoi with pos == size and validate non-negative + unique. (kv_cache.cpp:740-746)
  • Test runtime probe: StaticScatterRuntimeAvailable() skips on any exception message containing "TensorScatter". A real TensorScatter regression could be misclassified as "kernel absent" and silently skip parity. Match the precise missing-kernel signature ("Could not find an implementation" + "TensorScatter(24)") and rethrow others. (test/static_scatter_kv_cache_test.cpp)
  • E2E parity strength: the fixture asserts argmax + sum (1e-2 tol). sum can pass wrong-but-sum-preserving outputs (permutation/cancellation). Add full-tensor or slice-wise parity against the golden.

Open question

  • Decode token equal to pad_token_idAdvance(0) leaves write_index/nonpad unchanged; two such steps alias a cache slot. Mirrors the existing current_sequence_length logic and is benign if pad==eos. Confirm pad_token_id can't be a legal mid-stream generated token for the targeted models.

Scope clarity (not a defect)

  • batch==1 is an enforced, documented MVP constraint (driver tensors allocated [1], throws on BatchBeamSize()!=1), even though the mobius graph + tests support [batch]. Intentional on both sides — worth tracking as a known limitation, not a break.
  • KV name mapping: this consumer reads names from genai_config.json, while mobius emits key_cache.{i} / updated_key_cache.{i}. The fixture maps them manually; production needs the genai_config generator to emit matching names.

Nits

  • input_index_ / output_index_ / layer_shapes_ are set in Add() but never read back — annotate clearly or remove.
  • kv_cache.h comment "differs only in the allocator" erases the throw-vs-silent RewindTo split vs DefaultKeyValueCache; the struct field write_index (singular) vs the tensor write_indices (plural) is a small cross-boundary translation.

Praise: static_scatter_indices.h is a textbook extraction of a tricky off-by-one contract into a minimal, unit-testable type — the "crux mobius and genai must agree on" framing is exactly right. The two-tier test layout (pure-C++ tracker tests on every EP + fixture-backed e2e guarded by #if !USE_DML + runtime probe) is thoughtfully done.

Triple-review finalize of the static-scatter (TensorScatter + nonpad_kv_seqlen)
KV cache review fixes:

- Strict KV layer-index parse: factor the discovery loop into a standalone,
  model-free DiscoverKvLayerIndices() helper (static_scatter_indices.h) that
  uses std::from_chars to require the index segment be a complete non-negative
  integer (std::stoi silently accepted trailing junk like "0.bad" -> 0) and
  rejects duplicate layer indices. Add targeted unit tests (well-formed sort,
  throw-on-malformed, throw-on-duplicate).
- Strengthen EndToEndFixtureParity from argmax+sum to full-tensor element-wise
  parity (ExpectAllClose @ 1e-3) against a generated golden header captured from
  the fixture's authoritative ORT CPU MEA run; the header documents its
  provenance and regeneration recipe and is marked do-not-edit-by-hand.
- Tighten StaticScatterRuntimeAvailable() to skip only on the precise
  missing-kernel signature ("Could not find an implementation for" AND
  "TensorScatter(24)") and rethrow anything else, so a real TensorScatter
  regression is not silently skipped.
- Remove dead StaticScatter fields (input_index_/output_index_/layer_shapes_)
  and correct the kv_cache.h comment to document the 3D-vs-4D layout and the
  RewindTo throws-vs-reshape split (not "differ only in the allocator").
- Document the pad_token aliasing assumption in StaticScatterIndexTracker
  (pad_token_id == eos_token_id ends the sequence before the stalled slot is
  reread) and record the factory rank-3 tightening as a tracked follow-up.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwang <titaiwang@microsoft.com>
@titaiwangms titaiwangms force-pushed the static-scatter-kv-cache branch from dc4326f to 72e7199 Compare June 23, 2026 18:30

@titaiwangms titaiwangms left a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-review at 72e7199 — prior findings addressed; one deferred with rationale ✅

Re-checked the new commit "Address review feedback on static-scatter KV cache". I independently verified the new golden by running the committed fixture model.onnx fresh on ORT 1.27 CPU: it matches golden_io.npz (≤1e-4), the scalar sum/argmax constants match, and the C++ header arrays (static_scatter_golden.h) match the fresh run exactly (maxdiff = 0.0) for all four tensors (256-elem logits ×2, 512-elem cache ×2).

Prior finding Status
Minorstd::stoi accepted malformed layer indices ("0.bad" → 0), no dedup ✅ Factored into a model-free DiscoverKvLayerIndices() using std::from_chars with a complete-consume check (parse_end == end), non-negative guard, and duplicate rejection. New unit tests cover well-formed/sort, throw-on-malformed, throw-on-duplicate.
Minor — runtime probe skipped on any "TensorScatter" substring (could mask a real regression) ✅ Tightened to require both "Could not find an implementation for" and "TensorScatter(24)"; rethrows anything else.
Minor — e2e parity used only argmax + sum (passes wrong-but-sum-preserving outputs) ✅ Added full-tensor ExpectAllClose @ 1e-3 against a generated, provenance-documented golden header (kept the sum/argmax as headline checks). Verified the golden is authentic and reproducible.
Nit — dead input_index_ / output_index_ / layer_shapes_ ✅ Removed.
Nit — "differs only in the allocator" erased the RewindTo throw-vs-reshape split ✅ Comment rewritten to document both the 3D-vs-4D layout and the RewindTo difference.
Open question — decode token == pad_token_idAdvance(0) slot aliasing ✅ Documented the assumption in StaticScatterIndexTracker::Advance (safe because pad_token_id == eos_token_id ends the sequence before the stalled slot is reread; would only break if a model made pad a legal mid-stream token).
Major — factory hard-fails (rank != 3 throw) instead of falling back to Default/Windowed Deferred as a tracked follow-up with an in-code rationale (the rank-3 tightening touches the sensitive factory-selection path and risks layer-discovery drift for sparse/hybrid models lacking a layer-0 KV input; collision considered unlikely). Reasonable disposition — acceptable to land now and follow up.

DiscoverKvLayerIndices logic reviewed by reading: the size() > prefix+suffix guard keeps idx_str non-empty, and from_chars/negative/duplicate checks are correct. I did not rebuild the genai C++ (expensive) but verified the golden provenance directly, which is the strongest available check. LGTM with the factory fallback tracked separately.

@titaiwangms titaiwangms requested a review from justinchuby June 23, 2026 20:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Data-dependent bidirectional mask (Gemma-4 vision-block) forces standard Attention over GQA — split prefill/decode decoder graphs?

2 participants