Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ examples/csharp/ModelChat/models
!test/models/qwen2-5-vl/*
!test/models/qwen3-5/*
!test/models/qwen3-vl/*
!test/models/static-scatter-bias-decoder/*
!test/models/whisper/*

.ipynb_checkpoints/
Expand Down
4 changes: 4 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@ struct DecoderInputs_Element : JSON::Element {
v_.current_sequence_length = JSON::Get<std::string_view>(value);
} else if (name == "total_sequence_length") {
v_.total_sequence_length = JSON::Get<std::string_view>(value);
} else if (name == "write_indices") {
v_.write_indices = JSON::Get<std::string_view>(value);
} else if (name == "nonpad_kv_seqlen") {
v_.nonpad_kv_seqlen = JSON::Get<std::string_view>(value);
} else if (name == "encoder_hidden_states") {
v_.encoder_hidden_states = JSON::Get<std::string_view>(value);
} else if (name == "encoder_attention_mask") {
Expand Down
8 changes: 8 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ struct Config {
static constexpr std::string_view PastSequenceLengthName = "past_sequence_length";
static constexpr std::string_view CurrentSequenceLengthName = "current_sequence_length";
static constexpr std::string_view TotalSequenceLengthName = "total_sequence_length";
static constexpr std::string_view WriteIndicesName = "write_indices";
static constexpr std::string_view NonpadKvSeqlenName = "nonpad_kv_seqlen";
static constexpr std::string_view CacheIndirectionName = "cache_indirection";
static constexpr std::string_view AlignmentHeadsName = "alignment_heads";
static constexpr std::string_view TokenTypeIdsName = "token_type_ids";
Expand Down Expand Up @@ -344,6 +346,12 @@ struct Config {
std::string past_sequence_length{Defaults::PastSequenceLengthName};
std::string current_sequence_length{Defaults::CurrentSequenceLengthName};
std::string total_sequence_length{Defaults::TotalSequenceLengthName};
// Static-scatter (TensorScatter) KV cache driver inputs ([batch] int64).
// Both write_indices AND nonpad_kv_seqlen must be present to select
// StaticScatterKeyValueCache: IsStaticScatterCache() (the factory
// predicate) and the DefaultInputIDs producer both require both inputs.
std::string write_indices{Defaults::WriteIndicesName};
std::string nonpad_kv_seqlen{Defaults::NonpadKvSeqlenName};
std::string cache_indirection{Defaults::CacheIndirectionName};
std::string encoder_hidden_states{Defaults::EncoderHiddenStatesName};
std::string rnn_prev_states{Defaults::RnnStatesPrevName};
Expand Down
33 changes: 33 additions & 0 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ DefaultInputIDs::DefaultInputIDs(State& state)
*past_sequence_length_->GetTensorMutableData<int32_t>() = -1;
}

if (model_.session_info_.HasInput(model_.config_->model.decoder.inputs.write_indices) &&
model_.session_info_.HasInput(model_.config_->model.decoder.inputs.nonpad_kv_seqlen)) {
if (state_.params_->BatchBeamSize() != 1) {
throw std::runtime_error("Batch beam size (batch_size * num_beams) must be 1 for write_indices and nonpad_kv_seqlen inputs");
}
if (model_.session_info_.GetInputDataType(model_.config_->model.decoder.inputs.write_indices) != Ort::TypeToTensorType<int64_t> ||
model_.session_info_.GetInputDataType(model_.config_->model.decoder.inputs.nonpad_kv_seqlen) != Ort::TypeToTensorType<int64_t>)
throw std::runtime_error("write_indices and nonpad_kv_seqlen must be int64");

const std::array<int64_t, 1> static_scatter_shape{1};
write_indices_ = OrtValue::CreateTensor(model_.allocator_cpu_, static_scatter_shape, Ort::TypeToTensorType<int64_t>);
nonpad_kv_seqlen_ = OrtValue::CreateTensor(model_.allocator_cpu_, static_scatter_shape, Ort::TypeToTensorType<int64_t>);
*write_indices_->GetTensorMutableData<int64_t>() = 0;
*nonpad_kv_seqlen_->GetTensorMutableData<int64_t>() = 0;
}

value_ = std::make_unique<Tensor>(model_.p_device_inputs_, Ort::TypeToTensorType<int32_t>);
cast_value_ = std::make_unique<Tensor>(model_.p_device_inputs_, Ort::TypeToTensorType<int64_t>);
}
Expand All @@ -45,6 +61,13 @@ void DefaultInputIDs::Add() {
state_.input_names_.push_back(model_.config_->model.decoder.inputs.past_sequence_length.c_str());
state_.inputs_.push_back(past_sequence_length_.get());
}

if (write_indices_ && nonpad_kv_seqlen_) {
state_.input_names_.push_back(model_.config_->model.decoder.inputs.write_indices.c_str());
state_.inputs_.push_back(write_indices_.get());
state_.input_names_.push_back(model_.config_->model.decoder.inputs.nonpad_kv_seqlen.c_str());
state_.inputs_.push_back(nonpad_kv_seqlen_.get());
}
}

void DefaultInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
Expand All @@ -67,6 +90,16 @@ void DefaultInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
*past_sequence_length_->GetTensorMutableData<int32_t>() += new_sequence_length;
}

if (write_indices_ && nonpad_kv_seqlen_) {
if (state_.params_->BatchBeamSize() != 1) {
throw std::runtime_error("Batch beam size (batch_size * num_beams) must be 1 for write_indices and nonpad_kv_seqlen inputs");
}
auto new_sequence_length = get_unpadded_sequence_length(new_tokens_cpu, model_.config_->model.pad_token_id);
const StaticScatterIndices indices = static_scatter_indices_.Advance(new_sequence_length);
*write_indices_->GetTensorMutableData<int64_t>() = indices.write_index;
*nonpad_kv_seqlen_->GetTensorMutableData<int64_t>() = indices.nonpad_kv_seqlen;
}

// For beam search, resize input_ids shape based on new_tokens
size_t sequence_length = static_cast<size_t>(new_tokens.size()) / state_.params_->BatchBeamSize();
if (is_prompt_ && state_.params_->search.num_beams > 1)
Expand Down
9 changes: 9 additions & 0 deletions src/models/input_ids.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "static_scatter_indices.h"

namespace Generators {

struct InputIDs {
Expand Down Expand Up @@ -41,6 +43,13 @@ struct DefaultInputIDs : InputIDs {

std::unique_ptr<OrtValue> current_sequence_length_;
std::unique_ptr<OrtValue> past_sequence_length_;

// Static-scatter (TensorScatter) KV cache driver inputs, created only when the
// model declares write_indices + nonpad_kv_seqlen. Both are [batch] int64 CPU
// tensors; their per-step values come from static_scatter_indices_.
std::unique_ptr<OrtValue> write_indices_;
std::unique_ptr<OrtValue> nonpad_kv_seqlen_;
StaticScatterIndexTracker static_scatter_indices_;
};

// Certain models can only process a fixed number of tokens at a time.
Expand Down
134 changes: 134 additions & 0 deletions src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "../generators.h"
#include "model.h"
#include "kv_cache.h"
#include "static_scatter_indices.h"
#include "windowed_kv_cache.h"
#include "../openvino/interface.h"
#include "../qnn/interface.h"
Expand Down Expand Up @@ -708,6 +709,130 @@ void ModelManagedKeyValueCache::RewindTo(size_t index) {
state_.ep_dynamic_options_next_run_.push_back({"kvcache_rewind", std::to_string(index)});
}

bool StaticScatterKeyValueCache::IsStaticScatterCache(const Model& model) {
// Both driver inputs must be present. input_ids.cpp only produces the indices
// when it sees write_indices AND nonpad_kv_seqlen, so requiring just one here
// would create the cache for a model whose indices never get bound, surfacing
// as an obscure unbound-input error at Run. Keep this predicate in lockstep
// with the producer gate in input_ids.cpp.
//
// NAME-COLLISION ASSUMPTION (rank-3 tightening = TRACKED FOLLOW-UP, not done
// here): selection is purely on the presence of both driver inputs, BEFORE the
// rank-3 static layout is checked. A model that declares both write_indices and
// nonpad_kv_seqlen but uses a 4D KV layout would be routed here and then hit the
// ctor's `rank != 3` throw, rather than falling back to Default/Windowed.
//
// A proper fix would factor KV-input discovery into a shared helper (the index
// parse already lives in DiscoverKvLayerIndices(); the missing piece is locating
// a KV input name and reading its rank at selection time) so that this predicate
// returns false on a rank-4 layout and the factory falls back to Default/
// Windowed. That is DEFERRED for this PR: it touches the sensitive factory-
// selection path and risks layer-discovery drift for sparse/hybrid models that
// may lack a layer-0 KV input. The collision is considered unlikely (the two
// driver inputs are specific to this layout), so we accept the ctor throw over a
// risky attempt-and-fall-back restructuring until the follow-up lands.
return model.session_info_.HasInput(model.config_->model.decoder.inputs.write_indices) &&
model.session_info_.HasInput(model.config_->model.decoder.inputs.nonpad_kv_seqlen);
}

StaticScatterKeyValueCache::StaticScatterKeyValueCache(State& state)
: state_{state},
layer_count_{model_.config_->model.decoder.num_hidden_layers} {
if (state_.params_->search.num_beams != 1) {
throw std::runtime_error("Beam search (num_beams > 1) is not supported by the static-scatter KV cache.");
}
Comment thread
titaiwangms marked this conversation as resolved.
// The index tracker advances a single slot per step and cannot be forked
// across batch or beam dimensions; keep this aligned with the DefaultInputIDs
// producer gate, which also requires BatchBeamSize()==1.
if (state_.params_->BatchBeamSize() != 1) {
throw std::runtime_error("The static-scatter KV cache requires batch beam size (batch_size * num_beams) == 1.");
}

// Auto-discover which layer indices have KV cache inputs (mirrors
// DefaultKeyValueCache so sparse/hybrid layouts work the same way). The strict
// parse / dedup lives in DiscoverKvLayerIndices (static_scatter_indices.h) so
// it can be unit-tested without standing up a Model.
{
const auto& key_template = model_.config_->model.decoder.inputs.past_key_names;
auto prefix = key_template.substr(0, key_template.find('%'));
auto suffix = key_template.substr(key_template.find('%') + 2);
kv_layer_indices_ = DiscoverKvLayerIndices(model_.session_info_.GetInputNames(), prefix, suffix);
}

if (!kv_layer_indices_.empty()) {
layer_count_ = static_cast<int>(kv_layer_indices_.size());
}

for (int i = 0; i < layer_count_; ++i) {
int layer_idx = kv_layer_indices_.empty() ? i : kv_layer_indices_[i];
input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_key_names, layer_idx));
input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_value_names, layer_idx));
output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.present_key_names, layer_idx));
output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.present_value_names, layer_idx));
}

type_ = model_.session_info_.GetInputDataType(input_name_strings_[0]);

// Each KV input declares shape [batch, max_seq_len, kv_hidden]. The batch dim
// is a runtime property (symbolic in the graph), so take it from the params
// like DefaultKeyValueCache; only max_seq_len and kv_hidden must be static.
// kv_hidden = num_kv_heads * head_dim and may vary per layer, so read each
// layer's own declared shape rather than assuming a uniform value.
const int64_t batch_size = state_.params_->BatchBeamSize();
caches_.reserve(layer_count_ * 2);
for (int i = 0; i < layer_count_ * 2; ++i) {
const auto input_shape = model_.session_info_.GetInputShape(input_name_strings_[i]);
if (input_shape.size() != 3) {
throw std::runtime_error(
"StaticScatterKeyValueCache expects 3D [batch, max_seq_len, kv_hidden] KV inputs, but '" +
input_name_strings_[i] + "' has rank " + std::to_string(input_shape.size()) + ".");
}
// max_seq_len (axis 1) and kv_hidden (axis 2) size the fixed buffer and must
// be concrete; the batch dim (axis 0) is allowed to be symbolic.
for (size_t axis = 1; axis < 3; ++axis) {
if (input_shape[axis] <= 0) {
throw std::runtime_error(
"StaticScatterKeyValueCache requires a static max_seq_len and kv_hidden, but '" +
input_name_strings_[i] + "' has a non-concrete dim at axis " + std::to_string(axis) + ".");
}
}
std::array<int64_t, 3> tensor_shape{batch_size, input_shape[1], input_shape[2]};

caches_.push_back(OrtValue::CreateTensor(Allocator(), tensor_shape, type_));
if (Device().GetType() != DeviceType::WEBGPU) {
ByteWrapTensor(Device(), *caches_.back()).Zero();
}
}
}

void StaticScatterKeyValueCache::Add() {
// Past and present share one buffer: TensorScatter writes new rows in place,
// so key_cache.{i} (input) and updated_key_cache.{i} (output) point at the
// same OrtValue and never need rebinding between steps.
for (int i = 0; i < layer_count_ * 2; ++i) {
state_.inputs_.push_back(caches_[i].get());
state_.input_names_.push_back(input_name_strings_[i].c_str());
state_.outputs_.push_back(caches_[i].get());
state_.output_names_.push_back(output_name_strings_[i].c_str());
}
}

void StaticScatterKeyValueCache::Update(DeviceSpan<int32_t> /*beam_indices*/, int /*total_length*/) {
// No-op: the shared buffer is updated in place by the graph's TensorScatter,
// and the write offset / valid length are carried by the write_indices /
// nonpad_kv_seqlen inputs (see input_ids.cpp), not by rebinding tensors here.
}

void StaticScatterKeyValueCache::RewindTo(size_t /*index*/) {
// Fail loud: rewind is NOT wired for the static-scatter cache. The
// write_indices/nonpad_kv_seqlen stream lives in InputIDs and has no RewindTo
// hook, so a silent no-op here would leave the index tracker stale -> wrong
// scatter slots and an over-reported nonpad => silently wrong logits with no
// error. Throw until rewind is properly wired, matching the LFM2Cache and
// WindowedKeyValueCache siblings.
throw std::runtime_error("StaticScatterKeyValueCache does not support RewindTo.");
}

LFM2Cache::LFM2Cache(State& state)
: state_{state},
layer_types_{model_.config_->model.decoder.layer_types},
Expand Down Expand Up @@ -947,6 +1072,15 @@ std::unique_ptr<KeyValueCache> CreateKeyValueCache(State& state) {
return nullptr;
}

// mobius static-cache decoders drive an in-place TensorScatter KV buffer via
// the write_indices input; auto-detect that (no user-visible search flag,
// mirroring DetectAndConfigureFixedKvShape) before the default fallback.
if (StaticScatterKeyValueCache::IsStaticScatterCache(state.model_)) {
if (g_log.enabled)
Log("info", "CreateKeyValueCache: Creating StaticScatterKeyValueCache");
return std::make_unique<StaticScatterKeyValueCache>(state);
}

if (state.model_.p_device_->GetType() != DeviceType::NvTensorRtRtx &&
state.model_.config_->model.decoder.sliding_window &&
state.model_.config_->model.decoder.sliding_window->slide_key_value_cache) {
Expand Down
48 changes: 48 additions & 0 deletions src/models/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,54 @@ struct LFM2Cache : KeyValueCache {

std::string ComposeKeyValueName(const std::string& template_string, int index);

// A static-scatter KV cache for mobius-exported static-cache decoders.
//
// The model pre-allocates each layer's KV as a fixed 3D buffer
// [batch, max_seq_len, kv_hidden] and writes new rows in place via opset-24
// TensorScatter (driven by the write_indices / nonpad_kv_seqlen inputs produced
// in input_ids.cpp), reading them back through Attention. Because the scatter is
// in place, past and present share one buffer: Add() binds key_cache.{i} and
// updated_key_cache.{i} to the same OrtValue, and Update() is a no-op (mirroring
// DefaultKeyValueCache's past_present_share_buffer path). RewindTo() is NOT
// supported and throws: the write_indices/nonpad_kv_seqlen index stream lives in
// InputIDs with no rewind hook, so rewinding would silently desynchronize it.
//
// Distinct from DefaultKeyValueCache in two ways. (1) Layout: it consumes
// mobius's 3D emission directly (vs the 4D [batch, num_kv_heads, seq, head_dim]
// layout), and kv_hidden may vary per layer (e.g. Gemma-4 sliding GQA 8*256 vs
// global MQA 1*512), read from each layer's own declared input shape.
// (2) RewindTo: DefaultKeyValueCache rewinds by reshaping its buffers, whereas
// this cache THROWS (rewind is unsupported), because the write_indices /
// nonpad_kv_seqlen index stream lives in InputIDs with no rewind hook and a
// silent no-op would desynchronize it.
struct StaticScatterKeyValueCache : KeyValueCache {
StaticScatterKeyValueCache(State& state);

// True if the model declares BOTH static-scatter driver inputs (write_indices
// and nonpad_kv_seqlen); kept in lockstep with the producer gate in input_ids.
static bool IsStaticScatterCache(const Model& model);

void Add() override;
void Update(DeviceSpan<int32_t> beam_indices, int total_length) override;
void RewindTo(size_t index) override;

private:
DeviceInterface& Device() { return *model_.p_device_kvcache_; }
Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); }

State& state_;
const Model& model_{state_.model_};
int layer_count_;

// Auto-discovered KV layer indices (sparse for hybrid models).
std::vector<int> kv_layer_indices_;
ONNXTensorElementDataType type_;

// One shared past/present buffer per key and per value tensor (2 per layer).
std::vector<std::unique_ptr<OrtValue>> caches_;
std::vector<std::string> input_name_strings_, output_name_strings_;
};

std::unique_ptr<KeyValueCache> CreateKeyValueCache(State& state);

} // namespace Generators
Loading
Loading