Route init-session provider-option shaping through DeviceInterface#2232
Route init-session provider-option shaping through DeviceInterface#2232qjia7 wants to merge 7 commits into
Conversation
EnsureDeviceOrtInit previously contained inline EP-specific branches for WebGPU (whitelist forwarding of global/singleton options) and QNN (injection of enable_htp_shared_memory_allocator=1). This change adds a new optional virtual ShapeInitSessionProviderOptions on DeviceInterface so each EP owns its own init-session shaping. The base EnsureDeviceOrtInit becomes EP-agnostic: it looks up the user-supplied provider options once, then dispatches to the EP override. Most EPs inherit the default no-op. Also expand the WebGPU global-options whitelist from 7 to 14 entries to match all options consumed by ParseWebGpuContextConfig: preserveDevice, maxStorageBufferBindingSize, maxNumPendingDispatches, and the four *BufferCacheMode keys were previously dropped before reaching the process-wide WebGpuContext singleton. config.h becomes self-sufficient (adds the standard and project headers it already implicitly depended on) so it can be included from smartptrs.h without a transitive chain. A warning comment prevents future contributors from introducing a circular include. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
⚠️ Not ready to approve
config.h still relies on transitive standard-library includes for types it uses (e.g., uint32_t/std::unordered_map/std::byte), which undermines the stated goal of being self-sufficient.
Pull request overview
This PR refactors EnsureDeviceOrtInit so execution-provider-specific shaping of the trivial init-session provider options is routed through a new optional DeviceInterface virtual, with WebGPU and QNN providing overrides. It also expands the WebGPU init-session option forwarding whitelist and makes config.h more directly includable from lower-level headers by adding explicit includes and warning about include cycles.
Changes:
- Add
DeviceInterface::ShapeInitSessionProviderOptions(...)(default no-op) and routeEnsureDeviceOrtInitthrough it. - WebGPU: forward a larger set of global/singleton WebGPU provider options into the trivial init session.
- QNN: inject
enable_htp_shared_memory_allocator=1via the new shaping hook; updateconfig.hinclude set and document include-cycle constraints.
File summaries
| File | Description |
|---|---|
| src/webgpu/interface.cpp | Implements WebGPU override to forward process-wide WebGPU context options into the init session. |
| src/smartptrs.h | Adds the new optional shaping virtual to DeviceInterface and includes config.h for the signature. |
| src/qnn/interface.cpp | Implements QNN override to inject the shared-memory allocator gating provider option. |
| src/models/model.cpp | Replaces hard-coded WebGPU/QNN shaping with provider-option lookup + dispatch through DeviceInterface. |
| src/config.h | Adds explicit includes and a warning comment to avoid include cycles with smartptrs.h/generators.h. |
Copilot's findings
- Files reviewed: 5/5 changed files
- Comments generated: 1
Note
Your feedback helps us improve the quality of this feature.
Please use 👍 or 👎 to tell us whether this assessment is correct.
Add <cstddef>, <cstdint>, <span>, and <unordered_map> so config.h provides every type it uses (uint32_t, size_t, std::byte, std::span<const std::byte>, std::unordered_map) without depending on transitive includes pulled in through onnxruntime_api.h or <vector>.
There was a problem hiding this comment.
⚠️ Not ready to approve
src/config.h now includes <span>, which can break builds that define USE_CXX17 (the repo’s span.h provides the C++17 polyfill), so the include should use the project wrapper instead.
Copilot's findings
- Files reviewed: 5/5 changed files
- Comments generated: 1
Note
Your feedback helps us improve the quality of this feature.
Please use 👍 or 👎 to tell us whether this assessment is correct.
…t lint - Replace standard <span> include with project span.h wrapper in config.h to maintain C++17 polyfill compatibility (fixes Copilot review concern) - Fix include comment formatting in smartptrs.h to align with clang-format style (fixes lint-cpp CI failure) These changes address the PR #2232 Copilot review comments and lint check failures.
There was a problem hiding this comment.
⚠️ Not ready to approve
Device filtering options can be silently dropped during init-session allocator initialization for non-WebGPU EPs, which can initialize the global allocator against the wrong device.
Copilot's findings
- Files reviewed: 5/5 changed files
- Comments generated: 2
Note
Your feedback helps us improve the quality of this feature.
Please use 👍 or 👎 to tell us whether this assessment is correct.
…artptrs.h DeviceInterface::ShapeInitSessionProviderOptions needs ProviderOptions in its signature. Including config.h from smartptrs.h required making config.h self-sufficient (12 new standard headers), which caused the C++17 span incompatibility and downstream lint issues. Fix: move DeviceFilteringOptions, NamedString, and ProviderOptions to a new self-sufficient provider_options.h. smartptrs.h includes the lightweight new header; config.h adds one #include and replaces the three definitions with transparent type aliases, so all Config::ProviderOptions call sites compile unchanged.
- Add <algorithm> and <type_traits> to smartptrs.h for std::copy and std::remove_const_t which were relying on transitive includes. - Propagate device_filtering_options unconditionally in EnsureDeviceOrtInit before the virtual ShapeInitSessionProviderOptions call so all EPs (CUDA, QNN, DML, etc.) honour the user's device selection for the allocator-init session, not just WebGPU. - Remove the now-redundant device_filtering_options assignment from the WebGPU override.
| #include <assert.h> | ||
| #include <atomic> | ||
| #include <memory> | ||
| #include <type_traits> |
There was a problem hiding this comment.
Why is this import needed?
There was a problem hiding this comment.
<type_traits> is needed for std::remove_const_t used at DeviceInterface::WrapMemory (line 117 of this file) to strip const before passing to WrapMemoryBase(void*, ...). It was previously pulled in transitively, which Copilot flagged as fragile, so I added it explicitly in commit 2821298. Added a // for std::remove_const_t marker in d792135 to match the existing convention.
| // | ||
| // Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved. | ||
| #pragma once | ||
| #include <algorithm> |
There was a problem hiding this comment.
Why is this import needed?
There was a problem hiding this comment.
<algorithm> is needed for std::copy used in the free function Generators::copy at the bottom of this file. It was previously pulled in transitively, which Copilot flagged as fragile, so I added it explicitly in commit 2821298. Added a // for std::copy marker in d792135 to match the existing convention used by the other includes in this file.
Add inline 'for X' markers next to the two includes to match the existing convention in this file (see span.h, onnxruntime_api.h, provider_options.h above). <algorithm> is used by std::copy in Generators::copy; <type_traits> is used by std::remove_const_t in DeviceInterface::WrapMemory.
Summary
EnsureDeviceOrtInittrivial-model session options out ofmodel.cppand onto a new optional virtualDeviceInterface::ShapeInitSessionProviderOptions. WebGPU and QNN provide overrides; all other EPs inherit the default no-op. The base path becomes EP-agnostic: it looks up the user-supplied entry once, then dispatches.ParseWebGpuContextConfig(preserveDevice,maxStorageBufferBindingSize,maxNumPendingDispatches, and the four*BufferCacheModekeys) actually reaches the process-wideWebGpuContextsingleton.config.hself-sufficient (adds the standard and project headers it already implicitly depended on) so it can be included fromsmartptrs.hwithout a transitive chain. A one-line comment inconfig.hwarns against introducing a circular include back tosmartptrs.h/generators.h.Test plan
onnxruntime-genaiwithUSE_WEBGPU=ONagainst a local ORT (onnxruntime-genai.dll,*.pyd, wheel,unit_tests.exe,model_benchmark.exeall produced).verify_model_correctness.pyandverify_multi_gen.py(covers the original reasonEnsureDeviceOrtInitexists).enable_htp_shared_memory_allocator=1is still injected.