Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/models/onnxruntime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,25 @@ p_session_->Run(nullptr, input_names, inputs, std::size(inputs), output_names, o
#define ORT_GENAI_HAS_EXPERIMENTAL_C_API 1
#endif

// Expose Ort::Experimental::Get_*_Fn accessors used by onnxruntime_inline.h without pulling
// in onnxruntime_experimental_cxx_api.h (which transitively includes onnxruntime_cxx_api.h
// and conflicts with genai's vendored Ort wrappers). Only the C typedefs from
// onnxruntime_experimental_c_api.h (already included above) are needed here.
#if ORT_GENAI_HAS_EXPERIMENTAL_C_API
namespace Ort {
namespace Experimental {
#define ORT_EXPERIMENTAL_API(VER, RET, NAME, ...) \
inline OrtExperimental_##NAME##_SinceV##VER##_Fn Get_##NAME##_SinceV##VER##_Fn( \
const OrtApi* api) { \
return reinterpret_cast<OrtExperimental_##NAME##_SinceV##VER##_Fn>( \
api->GetExperimentalFunction(kOrtExperimental_##NAME##_SinceV##VER##_FnName)); \
}
#include "onnxruntime_experimental_c_api.inc"
#undef ORT_EXPERIMENTAL_API
} // namespace Experimental
} // namespace Ort
#endif

// Single gate for OrtModelPackageApi support: API version 28+ and the experimental header
// available. The OrtModelPackage* wrappers below and model_package.{h,cpp} key off this.
#if defined(ORT_API_VERSION) && ORT_API_VERSION >= 28 && ORT_GENAI_HAS_EXPERIMENTAL_C_API
Expand Down
14 changes: 4 additions & 10 deletions src/models/recurrent_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,14 @@ RecurrentState::RecurrentState(State& state)

const int num_layers = static_cast<int>(layer_indices_.size());

if (!state_.params_->IsPastPresentShareBufferEnabled(model_.config_->model.type)) {
share_buffers_ = state_.params_->IsPastPresentShareBufferEnabled(model_.config_->model.type);

Comment thread
qjia7 marked this conversation as resolved.
if (state_.params_->use_graph_capture && !share_buffers_) {
throw std::runtime_error(
"RecurrentState requires past_present_share_buffer=true. "
"Graph capture requires past_present_share_buffer=true for models with recurrent state. "
"Set past_present_share_buffer to true in genai_config.json.");
}
Comment thread
qjia7 marked this conversation as resolved.

// WebGPU prohibits binding the same buffer as both read-only (input) and
// read-write (output) storage in the same compute pass, so it must use
// separate past/present buffers with swap. All other EPs share buffers
// for stable addresses (required by TRT-RTX graph replay, beneficial elsewhere).
// TODO: Remove WebGPU special case once the ORT WebGPU EP adds a
// LinearAttention kernel with native past/present buffer sharing support.
share_buffers_ = model_.p_device_kvcache_->GetType() != DeviceType::WEBGPU;

if (!share_buffers_) {
pasts_.resize(num_layers * 2);
}
Expand Down
3 changes: 2 additions & 1 deletion src/models/recurrent_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ struct RecurrentState {
std::vector<std::unique_ptr<OrtValue>> pasts_;
std::vector<std::unique_ptr<OrtValue>> presents_;

// WebGPU cannot alias input/output buffers, so it uses separate past/present\n // with swap. All other EPs share buffers for stable addresses.
// Mirrors past_present_share_buffer config: true means inputs alias outputs (same allocation,
// stable handles for graph capture). False uses separate past/present buffers with per-step swap.
bool share_buffers_{false};
size_t input_index_{~0U};
size_t output_index_{~0U};
Expand Down
Loading