Skip to content

Commit a9c517c

Browse files
committed
cuda: add per-session mutable state rebinding
Local agent serving needs to host multiple logical conversations on one CUDA-resident model without multiplying the model weights. Loading one AOTI module per conversation is not viable for large local models, while sharing the default mutable state across conversations would let KV/recurrent/conv buffers bleed between users. This adds the CUDA-private foundation for separating those concerns: weights remain owned by the loaded AOTI container, while mutable buffer FQNs can be registered as per-session state and rebound before execution. The path is fail-closed and dormant until a model opts in by creating a mutable-state context and validating coverage, so existing CUDA models keep their current behavior. The branch also wires the new source and unit coverage into both Buck and CMake so the primitive can land independently before any model-specific engine consumes it.
1 parent d7ca5db commit a9c517c

6 files changed

Lines changed: 829 additions & 1 deletion

File tree

backends/cuda/CMakeLists.txt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ install(
184184
)
185185

186186
# CUDA backend implementation
187-
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp)
187+
set(_aoti_cuda_backend_sources runtime/cuda_backend.cpp
188+
runtime/cuda_mutable_state.cpp
189+
)
188190
if(_cuda_is_msvc_toolchain)
189191
# MSVC links aoti_cuda_backend into portable_lib without relying on C++
190192
# symbols exported from aoti_cuda_shims.dll.
@@ -236,3 +238,13 @@ install(
236238
EXPORT ExecuTorchTargets
237239
DESTINATION lib
238240
)
241+
242+
if(BUILD_TESTING)
243+
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
244+
245+
et_cxx_test(
246+
test_cuda_mutable_state SOURCES runtime/test/test_cuda_mutable_state.cpp
247+
EXTRA_LIBS aoti_cuda_backend
248+
)
249+
target_compile_definitions(test_cuda_mutable_state PRIVATE CUDA_AVAILABLE=1)
250+
endif()

backends/cuda/runtime/TARGETS

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest")
23
load("//tools/build/buck:nvcc_flags.bzl", "get_nvcc_arch_args")
34

45
oncall("executorch")
@@ -105,9 +106,11 @@ runtime.cxx_library(
105106
name = "cuda_backend",
106107
srcs = [
107108
"cuda_backend.cpp",
109+
"cuda_mutable_state.cpp",
108110
],
109111
headers = [
110112
"cuda_delegate_handle.h",
113+
"cuda_mutable_state.h",
111114
],
112115
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
113116
link_whole = True,
@@ -135,3 +138,19 @@ runtime.cxx_library(
135138
("cuda", None, "cuda-lazy"),
136139
],
137140
)
141+
142+
cpp_unittest(
143+
name = "test_cuda_mutable_state",
144+
srcs = [
145+
"test/test_cuda_mutable_state.cpp",
146+
],
147+
deps = [
148+
":cuda_backend",
149+
"//executorch/runtime/core:core",
150+
"//executorch/runtime/platform:platform",
151+
],
152+
external_deps = [
153+
("cuda", None, "cuda-lazy"),
154+
],
155+
preprocessor_flags = ["-DCUDA_AVAILABLE=1"],
156+
)

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include <executorch/backends/aoti/utils.h>
4545
#include <executorch/backends/cuda/runtime/cuda_allocator.h>
4646
#include <executorch/backends/cuda/runtime/cuda_delegate_handle.h>
47+
#include <executorch/backends/cuda/runtime/cuda_mutable_state.h>
4748
#include <executorch/backends/cuda/runtime/platform/platform.h>
4849
#include <executorch/backends/cuda/runtime/shims/memory.h>
4950
#include <executorch/backends/cuda/runtime/utils.h>
@@ -436,6 +437,10 @@ class ET_EXPERIMENTAL CudaBackend final
436437
kCudaGraphWarmupSteps);
437438
}
438439

440+
// Record whether this AOTI build exposes the constant-management symbols
441+
// needed for per-session mutable-buffer rebinding (CUDA V2 multi-session).
442+
mutable_state_note_handle(handle);
443+
439444
return (DelegateHandle*)handle; // Return the handle post-processing
440445
}
441446

@@ -539,6 +544,12 @@ class ET_EXPERIMENTAL CudaBackend final
539544
}
540545
}
541546

547+
// CUDA V2 multi-session: if a logical session is active on this thread,
548+
// rebind this container's mutable constants (KV/conv/recurrent) to the
549+
// session's own GPU buffers before running. No-op for
550+
// single-session/legacy.
551+
ET_CHECK_OK_OR_RETURN_ERROR(mutable_state_rebind_for_execute(handle));
552+
542553
// ---------------------------------------------------------------
543554
// CUDA graph REPLAY path — skip all tensor setup and just replay
544555
// ---------------------------------------------------------------
@@ -826,6 +837,8 @@ class ET_EXPERIMENTAL CudaBackend final
826837
}
827838
cuda::CudaDelegateHandle* handle = (cuda::CudaDelegateHandle*)handle_;
828839

840+
mutable_state_forget_handle(handle);
841+
829842
// The CUDA stream is managed by shared_ptr in the handle.
830843
// It will be automatically destroyed when the last handle using it
831844
// is destroyed. Just reset our reference.

0 commit comments

Comments
 (0)