webgpu: fix RecurrentState graph capture with shared buffer aliasing#2244
Open
qjia7 wants to merge 2 commits into
Open
webgpu: fix RecurrentState graph capture with shared buffer aliasing#2244qjia7 wants to merge 2 commits into
qjia7 wants to merge 2 commits into
Conversation
RecurrentState used separate past/present buffers for WebGPU and swapped them per step. WebGPU graph capture freezes GPU buffer handles at the first captured decode step; subsequent replays ignored the C++ pointer swap and kept reading a stale buffer, producing garbage output on every decode step after the first. Fix: derive share_buffers_ from the past_present_share_buffer config (matching the DefaultKeyValueCache pattern) and add a graph-capture guard that throws when share_buffers_ is false. When share_buffers_ is true, the ORT LinearAttention and CausalConvWithState kernels detect the aliased buffers via initial_state_in_present_state / conv_state_in_present_state and use a single read_write binding, satisfying the WebGPU spec constraint that was the original reason for the per-device split. Also add an Ort::Experimental::Get_*_Fn accessor shim to onnxruntime_api.h so genai builds against ORT builds where onnxruntime_experimental_cxx_api.h cannot be included directly due to a transitive onnxruntime_cxx_api.h conflict with genai's vendored Ort wrappers. Verified: Qwen3.5-0.8B-webgpu-fused correctness and multi-gen tests pass with enableGraphCapture=1 and with enableGraphCapture=0.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR fixes incorrect WebGPU graph-capture behavior for models that use RecurrentState by making recurrent past/present buffer sharing follow the past_present_share_buffer configuration (so GPU buffer handles remain stable across graph-capture replays). It also adds a small ONNX Runtime experimental-API accessor shim so GenAI can build in environments where the experimental C++ header cannot be included.
Changes:
- Derive
RecurrentState::share_buffers_fromGeneratorParams::IsPastPresentShareBufferEnabled(...)(instead of a WebGPU hardcoded behavior). - Add a runtime guard that rejects graph capture when effective past/present sharing is disabled.
- Add
Ort::Experimental::Get_*_Fnaccessor shims inonnxruntime_api.hto avoid includingonnxruntime_experimental_cxx_api.h.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| src/models/recurrent_state.h | Updates the member comment to describe config-driven past/present sharing semantics. |
| src/models/recurrent_state.cpp | Switches share_buffers_ to be config-derived and adds a graph-capture precondition check. |
| src/models/onnxruntime_api.h | Adds an experimental accessor shim to support model-package functions without including ORT experimental C++ headers. |
- Error message now mentions num_beams=1 constraint since beam search also disables past/present buffer sharing via IsPastPresentShareBufferEnabled. - Replace "WebGPU: separate past/present buffers" comment in Add() with EP-neutral description; the separate-buffer path applies to any EP when share_buffers_=false.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
RecurrentState::share_buffers_now derives frompast_present_share_bufferconfig (same pattern asDefaultKeyValueCache) instead of being hardcoded tofalsefor WebGPUshare_buffers_is false and graph capture is requested, matching the existing guard inkv_cache.cppOrt::Experimental::Get_*_Fnaccessor shim toonnxruntime_api.hso genai builds against ORT builds whereonnxruntime_experimental_cxx_api.hcannot be included directlyMotivation
WebGPU graph capture freezes GPU buffer handles at the first captured decode step.
RecurrentStatewas using separate past/present buffers for WebGPU with a per-step C++ pointer swap — but graph capture ignores those swaps on replay, so every decode step after the first read a stale buffer, producing garbage output.With
past_present_share_buffer: true, the ORTLinearAttentionandCausalConvWithStatekernels detect the aliased buffers viainitial_state_in_present_state/conv_state_in_present_stateand switch to a singleread_writebinding, which satisfies the WebGPU spec constraint that originally motivated the per-device split. Buffer handles stay stable across graph capture replays.The original comment in the code noted a TODO to remove the WebGPU special case once the ORT WebGPU EP kernels natively supported past/present buffer sharing — that support is already present, so this PR resolves the TODO.
Test plan
Qwen3.5-0.8B-webgpu-fusedcorrectness test: all queries pass withenableGraphCapture=1Qwen3.5-0.8B-webgpu-fusedmulti-gen test: sequential and overlapping generators pass withenableGraphCapture=1enableGraphCapture=0(no regression on non-graph-capture path)