-
Notifications
You must be signed in to change notification settings - Fork 2
Make a first pass at using NVSHMEM4Py for host-side library management, etc. #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 5 commits
36e4eb9
82ba99c
959f467
4706f16
1b2776c
1fd2280
b548347
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,8 @@ | |
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from nvshmem import core as nvshmem | ||
| from cuda.core.experimental import Device | ||
|
|
||
| try: | ||
| from nvrar import nvshmem_comm_cuda, resolve_params | ||
|
|
@@ -381,12 +383,22 @@ def main(): | |
| # Initialize NVSHMEM communicator if available | ||
| comm_wrapper = None | ||
| if nvshmem_comm_cuda is not None: | ||
| uid_bytes = nvshmem_comm_cuda.NVSHMEMCommWrapper.get_unique_id_bytes() | ||
| uid_gpu = uid_bytes.to(device) | ||
| dist.broadcast(uid_gpu, src=0) | ||
| # Set device current | ||
| cuda_dev = Device(local_device_idx) | ||
| cuda_dev.set_current() | ||
| # Rank 0 obtains UID; broadcast via object list | ||
| uniqueid = nvshmem.get_unique_id(empty=True) | ||
| if rank == 0: | ||
| uniqueid = nvshmem.get_unique_id() | ||
| obj = [uniqueid] | ||
| else: | ||
| obj = [None] | ||
| dist.broadcast_object_list(obj, src=0) | ||
| dist.barrier() | ||
| uid_cpu = uid_gpu.to("cpu") | ||
| comm_wrapper = nvshmem_comm_cuda.NVSHMEMCommWrapper(rank, world_size, local_device_idx, uid_cpu) | ||
| # Initialize nvshmem4py | ||
| nvshmem.init(device=cuda_dev, uid=obj[0], rank=rank, nranks=world_size, initializer_method="uid") | ||
| # Construct wrapper without UID (nvshmem already initialized) | ||
| comm_wrapper = nvshmem_comm_cuda.NVSHMEMCommWrapper(rank, world_size, local_device_idx) | ||
|
|
||
| # Use default stream | ||
| stream = torch.cuda.Stream(device=local_device_idx) | ||
|
|
@@ -411,7 +423,9 @@ def main(): | |
| nvrar_tensor_id = None | ||
| algorithm = "recursive" | ||
| if comm_wrapper is not None: | ||
| nvrar_tensor, nvrar_tensor_id = comm_wrapper.allocate_tensor(num_elems, dtype, device, nvshmem_comm_cuda.Protocol.LL8) | ||
| # Allocate symmetric tensor via nvshmem4py and register with wrapper | ||
| nvrar_tensor = nvshmem.tensor((num_elems,), dtype=dtype) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the first major difference. I couldn't think of a good way to handle the
|
||
| nvrar_tensor_id = comm_wrapper.register_tensor(nvrar_tensor, nvshmem_comm_cuda.Protocol.LL8) | ||
|
|
||
| # Choose kernel params | ||
| if params_resolver is not None: | ||
|
|
@@ -513,7 +527,8 @@ def main(): | |
|
|
||
| # Cleanup | ||
| if comm_wrapper is not None and nvrar_tensor_id is not None: | ||
| comm_wrapper.free_tensor(nvrar_tensor_id) | ||
| comm_wrapper.deregister_tensor(nvrar_tensor_id) | ||
| nvshmem.free_tensor(nvrar_tensor) | ||
|
|
||
| dist.barrier() | ||
| if dist.is_initialized(): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,8 @@ | |
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| import nvshmem.core | ||
| from cuda.core.experimental import Device | ||
| from . import nvshmem_comm_cuda | ||
|
|
||
| class NVRARCommunicator: | ||
|
|
@@ -17,16 +19,34 @@ def __init__(self, process_group: torch.distributed.ProcessGroup): | |
| return | ||
|
|
||
| device = torch.cuda.current_device() | ||
|
|
||
| unique_id = nvshmem_comm_cuda.NVSHMEMCommWrapper.get_unique_id_bytes() | ||
| uid_gpu = unique_id.to("cuda") | ||
| ranks = torch.distributed.get_process_group_ranks(process_group) | ||
| torch.distributed.broadcast(uid_gpu, src=ranks[0], group=process_group) | ||
| torch.distributed.barrier(group=process_group) | ||
|
|
||
| unique_id = uid_gpu.to("cpu") | ||
|
|
||
| self.comm_wrapper = nvshmem_comm_cuda.NVSHMEMCommWrapper(rank, nranks, device, unique_id) | ||
| cuda_dev = Device(device) | ||
| # This should be idempotent | ||
| cuda_dev.set_current() | ||
| stream = torch.cuda.current_stream() | ||
|
|
||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's the same boilerplate |
||
| # Fetch NVSHMEM unique ID via nvshmem4py and broadcast via torch.distributed | ||
| uniqueid = nvshmem.core.get_unique_id(empty=True) | ||
| if rank == 0: | ||
| # Rank 0 gets a real uniqueid | ||
| uniqueid = nvshmem.core.get_unique_id() | ||
| broadcast_objects = [uniqueid] | ||
| else: | ||
| broadcast_objects = [None] | ||
|
|
||
| # We use torch.distributed.broadcast_object_list to send the UID to all ranks | ||
| dist.broadcast_object_list(broadcast_objects, src=0, group=process_group) | ||
| dist.barrier(group=process_group) | ||
|
|
||
| nvshmem.core.init( | ||
| device=cuda_dev, | ||
| uid=broadcast_objects[0], | ||
| rank=rank, | ||
| nranks=nranks, | ||
| initializer_method="uid", | ||
| ) | ||
|
|
||
|
|
||
| self.comm_wrapper = nvshmem_comm_cuda.NVSHMEMCommWrapper(rank, nranks, device) | ||
| print(f"NVRARCommunicator created for process group {process_group} with rank {rank} and nranks {nranks}") | ||
|
|
||
| self.comm_wrapper.set_kernel_params(4, 256, 16384) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,6 +72,33 @@ target_include_directories(nvshmem_comm_cuda PRIVATE | |
| ${CMAKE_CURRENT_SOURCE_DIR}/include # Our headers | ||
| ) | ||
|
|
||
| # 9.1) CUDA 13 CCCL headers live under <CUDA_ROOT>/include/cccl | ||
| # Try to locate CUDA root from CUDA_HOME/CUDA_PATH or from nvcc | ||
| # Allow user override via -DCUDA_CCCL_INCLUDE_DIR | ||
| set(CUDA_CCCL_INCLUDE_DIR "" CACHE PATH "Path to CUDA CCCL include directory (contains cuda/std)") | ||
| set(_CUDA_ROOT "") | ||
| if(DEFINED ENV{CUDA_HOME}) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is hacky and terrible and there is a better way to do it. In NVSHMEM's source, we handle it like this: https://github.com/NVIDIA/nvshmem/blob/2d7d25f0816235e3c2b51779571ec032606ea0dd/src/device/CMakeLists.txt#L188 |
||
| set(_CUDA_ROOT $ENV{CUDA_HOME}) | ||
| elseif(DEFINED ENV{CUDA_PATH}) | ||
| set(_CUDA_ROOT $ENV{CUDA_PATH}) | ||
| elseif(CMAKE_CUDA_COMPILER) | ||
| get_filename_component(_CUDA_BIN_DIR "${CMAKE_CUDA_COMPILER}" DIRECTORY) | ||
| get_filename_component(_CUDA_ROOT "${_CUDA_BIN_DIR}" DIRECTORY) | ||
| endif() | ||
| if(CUDA_CCCL_INCLUDE_DIR) | ||
| if(EXISTS "${CUDA_CCCL_INCLUDE_DIR}/cuda/std/tuple" OR EXISTS "${CUDA_CCCL_INCLUDE_DIR}/cuda/std") | ||
| message(STATUS "Using user-provided CUDA CCCL include path: ${CUDA_CCCL_INCLUDE_DIR}") | ||
| target_include_directories(nvshmem_comm_cuda PRIVATE "${CUDA_CCCL_INCLUDE_DIR}") | ||
| else() | ||
| message(WARNING "CUDA_CCCL_INCLUDE_DIR does not contain cuda/std: ${CUDA_CCCL_INCLUDE_DIR}") | ||
| endif() | ||
| elseif(_CUDA_ROOT AND EXISTS "${_CUDA_ROOT}/include/cccl") | ||
| message(STATUS "Adding CUDA CCCL include path: ${_CUDA_ROOT}/include/cccl") | ||
| target_include_directories(nvshmem_comm_cuda PRIVATE "${_CUDA_ROOT}/include/cccl") | ||
| else() | ||
| message(STATUS "CUDA CCCL include path not found; set -DCUDA_CCCL_INCLUDE_DIR=<path to .../include/cccl>") | ||
| endif() | ||
|
|
||
| # Torch Python shim: prefer target if present, else raw library | ||
| set(_need_torch_python ON) | ||
| if(TARGET Torch::Python) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,9 +30,11 @@ class IColl { | |
| virtual void init(int num_blocks, int threads_per_block, | ||
| size_t chunk_size) = 0; | ||
|
|
||
| virtual std::tuple<torch::Tensor, uint64_t> allocate_tensor( | ||
| size_t size, torch::Dtype dt, torch::Device dev) = 0; | ||
| virtual void free_tensor(uint64_t id) = 0; | ||
| // Register an externally-allocated symmetric tensor (e.g., via nvshmem4py) | ||
| // Returns a newly assigned tensor id | ||
| virtual uint64_t register_external_tensor(torch::Tensor& t) = 0; | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's the renaming I mentioned above. |
||
| // Deregister a previously registered tensor without freeing memory | ||
| virtual void deregister_tensor(uint64_t id) = 0; | ||
|
|
||
| virtual void dispatch_allreduce_preallocated(torch::Tensor& t, uint64_t id, | ||
| cudaStream_t s, | ||
|
|
@@ -47,10 +49,7 @@ template <class Derived> | |
| class CollBase : public IColl { | ||
| public: | ||
| ~CollBase() noexcept override { | ||
| for (auto [id, ptr] : allocated_tensors_) { | ||
| nvshmem_free(ptr); | ||
| } | ||
| allocated_tensors_.clear(); | ||
| // Nothing to free: external tensors are owned by nvshmem4py | ||
| } | ||
|
|
||
| void init(int num_blocks, int threads_per_block, size_t chunk_size) override { | ||
|
|
@@ -67,34 +66,23 @@ class CollBase : public IColl { | |
| derived()->initialize(num_blocks, threads_per_block, chunk_size); | ||
| } | ||
|
|
||
| std::tuple<torch::Tensor, uint64_t> allocate_tensor( | ||
| size_t size, torch::Dtype dt, torch::Device dev) override { | ||
| void* ptr = nvshmem_malloc(size * torch::elementSize(dt)); | ||
|
|
||
| if (!ptr) { | ||
| throw std::runtime_error("Failed to allocate tensor memory"); | ||
| uint64_t register_external_tensor(torch::Tensor& t) override { | ||
| // Accept a pre-allocated symmetric tensor; we do not own its memory | ||
| void* ptr = t.data_ptr(); | ||
| if (ptr == nullptr) { | ||
| throw std::runtime_error("register_external_tensor: null data_ptr"); | ||
| } | ||
|
|
||
| uint64_t id = next_id_.fetch_add(1); | ||
| allocated_tensors_[id] = ptr; | ||
|
|
||
| // Register the tensor with the derived class which can maintain its own | ||
| // scratch memory | ||
| // Let derived class register scratch/meta using size/dtype/device | ||
| const size_t size = static_cast<size_t>(t.numel()); | ||
| const torch::Dtype dt = t.scalar_type(); | ||
| const torch::Device dev = t.device(); | ||
| derived()->register_tensor(id, size, dt, dev); | ||
|
|
||
| auto tensor = torch::from_blob(ptr, {static_cast<int64_t>(size)}, | ||
| torch::dtype(dt).device(dev)); | ||
| return std::make_tuple(tensor, id); | ||
| return id; | ||
| } | ||
|
|
||
| void free_tensor(uint64_t id) override { | ||
| if (allocated_tensors_.find(id) == allocated_tensors_.end()) { | ||
| throw std::runtime_error("Invalid tensor ID"); | ||
| } | ||
| nvshmem_free(allocated_tensors_[id]); | ||
| void deregister_tensor(uint64_t id) override { | ||
| derived()->deregister_tensor(id); | ||
|
|
||
| allocated_tensors_.erase(id); | ||
| } | ||
|
|
||
| void dispatch_allreduce_preallocated(torch::Tensor& t, uint64_t id, | ||
|
|
@@ -124,9 +112,6 @@ class CollBase : public IColl { | |
| Derived* derived() { return static_cast<Derived*>(this); } | ||
|
|
||
| protected: | ||
| // Memory Pools for allocated tensors | ||
| std::unordered_map<uint64_t, void*> allocated_tensors_; | ||
|
|
||
| // Next ID for allocated tensors | ||
| std::atomic<uint64_t> next_id_; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -303,16 +303,18 @@ void RecursiveLL8Coll::register_tensor(uint64_t id, size_t size, | |
| nvshmem_barrier_all(); | ||
|
|
||
| // Create signal tensors that hold the peer sequence numbers to wait on | ||
| uint64_t* seq_num_signal = | ||
| (uint64_t*)nvshmem_calloc(steps_inter_, sizeof(uint64_t)); | ||
| if (!seq_num_signal) { | ||
| throw std::runtime_error("Failed to allocate signal memory"); | ||
| uint64_t* seq_num_signal = nullptr; | ||
| // TODO: | ||
| if (steps_inter_ > 0) { | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just here so my tests would pass on 1 node. If it's 1 node but we don't have this check, the calloc will fail because steps_inter_ is 0 so we allocate nothing. |
||
| seq_num_signal = (uint64_t*)nvshmem_calloc(steps_inter_, sizeof(uint64_t)); | ||
| if (!seq_num_signal) { | ||
| throw std::runtime_error("Failed to allocate signal memory"); | ||
| } | ||
| init_signal_kernel<<<1, steps_inter_>>>(seq_num_signal, steps_inter_); | ||
| cudaDeviceSynchronize(); | ||
| nvshmem_barrier_all(); | ||
| } | ||
|
|
||
| init_signal_kernel<<<1, steps_inter_>>>(seq_num_signal, steps_inter_); | ||
| cudaDeviceSynchronize(); | ||
| nvshmem_barrier_all(); | ||
|
|
||
| allocated_scratch_send_[id] = send_scratch; | ||
| allocated_scratch_recv_[id] = recv_scratch; | ||
| seq_nums_[id] = seq_num; | ||
|
|
@@ -321,13 +323,17 @@ void RecursiveLL8Coll::register_tensor(uint64_t id, size_t size, | |
| } | ||
|
|
||
| void RecursiveLL8Coll::deregister_tensor(uint64_t id) { | ||
| // TODO: Implement | ||
| // TODO: Adding this so that I can test on 1-node. Is this valuable? | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
| if (allocated_scratch_send_.find(id) == allocated_scratch_send_.end()) { | ||
| throw std::runtime_error("Invalid tensor ID"); | ||
| } | ||
| nvshmem_free(allocated_scratch_send_[id]); | ||
| nvshmem_free(allocated_scratch_recv_[id]); | ||
| nvshmem_free(seq_num_signals_[id]); | ||
| // Free only if allocated (steps_inter_ > 0) | ||
| if (seq_num_signals_[id]) { | ||
| nvshmem_free(seq_num_signals_[id]); | ||
| } | ||
| cudaFree(seq_nums_[id]); | ||
| allocated_scratch_send_.erase(id); | ||
| allocated_scratch_recv_.erase(id); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should really be a helper function because it's used in the benchmarks and the library itself. I couldn't think of where the best place to put it would be.