Add StaticScatterKeyValueCache: drive mobius bias-aware external-KV (TensorScatter + nonpad_kv_seqlen) static cache#2235
Add StaticScatterKeyValueCache: drive mobius bias-aware external-KV (TensorScatter + nonpad_kv_seqlen) static cache#2235titaiwangms wants to merge 8 commits into
Conversation
Implement Path B: a 3D in-place TensorScatter KV cache that consumes mobius
static-cache emission directly ([batch, max_seq_len, kv_hidden] FLOAT, with
per-layer-varying kv_hidden), avoiding a mobius 4D re-export.
- input_ids: add the write_indices / nonpad_kv_seqlen [batch] int64 producer,
gated on the model declaring those inputs (batch==1, like the existing
current/past_sequence_length producer). The per-step values come from a new
dependency-free StaticScatterIndexTracker: write_index = valid tokens before
the step (TensorScatter row offset), nonpad = valid tokens after. The first
prefill step writes at row 0 (NOT -1); each later step appends at the previous
step nonpad. This deliberately does not reuse past_sequence_length (init -1),
whose semantics would yield nonpad = 2N-1 after a length-N prefill.
- kv_cache: add StaticScatterKeyValueCache, reusing the past_present_share_buffer
lifecycle verbatim (Add binds key_cache.{i} and updated_key_cache.{i} to the
same OrtValue; Update/RewindTo are no-ops). New is only a small 3D per-layer
allocator that reads each layer declared shape, so per-layer kv_hidden
variation (e.g. Gemma-4 sliding GQA 8*256 vs global MQA 1*512) is handled
naturally. Auto-detected in CreateKeyValueCache via HasInput(write_indices),
before the Default fallback, with no new search flag.
- config: add write_indices / nonpad_kv_seqlen Inputs fields + Defaults names +
JSON parse. Cache name templates (key_cache.%d / updated_key_cache.%d) are set
via genai_config.json, no struct change.
- test: unit-test the StaticScatterIndexTracker off-by-one / init / sequencing
contract; scaffold (DISABLED) the e2e fixture-parity test owned by the
build-test-genai task.
Built against onnxruntime-genai origin/main (6450690, includes the #2214 Gemma-4
per-layer KV fix). Library + unit_tests compile; 5 producer tests pass.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwang <titaiwang@microsoft.com>
Replace the DISABLED_EndToEndFixtureParity scaffold with a real e2e test that drives the mobius slice-A bias-aware external-KV static-cache fixture (#366, Part of #349) through the public Oga generator API. AppendTokens runs the model, DefaultInputIDs feeds write_indices/nonpad_kv_seqlen via StaticScatterIndexTracker, and StaticScatterKeyValueCache binds the 3D in-place share-buffer KV cache. Forces the golden prompt [1,2,3,4] (write_index 0, nonpad 4) then decode token 5 (write_index 4, nonpad 5) and asserts last-token logits argmax/sum and the updated_key_cache.0 element-sum against the frozen golden (ORT CPU MEA, opset 24). Runs on the CPU EP to match the CPU golden. Verified: 7/7 StaticScatter* gtests pass; full all-tensor parity is bit-exact (maxabs=0.000e+00). Fixture (model.onnx + genai_config.json) embedded under test/models/ and force-added since test/models/* is gitignored (matches existing test fixtures). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: titaiwang <titaiwang@microsoft.com>
…ate, nits Triple-review fixes on the genai Path B static-scatter KV cache: - M1 (silent KV corruption): StaticScatterKeyValueCache::RewindTo() was a no-op whose comment falsely claimed it reset the index stream. RewindTo cannot reset the write_indices/nonpad_kv_seqlen stream (it lives in InputIDs with no rewind hook), and DecoderOnly_State::RewindTo + the public OgaGenerator::RewindTo reach it, so a no-op silently desynchronizes the tracker -> wrong scatter slots + over-reported nonpad => wrong logits. Make it throw, matching the LFM2Cache and WindowedKeyValueCache siblings. Remove the now-dead StaticScatterIndex- Tracker::Reset() (its only rationale was rewind). - Minor (detection predicate mismatch): IsStaticScatterCache() required only write_indices, but the input_ids producer gate requires BOTH write_indices and nonpad_kv_seqlen. A write_indices-only model would get the cache created but the indices never bound (obscure unbound-input error). Require both, in lockstep with the producer. - Readability: rename StaticScatterIndices::nonpad_seqlen -> nonpad_kv_seqlen to match the field name used everywhere else; annotate the retained input_index_/ output_index_ members as parity-only/unused. - Test: add DISABLED_RewindToThrows pinning the M1 fail-loud contract through the public generator API (enabled once the slice-A fixture is wired as a genai model dir under test/models/, shared with the e2e parity test). Remove the Reset test. 5 producer tracker tests stay green; library + unit_tests build clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: titaiwang <titaiwang@microsoft.com>
…throw test The review-fix commit (d3a3837) rewrote the whole test file and unintentionally clobbered qa's intervening e2e enablement (5ee880d), because 5ee880d landed on the shared branch between this developer's read and write. 5ee880d's test also predated the nonpad_seqlen->nonpad_kv_seqlen rename and the Reset() removal, so it no longer compiled against the fixed source. This reconciles both: - Restore qa's EndToEndFixtureParity test (enabled, real StaticScatterKeyValueCache C++ path against the #366 slice-A fixture via the public Oga generator API), adapted to the renamed field. Golden parity unchanged. - Enable the M1 RewindToThrows test against the now-committed fixture (test/models/static-scatter-bias-decoder) instead of a DISABLED stub. The happy-path e2e never exercises rewind, so this pins the fail-loud contract with a running assertion (EXPECT_THROW on generator->RewindTo). - Keep the 5 producer tracker tests (renamed field, Reset test dropped). - Track the fixture dir via a .gitignore allowlist entry (!test/models/static-scatter-bias-decoder/*), matching the sibling test-model dirs, rather than relying on a force-add. 7/7 static-scatter tests pass (5 tracker + RewindToThrows + EndToEndFixtureParity); unit_tests build clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: titaiwang <titaiwang@microsoft.com>
There was a problem hiding this comment.
Pull request overview
Adds a new KV-cache implementation to let onnxruntime-genai drive mobius-exported “static-scatter” (opset-24 TensorScatter) external-KV decoders, including per-step production of write_indices and nonpad_kv_seqlen and an end-to-end fixture-backed parity test.
Changes:
- Introduces
StaticScatterKeyValueCache(3D[B, max_seq_len, kv_hidden]per-layer KV buffers updated in-place) and auto-detection via declared model inputs. - Adds
StaticScatterIndexTrackerplusDefaultInputIDsplumbing to produce/bindwrite_indicesandnonpad_kv_seqlen. - Adds a committed slice-A test fixture + new unit tests (producer contract tests, rewind guard, and end-to-end parity).
Reviewed changes
Copilot reviewed 9 out of 11 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| test/static_scatter_kv_cache_test.cpp | New unit tests for index-tracker semantics and end-to-end parity against a committed fixture. |
| test/models/static-scatter-bias-decoder/genai_config.json | Adds a minimal test model config declaring write_indices/nonpad_kv_seqlen and cache I/O names. |
| src/models/static_scatter_indices.h | Defines the StaticScatterIndices pair and StaticScatterIndexTracker helper. |
| src/models/kv_cache.h | Declares StaticScatterKeyValueCache interface and members. |
| src/models/kv_cache.cpp | Implements StaticScatterKeyValueCache, shape discovery/allocation, and factory auto-detect. |
| src/models/input_ids.h | Extends DefaultInputIDs to hold static-scatter driver tensors + tracker. |
| src/models/input_ids.cpp | Creates/binds/updates write_indices and nonpad_kv_seqlen when declared by the model. |
| src/config.h | Adds default config names for write_indices and nonpad_kv_seqlen. |
| src/config.cpp | Parses the new decoder input names from genai_config.json. |
| .gitignore | Allowlists the new committed test fixture directory under test/models/. |
Add the missing <algorithm> include (std::max_element at line 105 was used without it, breaking all platform builds) and apply clang-format to clear the lint-cpp violation on the comment alignment. Test-only change; production sources and the committed fixture are untouched. unit_tests builds clean and StaticScatter gtests are 7/7 green (5 tracker + RewindToThrows + EndToEndFixtureParity). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: titaiwang <titaiwang@microsoft.com>
Resolve the 5 Copilot inline review comments on PR #2235: - test/static_scatter_kv_cache_test.cpp: add <functional> for std::multiplies (was relying on a transitive include). - src/config.h: clarify that BOTH write_indices AND nonpad_kv_seqlen must be present to select StaticScatterKeyValueCache, matching the factory predicate and the input_ids producer gate. - src/models/kv_cache.cpp: enforce BatchBeamSize() == 1 in the StaticScatterKeyValueCache constructor, in lockstep with the DefaultInputIDs producer gate (single-stream index tracker). - src/models/input_ids.cpp: reword the two write_indices/nonpad_kv_seqlen guards to say "batch beam size (batch_size * num_beams) must be 1", since the check is BatchBeamSize() != 1, not batch size alone. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: titaiwang <titaiwang@microsoft.com>
The two StaticScatterKeyValueCache tests (RewindToThrows, EndToEndFixtureParity) drive the real cache path end-to-end through the Oga generator, which executes the fixture's TensorScatter(24) node. TensorScatter has CPU and CUDA kernels but no DirectML implementation, so under a USE_DML build (which routes these models onto the DML EP) session.run throws "Could not find an implementation for TensorScatter(24)" and both tests fail at the Run-tests step. This surfaced only after the <algorithm> include fix let the DirectML build compile the test file (the earlier max_element compile error masked it). CUDA, CPU and WebGPU builds are green. Guard the two fixture-backed tests (and their helpers/golden) with #if !USE_DML, matching the established pattern in c_api_tests.cpp. The static-cache Flash feature is CPU/CUDA-targeted; the StaticScatterIndexTracker unit tests are pure C++ and remain enabled on every EP. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: titaiwang <titaiwang@microsoft.com>
titaiwangms
left a comment
There was a problem hiding this comment.
Review summary — StaticScatterKeyValueCache (Path B consumer)
Reviewed at 2fd84e8 by a multi-model review team plus independent call-trace verification. Mergeable as scoped infrastructure; no Critical correctness defect found.
Verified correct ✅
- Prefill index ordering is sound (I traced
decoder_only.cpp):input_ids_.Update()→StaticScatterIndexTracker::Advance(N)runs beforeState::Run, so prefill bindswrite=0, nonpad=N— the ctor's0/0default is dead, and the feared one-step lag is not present. - Index semantics agree with mobius:
write=valid_before,nonpad=valid_after, satisfying mobius'snonpad == write + S_qfor unpadded chunks. The tracker unit tests pin the off-by-one contract well. RewindTofail-loud throw, factory predicate in lockstep with the producer gate (both inputs required), per-layer-varyingkv_hidden, and batch==1 enforced on both sides — all consistent.
Major — factory hard-fails instead of falling back
IsStaticScatterCache() selects on presence of both driver inputs before the other cache types in CreateKeyValueCache. A model that declares those inputs but uses 4D KV would hit the constructor's rank != 3 throw rather than falling back to Default/Windowed. Name collision is unlikely (low real risk), but consider tightening the predicate (require the rank-3 static layout) or making the factory attempt-and-fall-back. (kv_cache.cpp:711-719, 1069-1076)
Minor
- Layer-index parse:
std::stoi(idx_str)accepts malformed names ("0.bad"→ 0) and doesn't reject duplicates. Usestd::from_chars/stoiwithpos == sizeand validate non-negative + unique. (kv_cache.cpp:740-746) - Test runtime probe:
StaticScatterRuntimeAvailable()skips on any exception message containing"TensorScatter". A real TensorScatter regression could be misclassified as "kernel absent" and silently skip parity. Match the precise missing-kernel signature ("Could not find an implementation"+"TensorScatter(24)") and rethrow others. (test/static_scatter_kv_cache_test.cpp) - E2E parity strength: the fixture asserts
argmax+sum(1e-2 tol).sumcan pass wrong-but-sum-preserving outputs (permutation/cancellation). Add full-tensor or slice-wise parity against the golden.
Open question
- Decode token equal to
pad_token_id→Advance(0)leaveswrite_index/nonpadunchanged; two such steps alias a cache slot. Mirrors the existingcurrent_sequence_lengthlogic and is benign ifpad==eos. Confirmpad_token_idcan't be a legal mid-stream generated token for the targeted models.
Scope clarity (not a defect)
- batch==1 is an enforced, documented MVP constraint (driver tensors allocated
[1], throws onBatchBeamSize()!=1), even though the mobius graph + tests support[batch]. Intentional on both sides — worth tracking as a known limitation, not a break. - KV name mapping: this consumer reads names from
genai_config.json, while mobius emitskey_cache.{i}/updated_key_cache.{i}. The fixture maps them manually; production needs the genai_config generator to emit matching names.
Nits
input_index_/output_index_/layer_shapes_are set inAdd()but never read back — annotate clearly or remove.kv_cache.hcomment "differs only in the allocator" erases the throw-vs-silentRewindTosplit vsDefaultKeyValueCache; the struct fieldwrite_index(singular) vs the tensorwrite_indices(plural) is a small cross-boundary translation.
Praise: static_scatter_indices.h is a textbook extraction of a tricky off-by-one contract into a minimal, unit-testable type — the "crux mobius and genai must agree on" framing is exactly right. The two-tier test layout (pure-C++ tracker tests on every EP + fixture-backed e2e guarded by #if !USE_DML + runtime probe) is thoughtfully done.
Triple-review finalize of the static-scatter (TensorScatter + nonpad_kv_seqlen)
KV cache review fixes:
- Strict KV layer-index parse: factor the discovery loop into a standalone,
model-free DiscoverKvLayerIndices() helper (static_scatter_indices.h) that
uses std::from_chars to require the index segment be a complete non-negative
integer (std::stoi silently accepted trailing junk like "0.bad" -> 0) and
rejects duplicate layer indices. Add targeted unit tests (well-formed sort,
throw-on-malformed, throw-on-duplicate).
- Strengthen EndToEndFixtureParity from argmax+sum to full-tensor element-wise
parity (ExpectAllClose @ 1e-3) against a generated golden header captured from
the fixture's authoritative ORT CPU MEA run; the header documents its
provenance and regeneration recipe and is marked do-not-edit-by-hand.
- Tighten StaticScatterRuntimeAvailable() to skip only on the precise
missing-kernel signature ("Could not find an implementation for" AND
"TensorScatter(24)") and rethrow anything else, so a real TensorScatter
regression is not silently skipped.
- Remove dead StaticScatter fields (input_index_/output_index_/layer_shapes_)
and correct the kv_cache.h comment to document the 3D-vs-4D layout and the
RewindTo throws-vs-reshape split (not "differ only in the allocator").
- Document the pad_token aliasing assumption in StaticScatterIndexTracker
(pad_token_id == eos_token_id ends the sequence before the stalled slot is
reread) and record the factory rank-3 tightening as a tracked follow-up.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: titaiwang <titaiwang@microsoft.com>
dc4326f to
72e7199
Compare
titaiwangms
left a comment
There was a problem hiding this comment.
Re-review at 72e7199 — prior findings addressed; one deferred with rationale ✅
Re-checked the new commit "Address review feedback on static-scatter KV cache". I independently verified the new golden by running the committed fixture model.onnx fresh on ORT 1.27 CPU: it matches golden_io.npz (≤1e-4), the scalar sum/argmax constants match, and the C++ header arrays (static_scatter_golden.h) match the fresh run exactly (maxdiff = 0.0) for all four tensors (256-elem logits ×2, 512-elem cache ×2).
| Prior finding | Status |
|---|---|
Minor — std::stoi accepted malformed layer indices ("0.bad" → 0), no dedup |
✅ Factored into a model-free DiscoverKvLayerIndices() using std::from_chars with a complete-consume check (parse_end == end), non-negative guard, and duplicate rejection. New unit tests cover well-formed/sort, throw-on-malformed, throw-on-duplicate. |
Minor — runtime probe skipped on any "TensorScatter" substring (could mask a real regression) |
✅ Tightened to require both "Could not find an implementation for" and "TensorScatter(24)"; rethrows anything else. |
Minor — e2e parity used only argmax + sum (passes wrong-but-sum-preserving outputs) |
✅ Added full-tensor ExpectAllClose @ 1e-3 against a generated, provenance-documented golden header (kept the sum/argmax as headline checks). Verified the golden is authentic and reproducible. |
Nit — dead input_index_ / output_index_ / layer_shapes_ |
✅ Removed. |
| Nit — "differs only in the allocator" erased the RewindTo throw-vs-reshape split | ✅ Comment rewritten to document both the 3D-vs-4D layout and the RewindTo difference. |
Open question — decode token == pad_token_id → Advance(0) slot aliasing |
✅ Documented the assumption in StaticScatterIndexTracker::Advance (safe because pad_token_id == eos_token_id ends the sequence before the stalled slot is reread; would only break if a model made pad a legal mid-stream token). |
Major — factory hard-fails (rank != 3 throw) instead of falling back to Default/Windowed |
⏳ Deferred as a tracked follow-up with an in-code rationale (the rank-3 tightening touches the sensitive factory-selection path and risks layer-discovery drift for sparse/hybrid models lacking a layer-0 KV input; collision considered unlikely). Reasonable disposition — acceptable to land now and follow up. |
DiscoverKvLayerIndices logic reviewed by reading: the size() > prefix+suffix guard keeps idx_str non-empty, and from_chars/negative/duplicate checks are correct. I did not rebuild the genai C++ (expensive) but verified the golden provenance directly, which is the strongest available check. LGTM with the factory fallback tracked separately.
What
Adds
StaticScatterKeyValueCache, a new KV-cache variant that lets onnxruntime-genaidrive a mobius-exported bias-aware external-KV static-cache decoder directly. It:
past_present_share_bufferlifecycle: eachkey_cache.{i}/value_cache.{i}is a pre-allocated 3D[batch, max_seq_len, kv_hidden]buffer thatthe graph updates in place via opset-24
TensorScatter, so past and present aliasone
OrtValueand never need rebinding between steps.[B]int64write_indices/nonpad_kv_seqlenproducer ininput_ids(driven by a standaloneStaticScatterIndexTracker):write_indicesis thescatter row offset (valid tokens before this step),
nonpad_kv_seqlenis the validcached length after this step (Attention's per-batch
seqlens_k).CreateKeyValueCacheselects this variant when the modeldeclares both
write_indicesandnonpad_kv_seqleninputs (kept in lockstep withthe producer gate so a half-declared model can't create the cache with unbound indices).
[batch, num_kv_heads, seq, head_dim]default layout), and supports per-layer-varyingkv_hidden(e.g. Gemma-4 sliding GQA vs global MQA), read from each layer's owndeclared input shape. The batch dim is taken from
BatchBeamSize()(symbolic in thegraph); only
max_seq_len/kv_hiddenmust be static.Why
genai's existing GQA
past_present_share_bufferpath cannot carry an arbitrary floatattention bias. mobius now emits float-bias decoders (causal + sliding-window built
in-graph) whose KV is an external
TensorScattercache — that path can carry the bias.This PR is the host-side consumer for that graph, recovering near-GQA decode behaviour
on the memory-efficient-attention (MEA) path for bias-aware models like Gemma-4.
Relationship to onnxruntime-genai#2204
#2204 proposed Path A = a prefill/decode GQA graph-split. This PR implements Path B
= a single
TensorScatterexternal-KV graph, which sidesteps the cross-stage KV handoffthat #2204 flagged as the crux. Referenced as related; this PR does not close #2204.
Mobius side
The bias-aware external-KV graph this cache consumes is emitted by mobius
(onnxruntime/mobius#367, issue onnxruntime/mobius#366). This PR is the genai-side consumer.
Testing
StaticScatterKeyValueCache.EndToEndFixtureParity— bit-exact vs an ORT-CPU-MEA goldenon a committed slice-A fixture (prefill seq=4 @ write 0/nonpad 4, then decode @ write
4/nonpad 5): logits argmax/sum +
updated_key_cache.0sums.StaticScatterKeyValueCache.RewindToThrows— deny-by-default guard:RewindTo()throws(the
write_indices/nonpad_kv_seqlenstream has no rewind hook, so a silent no-op woulddesync the tracker; matches the
LFM2Cache/WindowedKeyValueCachesiblings).StaticScatterIndexTrackerproducer tests pinning the off-by-one / init contract.Scope / limits
batch == 1(beam search unsupported, throws).kv_hiddensupported (Gemma-4 sliding/global mix).test/models/per the existing test-modelconvention, tracked via a
.gitignoreallowlist entry matching its siblings.