Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ examples/csharp/ModelChat/models
!test/models/qwen2-5-vl/*
!test/models/qwen3-5/*
!test/models/qwen3-vl/*
!test/models/static-scatter-bias-decoder/*
!test/models/whisper/*

.ipynb_checkpoints/
Expand Down
4 changes: 4 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@ struct DecoderInputs_Element : JSON::Element {
v_.current_sequence_length = JSON::Get<std::string_view>(value);
} else if (name == "total_sequence_length") {
v_.total_sequence_length = JSON::Get<std::string_view>(value);
} else if (name == "write_indices") {
v_.write_indices = JSON::Get<std::string_view>(value);
} else if (name == "nonpad_kv_seqlen") {
v_.nonpad_kv_seqlen = JSON::Get<std::string_view>(value);
} else if (name == "encoder_hidden_states") {
v_.encoder_hidden_states = JSON::Get<std::string_view>(value);
} else if (name == "encoder_attention_mask") {
Expand Down
6 changes: 6 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ struct Config {
static constexpr std::string_view PastSequenceLengthName = "past_sequence_length";
static constexpr std::string_view CurrentSequenceLengthName = "current_sequence_length";
static constexpr std::string_view TotalSequenceLengthName = "total_sequence_length";
static constexpr std::string_view WriteIndicesName = "write_indices";
static constexpr std::string_view NonpadKvSeqlenName = "nonpad_kv_seqlen";
static constexpr std::string_view CacheIndirectionName = "cache_indirection";
static constexpr std::string_view AlignmentHeadsName = "alignment_heads";
static constexpr std::string_view TokenTypeIdsName = "token_type_ids";
Expand Down Expand Up @@ -344,6 +346,10 @@ struct Config {
std::string past_sequence_length{Defaults::PastSequenceLengthName};
std::string current_sequence_length{Defaults::CurrentSequenceLengthName};
std::string total_sequence_length{Defaults::TotalSequenceLengthName};
// Static-scatter (TensorScatter) KV cache driver inputs ([batch] int64).
// Presence of write_indices selects StaticScatterKeyValueCache.
std::string write_indices{Defaults::WriteIndicesName};
std::string nonpad_kv_seqlen{Defaults::NonpadKvSeqlenName};
Comment thread
titaiwangms marked this conversation as resolved.
std::string cache_indirection{Defaults::CacheIndirectionName};
std::string encoder_hidden_states{Defaults::EncoderHiddenStatesName};
std::string rnn_prev_states{Defaults::RnnStatesPrevName};
Expand Down
33 changes: 33 additions & 0 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ DefaultInputIDs::DefaultInputIDs(State& state)
*past_sequence_length_->GetTensorMutableData<int32_t>() = -1;
}

if (model_.session_info_.HasInput(model_.config_->model.decoder.inputs.write_indices) &&
model_.session_info_.HasInput(model_.config_->model.decoder.inputs.nonpad_kv_seqlen)) {
if (state_.params_->BatchBeamSize() != 1) {
throw std::runtime_error("Batch size must be 1 for write_indices and nonpad_kv_seqlen inputs");
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
}
if (model_.session_info_.GetInputDataType(model_.config_->model.decoder.inputs.write_indices) != Ort::TypeToTensorType<int64_t> ||
model_.session_info_.GetInputDataType(model_.config_->model.decoder.inputs.nonpad_kv_seqlen) != Ort::TypeToTensorType<int64_t>)
throw std::runtime_error("write_indices and nonpad_kv_seqlen must be int64");

const std::array<int64_t, 1> static_scatter_shape{1};
write_indices_ = OrtValue::CreateTensor(model_.allocator_cpu_, static_scatter_shape, Ort::TypeToTensorType<int64_t>);
nonpad_kv_seqlen_ = OrtValue::CreateTensor(model_.allocator_cpu_, static_scatter_shape, Ort::TypeToTensorType<int64_t>);
*write_indices_->GetTensorMutableData<int64_t>() = 0;
*nonpad_kv_seqlen_->GetTensorMutableData<int64_t>() = 0;
}

value_ = std::make_unique<Tensor>(model_.p_device_inputs_, Ort::TypeToTensorType<int32_t>);
cast_value_ = std::make_unique<Tensor>(model_.p_device_inputs_, Ort::TypeToTensorType<int64_t>);
}
Expand All @@ -45,6 +61,13 @@ void DefaultInputIDs::Add() {
state_.input_names_.push_back(model_.config_->model.decoder.inputs.past_sequence_length.c_str());
state_.inputs_.push_back(past_sequence_length_.get());
}

if (write_indices_ && nonpad_kv_seqlen_) {
state_.input_names_.push_back(model_.config_->model.decoder.inputs.write_indices.c_str());
state_.inputs_.push_back(write_indices_.get());
state_.input_names_.push_back(model_.config_->model.decoder.inputs.nonpad_kv_seqlen.c_str());
state_.inputs_.push_back(nonpad_kv_seqlen_.get());
}
}

void DefaultInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
Expand All @@ -67,6 +90,16 @@ void DefaultInputIDs::Update(DeviceSpan<int32_t> new_tokens) {
*past_sequence_length_->GetTensorMutableData<int32_t>() += new_sequence_length;
}

if (write_indices_ && nonpad_kv_seqlen_) {
if (state_.params_->BatchBeamSize() != 1) {
throw std::runtime_error("Batch size must be 1 for write_indices and nonpad_kv_seqlen inputs");
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
}
auto new_sequence_length = get_unpadded_sequence_length(new_tokens_cpu, model_.config_->model.pad_token_id);
const StaticScatterIndices indices = static_scatter_indices_.Advance(new_sequence_length);
*write_indices_->GetTensorMutableData<int64_t>() = indices.write_index;
*nonpad_kv_seqlen_->GetTensorMutableData<int64_t>() = indices.nonpad_kv_seqlen;
}

// For beam search, resize input_ids shape based on new_tokens
size_t sequence_length = static_cast<size_t>(new_tokens.size()) / state_.params_->BatchBeamSize();
if (is_prompt_ && state_.params_->search.num_beams > 1)
Expand Down
9 changes: 9 additions & 0 deletions src/models/input_ids.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "static_scatter_indices.h"

namespace Generators {

struct InputIDs {
Expand Down Expand Up @@ -41,6 +43,13 @@ struct DefaultInputIDs : InputIDs {

std::unique_ptr<OrtValue> current_sequence_length_;
std::unique_ptr<OrtValue> past_sequence_length_;

// Static-scatter (TensorScatter) KV cache driver inputs, created only when the
// model declares write_indices + nonpad_kv_seqlen. Both are [batch] int64 CPU
// tensors; their per-step values come from static_scatter_indices_.
std::unique_ptr<OrtValue> write_indices_;
std::unique_ptr<OrtValue> nonpad_kv_seqlen_;
StaticScatterIndexTracker static_scatter_indices_;
};

// Certain models can only process a fixed number of tokens at a time.
Expand Down
122 changes: 122 additions & 0 deletions src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,119 @@ void ModelManagedKeyValueCache::RewindTo(size_t index) {
state_.ep_dynamic_options_next_run_.push_back({"kvcache_rewind", std::to_string(index)});
}

bool StaticScatterKeyValueCache::IsStaticScatterCache(const Model& model) {
// Both driver inputs must be present. input_ids.cpp only produces the indices
// when it sees write_indices AND nonpad_kv_seqlen, so requiring just one here
// would create the cache for a model whose indices never get bound, surfacing
// as an obscure unbound-input error at Run. Keep this predicate in lockstep
// with the producer gate in input_ids.cpp.
return model.session_info_.HasInput(model.config_->model.decoder.inputs.write_indices) &&
model.session_info_.HasInput(model.config_->model.decoder.inputs.nonpad_kv_seqlen);
}

StaticScatterKeyValueCache::StaticScatterKeyValueCache(State& state)
: state_{state},
layer_count_{model_.config_->model.decoder.num_hidden_layers} {
if (state_.params_->search.num_beams != 1) {
throw std::runtime_error("Beam search (num_beams > 1) is not supported by the static-scatter KV cache.");
}
Comment thread
titaiwangms marked this conversation as resolved.

// Auto-discover which layer indices have KV cache inputs (mirrors
// DefaultKeyValueCache so sparse/hybrid layouts work the same way).
{
const auto& key_template = model_.config_->model.decoder.inputs.past_key_names;
auto prefix = key_template.substr(0, key_template.find('%'));
auto suffix = key_template.substr(key_template.find('%') + 2);
for (const auto& name : model_.session_info_.GetInputNames()) {
if (name.size() > prefix.size() + suffix.size() &&
name.compare(0, prefix.size(), prefix) == 0 &&
name.compare(name.size() - suffix.size(), suffix.size(), suffix) == 0) {
auto idx_str = name.substr(prefix.size(), name.size() - prefix.size() - suffix.size());
kv_layer_indices_.push_back(std::stoi(idx_str));
}
}
std::sort(kv_layer_indices_.begin(), kv_layer_indices_.end());
}

if (!kv_layer_indices_.empty()) {
layer_count_ = static_cast<int>(kv_layer_indices_.size());
}

for (int i = 0; i < layer_count_; ++i) {
int layer_idx = kv_layer_indices_.empty() ? i : kv_layer_indices_[i];
input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_key_names, layer_idx));
input_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.inputs.past_value_names, layer_idx));
output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.present_key_names, layer_idx));
output_name_strings_.emplace_back(ComposeKeyValueName(model_.config_->model.decoder.outputs.present_value_names, layer_idx));
}

type_ = model_.session_info_.GetInputDataType(input_name_strings_[0]);

// Each KV input declares shape [batch, max_seq_len, kv_hidden]. The batch dim
// is a runtime property (symbolic in the graph), so take it from the params
// like DefaultKeyValueCache; only max_seq_len and kv_hidden must be static.
// kv_hidden = num_kv_heads * head_dim and may vary per layer, so read each
// layer's own declared shape rather than assuming a uniform value.
const int64_t batch_size = state_.params_->BatchBeamSize();
layer_shapes_.resize(layer_count_ * 2);
caches_.reserve(layer_count_ * 2);
for (int i = 0; i < layer_count_ * 2; ++i) {
const auto input_shape = model_.session_info_.GetInputShape(input_name_strings_[i]);
if (input_shape.size() != 3) {
throw std::runtime_error(
"StaticScatterKeyValueCache expects 3D [batch, max_seq_len, kv_hidden] KV inputs, but '" +
input_name_strings_[i] + "' has rank " + std::to_string(input_shape.size()) + ".");
}
// max_seq_len (axis 1) and kv_hidden (axis 2) size the fixed buffer and must
// be concrete; the batch dim (axis 0) is allowed to be symbolic.
for (size_t axis = 1; axis < 3; ++axis) {
if (input_shape[axis] <= 0) {
throw std::runtime_error(
"StaticScatterKeyValueCache requires a static max_seq_len and kv_hidden, but '" +
input_name_strings_[i] + "' has a non-concrete dim at axis " + std::to_string(axis) + ".");
}
}
std::array<int64_t, 3> tensor_shape{batch_size, input_shape[1], input_shape[2]};
layer_shapes_[i] = tensor_shape;

caches_.push_back(OrtValue::CreateTensor(Allocator(), tensor_shape, type_));
if (Device().GetType() != DeviceType::WEBGPU) {
ByteWrapTensor(Device(), *caches_.back()).Zero();
}
}
}

void StaticScatterKeyValueCache::Add() {
input_index_ = state_.inputs_.size();
output_index_ = state_.outputs_.size();

// Past and present share one buffer: TensorScatter writes new rows in place,
// so key_cache.{i} (input) and updated_key_cache.{i} (output) point at the
// same OrtValue and never need rebinding between steps.
for (int i = 0; i < layer_count_ * 2; ++i) {
state_.inputs_.push_back(caches_[i].get());
state_.input_names_.push_back(input_name_strings_[i].c_str());
state_.outputs_.push_back(caches_[i].get());
state_.output_names_.push_back(output_name_strings_[i].c_str());
}
}

void StaticScatterKeyValueCache::Update(DeviceSpan<int32_t> /*beam_indices*/, int /*total_length*/) {
// No-op: the shared buffer is updated in place by the graph's TensorScatter,
// and the write offset / valid length are carried by the write_indices /
// nonpad_kv_seqlen inputs (see input_ids.cpp), not by rebinding tensors here.
}

void StaticScatterKeyValueCache::RewindTo(size_t /*index*/) {
// Fail loud: rewind is NOT wired for the static-scatter cache. The
// write_indices/nonpad_kv_seqlen stream lives in InputIDs and has no RewindTo
// hook, so a silent no-op here would leave the index tracker stale -> wrong
// scatter slots and an over-reported nonpad => silently wrong logits with no
// error. Throw until rewind is properly wired, matching the LFM2Cache and
// WindowedKeyValueCache siblings.
throw std::runtime_error("StaticScatterKeyValueCache does not support RewindTo.");
}

LFM2Cache::LFM2Cache(State& state)
: state_{state},
layer_types_{model_.config_->model.decoder.layer_types},
Expand Down Expand Up @@ -947,6 +1060,15 @@ std::unique_ptr<KeyValueCache> CreateKeyValueCache(State& state) {
return nullptr;
}

// mobius static-cache decoders drive an in-place TensorScatter KV buffer via
// the write_indices input; auto-detect that (no user-visible search flag,
// mirroring DetectAndConfigureFixedKvShape) before the default fallback.
if (StaticScatterKeyValueCache::IsStaticScatterCache(state.model_)) {
if (g_log.enabled)
Log("info", "CreateKeyValueCache: Creating StaticScatterKeyValueCache");
return std::make_unique<StaticScatterKeyValueCache>(state);
}

if (state.model_.p_device_->GetType() != DeviceType::NvTensorRtRtx &&
state.model_.config_->model.decoder.sliding_window &&
state.model_.config_->model.decoder.sliding_window->slide_key_value_cache) {
Expand Down
49 changes: 49 additions & 0 deletions src/models/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,55 @@ struct LFM2Cache : KeyValueCache {

std::string ComposeKeyValueName(const std::string& template_string, int index);

// A static-scatter KV cache for mobius-exported static-cache decoders.
//
// The model pre-allocates each layer's KV as a fixed 3D buffer
// [batch, max_seq_len, kv_hidden] and writes new rows in place via opset-24
// TensorScatter (driven by the write_indices / nonpad_kv_seqlen inputs produced
// in input_ids.cpp), reading them back through Attention. Because the scatter is
// in place, past and present share one buffer: Add() binds key_cache.{i} and
// updated_key_cache.{i} to the same OrtValue, and Update() is a no-op (mirroring
// DefaultKeyValueCache's past_present_share_buffer path). RewindTo() is NOT
// supported and throws: the write_indices/nonpad_kv_seqlen index stream lives in
// InputIDs with no rewind hook, so rewinding would silently desynchronize it.
//
// Differs from DefaultKeyValueCache only in the allocator: it consumes mobius's
// 3D emission directly (vs the 4D [batch, num_kv_heads, seq, head_dim] layout),
// and kv_hidden may vary per layer (e.g. Gemma-4 sliding GQA 8*256 vs global
// MQA 1*512), read from each layer's own declared input shape.
struct StaticScatterKeyValueCache : KeyValueCache {
StaticScatterKeyValueCache(State& state);

// True if the model declares BOTH static-scatter driver inputs (write_indices
// and nonpad_kv_seqlen); kept in lockstep with the producer gate in input_ids.
static bool IsStaticScatterCache(const Model& model);

void Add() override;
void Update(DeviceSpan<int32_t> beam_indices, int total_length) override;
void RewindTo(size_t index) override;

private:
DeviceInterface& Device() { return *model_.p_device_kvcache_; }
Ort::Allocator& Allocator() { return model_.p_device_kvcache_->GetAllocator(); }

State& state_;
const Model& model_{state_.model_};
int layer_count_;
// Bind offsets recorded by Add() for parity with the other KV caches; unused
// here because the shared buffer never needs rebinding between steps.
size_t input_index_{~0U}, output_index_{~0U};

// Auto-discovered KV layer indices (sparse for hybrid models).
std::vector<int> kv_layer_indices_;
// Per-layer static shape [batch, max_seq_len, kv_hidden]; kv_hidden may vary.
std::vector<std::array<int64_t, 3>> layer_shapes_;
ONNXTensorElementDataType type_;

// One shared past/present buffer per key and per value tensor (2 per layer).
std::vector<std::unique_ptr<OrtValue>> caches_;
std::vector<std::string> input_name_strings_, output_name_strings_;
};

std::unique_ptr<KeyValueCache> CreateKeyValueCache(State& state);

} // namespace Generators
57 changes: 57 additions & 0 deletions src/models/static_scatter_indices.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <cstdint>

namespace Generators {

// The per-step index pair a static-scatter (TensorScatter) KV cache needs.
//
// A mobius-exported static-cache decoder consumes a pre-allocated KV buffer of
// shape [batch, max_seq_len, kv_hidden] and writes each step's new key/value
// rows into it in place via TensorScatter, then reads them back through
// Attention. Two int64 [batch] inputs drive that:
// * write_indices - the cache row offset TensorScatter writes this step's
// rows at (i.e. how many valid tokens are already cached
// BEFORE this step).
// * nonpad_kv_seqlen - the number of valid cached tokens AFTER this step,
// which Attention reads as the per-batch seqlens_k.
struct StaticScatterIndices {
int64_t write_index; // valid cache tokens before this step (scatter offset)
int64_t nonpad_kv_seqlen; // valid cache tokens after this step (Attention seqlens_k)
};

// Tracks the running static-scatter cache indices for a single (batch==1)
// generation stream. Kept as a standalone, dependency-free helper so the
// off-by-one / init behaviour can be unit-tested without standing up a Model.
//
// Sequencing contract (the crux mobius and genai must agree on):
// * The very first step (prefill) writes at row 0 and reports nonpad equal to
// the number of prefill tokens.
// * Each subsequent step's write_index is the PREVIOUS step's nonpad_kv_seqlen,
// so rows are appended contiguously with no gap or overlap.
// This deliberately does NOT reuse genai's existing past_sequence_length scalar
// (which inits to -1 and is consumed differently); mixing the two would yield
// nonpad = 2N-1 after a length-N prefill instead of N.
class StaticScatterIndexTracker {
public:
// Advance one generation step that appended new_unpadded_tokens valid tokens
// (the prompt length on prefill, normally 1 per decode step). Returns the
// index pair to bind for THIS step, then folds the new tokens into the
// running total for the next step.
StaticScatterIndices Advance(int64_t new_unpadded_tokens) {
const int64_t write_index = valid_tokens_;
valid_tokens_ += new_unpadded_tokens;
return {write_index, valid_tokens_};
}

// Valid cached tokens before the next step. Zero before any Advance().
int64_t valid_tokens() const { return valid_tokens_; }

private:
int64_t valid_tokens_{0};
};

} // namespace Generators
Loading
Loading