Skip to content

webgpu: fix RecurrentState graph capture with shared buffer aliasing#2244

Open
qjia7 wants to merge 2 commits into
mainfrom
fix/webgpu-recurrent-state-graph-capture
Open

webgpu: fix RecurrentState graph capture with shared buffer aliasing#2244
qjia7 wants to merge 2 commits into
mainfrom
fix/webgpu-recurrent-state-graph-capture

Conversation

@qjia7

@qjia7 qjia7 commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Summary

  • RecurrentState::share_buffers_ now derives from past_present_share_buffer config (same pattern as DefaultKeyValueCache) instead of being hardcoded to false for WebGPU
  • Add a graph-capture guard that throws when share_buffers_ is false and graph capture is requested, matching the existing guard in kv_cache.cpp
  • Add Ort::Experimental::Get_*_Fn accessor shim to onnxruntime_api.h so genai builds against ORT builds where onnxruntime_experimental_cxx_api.h cannot be included directly

Motivation

WebGPU graph capture freezes GPU buffer handles at the first captured decode step. RecurrentState was using separate past/present buffers for WebGPU with a per-step C++ pointer swap — but graph capture ignores those swaps on replay, so every decode step after the first read a stale buffer, producing garbage output.

With past_present_share_buffer: true, the ORT LinearAttention and CausalConvWithState kernels detect the aliased buffers via initial_state_in_present_state / conv_state_in_present_state and switch to a single read_write binding, which satisfies the WebGPU spec constraint that originally motivated the per-device split. Buffer handles stay stable across graph capture replays.

The original comment in the code noted a TODO to remove the WebGPU special case once the ORT WebGPU EP kernels natively supported past/present buffer sharing — that support is already present, so this PR resolves the TODO.

Test plan

  • Qwen3.5-0.8B-webgpu-fused correctness test: all queries pass with enableGraphCapture=1
  • Qwen3.5-0.8B-webgpu-fused multi-gen test: sequential and overlapping generators pass with enableGraphCapture=1
  • Both tests also pass with enableGraphCapture=0 (no regression on non-graph-capture path)
  • clang-format: no issues on changed files
  • Build: succeeds against local ORT WebGPU EP build

RecurrentState used separate past/present buffers for WebGPU and swapped
them per step. WebGPU graph capture freezes GPU buffer handles at the
first captured decode step; subsequent replays ignored the C++ pointer
swap and kept reading a stale buffer, producing garbage output on every
decode step after the first.

Fix: derive share_buffers_ from the past_present_share_buffer config
(matching the DefaultKeyValueCache pattern) and add a graph-capture guard
that throws when share_buffers_ is false. When share_buffers_ is true,
the ORT LinearAttention and CausalConvWithState kernels detect the aliased
buffers via initial_state_in_present_state / conv_state_in_present_state
and use a single read_write binding, satisfying the WebGPU spec constraint
that was the original reason for the per-device split.

Also add an Ort::Experimental::Get_*_Fn accessor shim to onnxruntime_api.h
so genai builds against ORT builds where onnxruntime_experimental_cxx_api.h
cannot be included directly due to a transitive onnxruntime_cxx_api.h
conflict with genai's vendored Ort wrappers.

Verified: Qwen3.5-0.8B-webgpu-fused correctness and multi-gen tests pass
with enableGraphCapture=1 and with enableGraphCapture=0.
@qjia7 qjia7 marked this pull request as ready for review June 26, 2026 08:54
@qjia7 qjia7 requested a review from a team as a code owner June 26, 2026 08:54
Copilot AI review requested due to automatic review settings June 26, 2026 08:54

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes incorrect WebGPU graph-capture behavior for models that use RecurrentState by making recurrent past/present buffer sharing follow the past_present_share_buffer configuration (so GPU buffer handles remain stable across graph-capture replays). It also adds a small ONNX Runtime experimental-API accessor shim so GenAI can build in environments where the experimental C++ header cannot be included.

Changes:

  • Derive RecurrentState::share_buffers_ from GeneratorParams::IsPastPresentShareBufferEnabled(...) (instead of a WebGPU hardcoded behavior).
  • Add a runtime guard that rejects graph capture when effective past/present sharing is disabled.
  • Add Ort::Experimental::Get_*_Fn accessor shims in onnxruntime_api.h to avoid including onnxruntime_experimental_cxx_api.h.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
src/models/recurrent_state.h Updates the member comment to describe config-driven past/present sharing semantics.
src/models/recurrent_state.cpp Switches share_buffers_ to be config-derived and adds a graph-capture precondition check.
src/models/onnxruntime_api.h Adds an experimental accessor shim to support model-package functions without including ORT experimental C++ headers.

Comment thread src/models/recurrent_state.cpp
Comment thread src/models/recurrent_state.cpp
- Error message now mentions num_beams=1 constraint since beam search
  also disables past/present buffer sharing via IsPastPresentShareBufferEnabled.
- Replace "WebGPU: separate past/present buffers" comment in Add() with
  EP-neutral description; the separate-buffer path applies to any EP when
  share_buffers_=false.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants