Skip to content

Commit b6c2a9f

Browse files
wliuyxmeta-codesync[bot]
authored andcommitted
Thread kernel_registry through Module::load_method (#19641)
Summary: Pull Request resolved: #19641 D98080033 added a method-scoped kernel registry to Program::load_method and Method, allowing callers to override specific kernels for a single method without affecting the global registry. However, the Module facade class did not expose this parameter, forcing consumers to bypass Module and manage memory manually. This adds an optional `Span<const Kernel> kernel_registry` parameter (defaulting to empty) to Module::load_method and Module::load_forward, and forwards it to Program::load_method. Existing callers are completely unaffected — the default empty span causes the runtime to fall back to the global kernel registry, exactly as before. Differential Revision: D104433196
1 parent a142873 commit b6c2a9f

2 files changed

Lines changed: 16 additions & 5 deletions

File tree

extension/module/module.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace extension {
2020
namespace ET_MODULE_NAMESPACE {
2121

2222
using ET_MERGED_DATA_MAP_NAMESPACE::MergedDataMap;
23+
using ET_RUNTIME_NAMESPACE::Kernel;
2324
using ET_RUNTIME_NAMESPACE::MethodMeta;
2425
using ET_RUNTIME_NAMESPACE::Program;
2526

@@ -365,7 +366,8 @@ runtime::Error Module::load_method(
365366
const std::string& method_name,
366367
runtime::HierarchicalAllocator* planned_memory,
367368
torch::executor::EventTracer* event_tracer,
368-
const LoadBackendOptionsMap* backend_options) {
369+
const LoadBackendOptionsMap* backend_options,
370+
std::vector<Kernel> kernel_registry) {
369371
if (!is_method_loaded(method_name)) {
370372
ET_CHECK_OK_OR_RETURN_ERROR(load());
371373

@@ -402,12 +404,16 @@ runtime::Error Module::load_method(
402404

403405
method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
404406
memory_allocator_.get(), planned_memory, temp_allocator_.get());
407+
method_holder.kernel_registry = std::move(kernel_registry);
405408
auto res_method = program_->load_method(
406409
method_name.c_str(),
407410
method_holder.memory_manager.get(),
408411
event_tracer ? event_tracer : this->event_tracer(),
409412
merged_data_map_.get(),
410-
effective_backend_options);
413+
effective_backend_options,
414+
runtime::Span<const Kernel>(
415+
method_holder.kernel_registry.data(),
416+
method_holder.kernel_registry.size()));
411417
if (!res_method.ok()) {
412418
return res_method.error();
413419
}

extension/module/module.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
namespace executorch {
2626
namespace extension {
2727

28+
using ET_RUNTIME_NAMESPACE::Kernel;
2829
using ET_RUNTIME_NAMESPACE::Method;
2930
using ET_RUNTIME_NAMESPACE::MethodMeta;
3031
using ET_RUNTIME_NAMESPACE::NamedDataMap;
@@ -255,7 +256,8 @@ class Module {
255256
const std::string& method_name,
256257
runtime::HierarchicalAllocator* planned_memory = nullptr,
257258
torch::executor::EventTracer* event_tracer = nullptr,
258-
const LoadBackendOptionsMap* backend_options = nullptr);
259+
const LoadBackendOptionsMap* backend_options = nullptr,
260+
std::vector<Kernel> kernel_registry = {});
259261

260262
ET_DEPRECATED ET_NODISCARD runtime::Error inline load_method(
261263
const std::string& method_name,
@@ -303,9 +305,11 @@ class Module {
303305
ET_NODISCARD inline runtime::Error load_forward(
304306
runtime::HierarchicalAllocator* planned_memory = nullptr,
305307
torch::executor::EventTracer* event_tracer = nullptr,
306-
const LoadBackendOptionsMap* backend_options = nullptr) {
308+
const LoadBackendOptionsMap* backend_options = nullptr,
309+
std::vector<Kernel> kernel_registry = {}) {
307310
return load_method(
308-
"forward", planned_memory, event_tracer, backend_options);
311+
"forward", planned_memory, event_tracer, backend_options,
312+
std::move(kernel_registry));
309313
}
310314

311315
ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward(
@@ -698,6 +702,7 @@ class Module {
698702
std::unique_ptr<PlannedMemory> planned_memory;
699703
std::unique_ptr<runtime::MemoryManager> memory_manager;
700704
std::unique_ptr<Method> method;
705+
std::vector<Kernel> kernel_registry;
701706
};
702707

703708
std::string file_path_;

0 commit comments

Comments
 (0)