Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Modifications Copyright (C) 2026 Advanced Micro Devices, Inc. All rights reserved.
// Portions of this file consist of AI generated content.
#pragma once
#include "provider_options.h"

namespace Generators {

Expand Down Expand Up @@ -86,18 +87,9 @@ struct Config {

fs::path config_path; // Path of the config directory

using NamedString = std::pair<std::string, std::string>;
struct DeviceFilteringOptions {
std::optional<OrtHardwareDeviceType> hardware_device_type; // OrtHardwareDeviceType_CPU, OrtHardwareDeviceType_GPU, OrtHardwareDeviceType_NPU
std::optional<uint32_t> hardware_device_id;
std::optional<uint32_t> hardware_vendor_id;
};

struct ProviderOptions {
std::string name;
std::vector<NamedString> options;
std::optional<DeviceFilteringOptions> device_filtering_options;
};
using NamedString = Generators::NamedString;
using DeviceFilteringOptions = Generators::DeviceFilteringOptions;
using ProviderOptions = Generators::ProviderOptions;

struct SessionOptions {
std::optional<int> intra_op_num_threads;
Expand Down
44 changes: 13 additions & 31 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,39 +420,21 @@ void EnsureDeviceOrtInit(DeviceInterface& device, const Config& config) {
const char* provider_name = device_type_names[static_cast<int>(type)];
Config::ProviderOptions init_session_provider_options{provider_name, {}};

// Forward only global/singleton WebGPU options to the init session so that the
// process-wide WebGpuContext singleton is initialized with the correct settings.
// Per-session options (enableGraphCapture, enableInt64, etc.) are excluded
// because they are meaningless for the trivial initialization model.
if (type == DeviceType::WEBGPU) {
constexpr std::array<std::string_view, 7> kWebGpuGlobalOptions = {
"deviceId",
"webgpuInstance",
"webgpuDevice",
"dawnBackendType",
"powerPreference",
"validationMode",
"dawnProcTable",
};
for (const auto& user_po : config.model.decoder.session_options.provider_options) {
if (user_po.name == provider_name) {
for (const auto& opt : user_po.options) {
if (std::find(kWebGpuGlobalOptions.begin(), kWebGpuGlobalOptions.end(), opt.first) != kWebGpuGlobalOptions.end()) {
init_session_provider_options.options.emplace_back(opt);
}
}
init_session_provider_options.device_filtering_options = user_po.device_filtering_options;
break;
}
}
}
// Look up the user-supplied provider options entry for this provider (if any),
// then let the EP shape the trivial-model init session options. Most EPs use
// the default no-op; WebGPU forwards global/singleton options and QNN injects
// the QnnHtpShared allocator gating option.
const auto& user_provider_options_list = config.model.decoder.session_options.provider_options;
const auto user_provider_options_it = std::find_if(
user_provider_options_list.begin(), user_provider_options_list.end(),
[provider_name](const Config::ProviderOptions& po) { return po.name == provider_name; });
const Config::ProviderOptions* user_provider_options =
user_provider_options_it != user_provider_options_list.end() ? &*user_provider_options_it : nullptr;
if (user_provider_options)
init_session_provider_options.device_filtering_options = user_provider_options->device_filtering_options;
device.ShapeInitSessionProviderOptions(init_session_provider_options, user_provider_options);

provider_options_list.emplace_back(std::move(init_session_provider_options));
// QnnHtpShared is a special case. This allocator is only made available when the provider option
// 'enable_htp_shared_memory_allocator' is set to 1.
if (type == DeviceType::QNN) {
provider_options_list.back().options.emplace_back("enable_htp_shared_memory_allocator", "1");
}
const std::vector<std::string> providers{device_type_names[static_cast<int>(type)]};
SetProviderSessionOptions(*session_options, providers, provider_options_list, true, config);
session_options->SetLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR); // Errors only here, as warnings are not useful to the user
Expand Down
27 changes: 27 additions & 0 deletions src/provider_options.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <cstdint>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "models/onnxruntime_api.h" // for OrtHardwareDeviceType

namespace Generators {

using NamedString = std::pair<std::string, std::string>;

struct DeviceFilteringOptions {
std::optional<OrtHardwareDeviceType> hardware_device_type; // OrtHardwareDeviceType_CPU, OrtHardwareDeviceType_GPU, OrtHardwareDeviceType_NPU
std::optional<uint32_t> hardware_device_id;
std::optional<uint32_t> hardware_vendor_id;
};

struct ProviderOptions {
std::string name;
std::vector<NamedString> options;
std::optional<DeviceFilteringOptions> device_filtering_options;
};

} // namespace Generators
7 changes: 7 additions & 0 deletions src/qnn/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ struct InterfaceImpl : DeviceInterface {
std::unique_ptr<Search> CreateBeam(const GeneratorParams& params) override { return std::make_unique<BeamSearch_Cpu>(params); }

void Synchronize() override {} // Nothing to do

void ShapeInitSessionProviderOptions(Config::ProviderOptions& init_options,
const Config::ProviderOptions* /*user_options*/) const override {
// QnnHtpShared is a special case. This allocator is only made available when the provider
// option 'enable_htp_shared_memory_allocator' is set to 1.
init_options.options.emplace_back("enable_htp_shared_memory_allocator", "1");
}
};

} // namespace QNN
Expand Down
11 changes: 11 additions & 0 deletions src/smartptrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
//
// Modifications Copyright(C) 2026 Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm> // for std::copy
#include <assert.h>
#include <atomic>
#include <memory>
#include <type_traits> // for std::remove_const_t
#include "span.h"
#include "models/onnxruntime_api.h" // for ONNXTensorElementDataType
#include "provider_options.h" // for ProviderOptions
namespace Ort {
struct Allocator;
}
Expand Down Expand Up @@ -133,6 +136,14 @@ struct DeviceInterface {
virtual void FinalizeCrossQK(int /*iteration_number*/, int /*context_decoding_len*/, int /*batch_size*/, int /*num_beams*/, int /*max_length*/, int /*num_alignment_heads*/, int /*frames_of_k*/, const Ort::Float16_t* /*cross_qk_buffer_data*/, Ort::Float16_t* /*cross_qk_output*/, int /*num_return_sequences*/, const int* /*cache_indir_data*/) { assert(false); }
virtual void GetAvailableMemory(size_t& /* free_bytes */, size_t& /* total_bytes */) { assert(false); }

// Allow each EP to shape the trivial init-session ProviderOptions used by EnsureDeviceOrtInit.
// The default does nothing; EPs that need global singletons configured (e.g. WebGPU) or
// allocator gating options (e.g. QNN) override this. `user_options` is the user-supplied entry
// for this provider from config.model.decoder.session_options.provider_options, or nullptr if
// the user did not provide one.
virtual void ShapeInitSessionProviderOptions(ProviderOptions& /*init_options*/,
const ProviderOptions* /*user_options*/) const {}

virtual void* GetCudaStream() {
assert(false);
return nullptr;
Expand Down
34 changes: 34 additions & 0 deletions src/webgpu/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,40 @@ struct InterfaceImpl : DeviceInterface {

return true;
}

void ShapeInitSessionProviderOptions(Config::ProviderOptions& init_options,
const Config::ProviderOptions* user_options) const override {
if (!user_options) return;

// Forward only global/singleton WebGPU options to the init session so that the
// process-wide WebGpuContext singleton is initialized with the correct settings.
// Per-session options (preferredLayout, enableGraphCapture, sessionBufferPoolGenerations,
// enableInt64, multiRotaryCacheConcatOffset, forceCpuNodeNames, enablePIXCapture) are
// excluded because they are meaningless for the trivial initialization model.
// Keep this list in sync with ParseWebGpuContextConfig in
// onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc.
constexpr std::array<std::string_view, 14> kWebGpuGlobalOptions = {
"deviceId",
"webgpuInstance",
"webgpuDevice",
"dawnProcTable",
"dawnBackendType",
"powerPreference",
"validationMode",
"preserveDevice",
"maxStorageBufferBindingSize",
"maxNumPendingDispatches",
"storageBufferCacheMode",
"uniformBufferCacheMode",
"queryResolveBufferCacheMode",
"defaultBufferCacheMode",
};
for (const auto& opt : user_options->options) {
if (std::find(kWebGpuGlobalOptions.begin(), kWebGpuGlobalOptions.end(), opt.first) != kWebGpuGlobalOptions.end()) {
init_options.options.emplace_back(opt);
}
}
}
};

} // namespace WebGPU
Expand Down
Loading