1313#ifdef CUDA_AVAILABLE
1414#include < executorch/backends/aoti/slim/c10/cuda/Exception.h>
1515#include < executorch/backends/aoti/slim/cuda/guard.h>
16+ #include < executorch/backends/cuda/runtime/cuda_allocator.h>
1617#endif
1718
1819#include < executorch/backends/aoti/slim/c10/core/Device.h>
@@ -107,9 +108,6 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
107108 // / @param device The target CUDA device (used to get the stream).
108109 // / @return Pointer to allocated device memory.
109110 static void * allocate (size_t nbytes, const c10::Device& device) {
110- // Get the current stream for this device (set by CUDAStreamGuard if any)
111- // This follows PyTorch's pattern where the allocator assumes the caller
112- // has already set the correct device via CUDAStreamGuard.
113111 auto stream_result =
114112 executorch::backends::cuda::getCurrentCUDAStream (device.index ());
115113 ET_CHECK_MSG (
@@ -118,31 +116,23 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
118116 static_cast <int >(device.index ()));
119117
120118 cudaStream_t stream = stream_result.get ();
121- void * data = nullptr ;
122- ET_CUDA_CHECK (cudaMallocAsync (&data, nbytes, stream));
123- return data;
119+ auto result = executorch::backends::cuda::CudaAllocator::allocate_async (
120+ nbytes, device.index (), stream);
121+ ET_CHECK_MSG (
122+ result.ok (),
123+ " CudaAllocator::allocate_async failed for %zu bytes on device %d" ,
124+ nbytes,
125+ static_cast <int >(device.index ()));
126+ return result.get ();
124127 }
125128
126- // / Frees CUDA device memory on the current stream.
127- // / @param ptr Pointer to device memory to free.
128129 static void free (void * ptr) {
129- // Get the current stream for the current device
130- // Currently all cuda slimtensors should be on the same device same stream,
131- // so we can just use the stream on current device.
132- // TODO(gasoonjia): add cuda stream as a member of MaybeOwningStorage to
133- // support multiple devices.
134130 auto stream_result = executorch::backends::cuda::getCurrentCUDAStream (-1 );
135131 ET_CHECK_MSG (stream_result.ok (), " Failed to get current CUDA stream" );
136- ET_CUDA_LOG_WARN (cudaFreeAsync (ptr, stream_result.get ()));
132+ executorch::backends::cuda::CudaAllocator::deallocate_async (
133+ ptr, -1 , stream_result.get ());
137134 }
138135
139- // / Copies memory between CPU and CUDA or CUDA and CUDA asynchronously.
140- // / @param dst Destination pointer.
141- // / @param src Source pointer.
142- // / @param nbytes Number of bytes to copy.
143- // / @param dst_device Destination device.
144- // / @param src_device Source device.
145- // / @param stream CUDA stream for async copy.
146136 static void memcpy_async (
147137 void * dst,
148138 const void * src,
@@ -151,7 +141,6 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
151141 const c10::Device& src_device,
152142 cudaStream_t stream) {
153143 cudaMemcpyKind direction = cudaMemcpyDeviceToDevice;
154-
155144 if (src_device.is_cpu ()) {
156145 direction = cudaMemcpyHostToDevice;
157146 } else if (dst_device.is_cpu ()) {
@@ -164,23 +153,18 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
164153 static_cast <int >(dst_device.index ()));
165154 }
166155
167- ET_CUDA_CHECK (cudaMemcpyAsync (dst, src, nbytes, direction, stream));
156+ auto err = executorch::backends::cuda::CudaAllocator::memcpy_async (
157+ dst, src, nbytes, direction, stream);
158+ ET_CHECK_MSG (err == executorch::runtime::Error::Ok, " memcpy_async failed" );
168159 }
169160
170- // / Copies memory between CPU and CUDA or CUDA and CUDA synchronously.
171- // / @param dst Destination pointer.
172- // / @param src Source pointer.
173- // / @param nbytes Number of bytes to copy.
174- // / @param dst_device Destination device.
175- // / @param src_device Source device.
176161 static void memcpy (
177162 void * dst,
178163 const void * src,
179164 size_t nbytes,
180165 const c10::Device& dst_device,
181166 const c10::Device& src_device) {
182167 cudaMemcpyKind direction = cudaMemcpyDeviceToDevice;
183-
184168 if (src_device.is_cpu ()) {
185169 direction = cudaMemcpyHostToDevice;
186170 } else if (dst_device.is_cpu ()) {
0 commit comments