Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
// Portions of this file consist of AI generated content.
#pragma once

// Do not include "smartptrs.h" or "generators.h" from this header: smartptrs.h
// includes this file, so the reverse direction would form a cycle.
#include <cstddef>
#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>
#include "filesystem.h"
#include "models/onnxruntime_api.h"
#include "span.h"

namespace Generators {

struct RuntimeSettings;
Expand Down
42 changes: 11 additions & 31 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,39 +411,19 @@ void EnsureDeviceOrtInit(DeviceInterface& device, const Config& config) {
const char* provider_name = device_type_names[static_cast<int>(type)];
Config::ProviderOptions init_session_provider_options{provider_name, {}};

// Forward only global/singleton WebGPU options to the init session so that the
// process-wide WebGpuContext singleton is initialized with the correct settings.
// Per-session options (enableGraphCapture, enableInt64, etc.) are excluded
// because they are meaningless for the trivial initialization model.
if (type == DeviceType::WEBGPU) {
constexpr std::array<std::string_view, 7> kWebGpuGlobalOptions = {
"deviceId",
"webgpuInstance",
"webgpuDevice",
"dawnBackendType",
"powerPreference",
"validationMode",
"dawnProcTable",
};
for (const auto& user_po : config.model.decoder.session_options.provider_options) {
if (user_po.name == provider_name) {
for (const auto& opt : user_po.options) {
if (std::find(kWebGpuGlobalOptions.begin(), kWebGpuGlobalOptions.end(), opt.first) != kWebGpuGlobalOptions.end()) {
init_session_provider_options.options.emplace_back(opt);
}
}
init_session_provider_options.device_filtering_options = user_po.device_filtering_options;
break;
}
}
}
// Look up the user-supplied provider options entry for this provider (if any),
// then let the EP shape the trivial-model init session options. Most EPs use
// the default no-op; WebGPU forwards global/singleton options and QNN injects
// the QnnHtpShared allocator gating option.
const auto& user_provider_options_list = config.model.decoder.session_options.provider_options;
const auto user_provider_options_it = std::find_if(
user_provider_options_list.begin(), user_provider_options_list.end(),
[provider_name](const Config::ProviderOptions& po) { return po.name == provider_name; });
const Config::ProviderOptions* user_provider_options =
user_provider_options_it != user_provider_options_list.end() ? &*user_provider_options_it : nullptr;
device.ShapeInitSessionProviderOptions(init_session_provider_options, user_provider_options);
Comment thread
qjia7 marked this conversation as resolved.

provider_options_list.emplace_back(std::move(init_session_provider_options));
// QnnHtpShared is a special case. This allocator is only made available when the provider option
// 'enable_htp_shared_memory_allocator' is set to 1.
if (type == DeviceType::QNN) {
provider_options_list.back().options.emplace_back("enable_htp_shared_memory_allocator", "1");
}
const std::vector<std::string> providers{device_type_names[static_cast<int>(type)]};
SetProviderSessionOptions(*session_options, providers, provider_options_list, true, config);
session_options->SetLogSeverityLevel(ORT_LOGGING_LEVEL_ERROR); // Errors only here, as warnings are not useful to the user
Expand Down
7 changes: 7 additions & 0 deletions src/qnn/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ struct InterfaceImpl : DeviceInterface {
std::unique_ptr<Search> CreateBeam(const GeneratorParams& params) override { return std::make_unique<BeamSearch_Cpu>(params); }

void Synchronize() override {} // Nothing to do

void ShapeInitSessionProviderOptions(Config::ProviderOptions& init_options,
const Config::ProviderOptions* /*user_options*/) const override {
// QnnHtpShared is a special case. This allocator is only made available when the provider
// option 'enable_htp_shared_memory_allocator' is set to 1.
init_options.options.emplace_back("enable_htp_shared_memory_allocator", "1");
}
};

} // namespace QNN
Expand Down
9 changes: 9 additions & 0 deletions src/smartptrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <memory>
#include "span.h"
#include "models/onnxruntime_api.h" // for ONNXTensorElementDataType
#include "config.h" // for Config::ProviderOptions
Comment thread
qjia7 marked this conversation as resolved.
Outdated
namespace Ort {
struct Allocator;
}
Expand Down Expand Up @@ -133,6 +134,14 @@ struct DeviceInterface {
virtual void FinalizeCrossQK(int /*iteration_number*/, int /*context_decoding_len*/, int /*batch_size*/, int /*num_beams*/, int /*max_length*/, int /*num_alignment_heads*/, int /*frames_of_k*/, const Ort::Float16_t* /*cross_qk_buffer_data*/, Ort::Float16_t* /*cross_qk_output*/, int /*num_return_sequences*/, const int* /*cache_indir_data*/) { assert(false); }
virtual void GetAvailableMemory(size_t& /* free_bytes */, size_t& /* total_bytes */) { assert(false); }

// Allow each EP to shape the trivial init-session ProviderOptions used by EnsureDeviceOrtInit.
// The default does nothing; EPs that need global singletons configured (e.g. WebGPU) or
// allocator gating options (e.g. QNN) override this. `user_options` is the user-supplied entry
// for this provider from config.model.decoder.session_options.provider_options, or nullptr if
// the user did not provide one.
virtual void ShapeInitSessionProviderOptions(Config::ProviderOptions& /*init_options*/,
const Config::ProviderOptions* /*user_options*/) const {}

virtual void* GetCudaStream() {
assert(false);
return nullptr;
Expand Down
35 changes: 35 additions & 0 deletions src/webgpu/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,41 @@ struct InterfaceImpl : DeviceInterface {

return true;
}

void ShapeInitSessionProviderOptions(Config::ProviderOptions& init_options,
const Config::ProviderOptions* user_options) const override {
if (!user_options) return;

// Forward only global/singleton WebGPU options to the init session so that the
// process-wide WebGpuContext singleton is initialized with the correct settings.
// Per-session options (preferredLayout, enableGraphCapture, sessionBufferPoolGenerations,
// enableInt64, multiRotaryCacheConcatOffset, forceCpuNodeNames, enablePIXCapture) are
// excluded because they are meaningless for the trivial initialization model.
// Keep this list in sync with ParseWebGpuContextConfig in
// onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc.
constexpr std::array<std::string_view, 14> kWebGpuGlobalOptions = {
"deviceId",
"webgpuInstance",
"webgpuDevice",
"dawnProcTable",
"dawnBackendType",
"powerPreference",
"validationMode",
"preserveDevice",
"maxStorageBufferBindingSize",
"maxNumPendingDispatches",
"storageBufferCacheMode",
"uniformBufferCacheMode",
"queryResolveBufferCacheMode",
"defaultBufferCacheMode",
};
for (const auto& opt : user_options->options) {
if (std::find(kWebGpuGlobalOptions.begin(), kWebGpuGlobalOptions.end(), opt.first) != kWebGpuGlobalOptions.end()) {
init_options.options.emplace_back(opt);
}
}
init_options.device_filtering_options = user_options->device_filtering_options;
}
};

} // namespace WebGPU
Expand Down
Loading