Skip to content

Commit c27cc5d

Browse files
authored
[ET Device Support] CudaAllocator: device memory allocator for CUDA backend (#19747)
clone #18477 due to bot crash
1 parent 12f62f2 commit c27cc5d

6 files changed

Lines changed: 395 additions & 30 deletions

File tree

backends/aoti/slim/core/storage.h

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifdef CUDA_AVAILABLE
1414
#include <executorch/backends/aoti/slim/c10/cuda/Exception.h>
1515
#include <executorch/backends/aoti/slim/cuda/guard.h>
16+
#include <executorch/backends/cuda/runtime/cuda_allocator.h>
1617
#endif
1718

1819
#include <executorch/backends/aoti/slim/c10/core/Device.h>
@@ -107,9 +108,6 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
107108
/// @param device The target CUDA device (used to get the stream).
108109
/// @return Pointer to allocated device memory.
109110
static void* allocate(size_t nbytes, const c10::Device& device) {
110-
// Get the current stream for this device (set by CUDAStreamGuard if any)
111-
// This follows PyTorch's pattern where the allocator assumes the caller
112-
// has already set the correct device via CUDAStreamGuard.
113111
auto stream_result =
114112
executorch::backends::cuda::getCurrentCUDAStream(device.index());
115113
ET_CHECK_MSG(
@@ -118,31 +116,23 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
118116
static_cast<int>(device.index()));
119117

120118
cudaStream_t stream = stream_result.get();
121-
void* data = nullptr;
122-
ET_CUDA_CHECK(cudaMallocAsync(&data, nbytes, stream));
123-
return data;
119+
auto result = executorch::backends::cuda::CudaAllocator::allocate_async(
120+
nbytes, device.index(), stream);
121+
ET_CHECK_MSG(
122+
result.ok(),
123+
"CudaAllocator::allocate_async failed for %zu bytes on device %d",
124+
nbytes,
125+
static_cast<int>(device.index()));
126+
return result.get();
124127
}
125128

126-
/// Frees CUDA device memory on the current stream.
127-
/// @param ptr Pointer to device memory to free.
128129
static void free(void* ptr) {
129-
// Get the current stream for the current device
130-
// Currently all cuda slimtensors should be on the same device same stream,
131-
// so we can just use the stream on current device.
132-
// TODO(gasoonjia): add cuda stream as a member of MaybeOwningStorage to
133-
// support multiple devices.
134130
auto stream_result = executorch::backends::cuda::getCurrentCUDAStream(-1);
135131
ET_CHECK_MSG(stream_result.ok(), "Failed to get current CUDA stream");
136-
ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get()));
132+
executorch::backends::cuda::CudaAllocator::deallocate_async(
133+
ptr, -1, stream_result.get());
137134
}
138135

139-
/// Copies memory between CPU and CUDA or CUDA and CUDA asynchronously.
140-
/// @param dst Destination pointer.
141-
/// @param src Source pointer.
142-
/// @param nbytes Number of bytes to copy.
143-
/// @param dst_device Destination device.
144-
/// @param src_device Source device.
145-
/// @param stream CUDA stream for async copy.
146136
static void memcpy_async(
147137
void* dst,
148138
const void* src,
@@ -151,7 +141,6 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
151141
const c10::Device& src_device,
152142
cudaStream_t stream) {
153143
cudaMemcpyKind direction = cudaMemcpyDeviceToDevice;
154-
155144
if (src_device.is_cpu()) {
156145
direction = cudaMemcpyHostToDevice;
157146
} else if (dst_device.is_cpu()) {
@@ -164,23 +153,18 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
164153
static_cast<int>(dst_device.index()));
165154
}
166155

167-
ET_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, direction, stream));
156+
auto err = executorch::backends::cuda::CudaAllocator::memcpy_async(
157+
dst, src, nbytes, direction, stream);
158+
ET_CHECK_MSG(err == executorch::runtime::Error::Ok, "memcpy_async failed");
168159
}
169160

170-
/// Copies memory between CPU and CUDA or CUDA and CUDA synchronously.
171-
/// @param dst Destination pointer.
172-
/// @param src Source pointer.
173-
/// @param nbytes Number of bytes to copy.
174-
/// @param dst_device Destination device.
175-
/// @param src_device Source device.
176161
static void memcpy(
177162
void* dst,
178163
const void* src,
179164
size_t nbytes,
180165
const c10::Device& dst_device,
181166
const c10::Device& src_device) {
182167
cudaMemcpyKind direction = cudaMemcpyDeviceToDevice;
183-
184168
if (src_device.is_cpu()) {
185169
direction = cudaMemcpyHostToDevice;
186170
} else if (dst_device.is_cpu()) {

backends/aoti/slim/core/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def define_common_targets():
1919
"//executorch/runtime/platform:platform",
2020
"//executorch/backends/aoti/slim/c10/cuda:exception",
2121
"//executorch/backends/aoti/slim/cuda:guard",
22+
"//executorch/backends/cuda/runtime:cuda_allocator",
2223
],
2324
)
2425

backends/cuda/runtime/TARGETS

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,33 @@ runtime.cxx_library(
7474
],
7575
)
7676

77+
runtime.cxx_library(
78+
name = "cuda_allocator",
79+
srcs = [
80+
"cuda_allocator.cpp",
81+
],
82+
headers = [
83+
"cuda_allocator.h",
84+
],
85+
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
86+
link_whole = True,
87+
supports_python_dlopen = True,
88+
visibility = ["PUBLIC"],
89+
exported_deps = [
90+
"//executorch/runtime/core:device_allocator",
91+
],
92+
deps = [
93+
"//executorch/runtime/platform:platform",
94+
],
95+
nvcc_flags = get_nvcc_arch_args() + [
96+
"-_NVCC_HOST_COMPILER_FLAG_",
97+
"gcc",
98+
],
99+
external_deps = [
100+
("cuda", None, "cuda-lazy"),
101+
],
102+
)
103+
77104
runtime.cxx_library(
78105
name = "cuda_backend",
79106
srcs = [
@@ -92,6 +119,8 @@ runtime.cxx_library(
92119
deps = [
93120
":cuda_platform",
94121
":runtime_shims",
122+
":cuda_allocator",
123+
":cuda_platform",
95124
"//executorch/backends/aoti:aoti_common_slim",
96125
"//executorch/backends/aoti/slim/core:slimtensor",
97126
"//executorch/backends/aoti/slim/factory:empty",

0 commit comments

Comments
 (0)