Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions benchmarks/benchmark_compare_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

import torch
import torch.distributed as dist
from nvshmem import core as nvshmem
from cuda.core.experimental import Device

try:
from nvrar import nvshmem_comm_cuda, resolve_params
Expand Down Expand Up @@ -381,12 +383,22 @@ def main():
# Initialize NVSHMEM communicator if available
comm_wrapper = None
if nvshmem_comm_cuda is not None:
uid_bytes = nvshmem_comm_cuda.NVSHMEMCommWrapper.get_unique_id_bytes()
uid_gpu = uid_bytes.to(device)
dist.broadcast(uid_gpu, src=0)
# Set device current
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really be a helper function because it's used in the benchmarks and the library itself. I couldn't think of where the best place to put it would be.

cuda_dev = Device(local_device_idx)
cuda_dev.set_current()
# Rank 0 obtains UID; broadcast via object list
uniqueid = nvshmem.get_unique_id(empty=True)
if rank == 0:
uniqueid = nvshmem.get_unique_id()
obj = [uniqueid]
else:
obj = [None]
dist.broadcast_object_list(obj, src=0)
dist.barrier()
uid_cpu = uid_gpu.to("cpu")
comm_wrapper = nvshmem_comm_cuda.NVSHMEMCommWrapper(rank, world_size, local_device_idx, uid_cpu)
# Initialize nvshmem4py
nvshmem.init(device=cuda_dev, uid=obj[0], rank=rank, nranks=world_size, initializer_method="uid")
# Construct wrapper without UID (nvshmem already initialized)
comm_wrapper = nvshmem_comm_cuda.NVSHMEMCommWrapper(rank, world_size, local_device_idx)

# Use default stream
stream = torch.cuda.Stream(device=local_device_idx)
Expand All @@ -411,7 +423,9 @@ def main():
nvrar_tensor_id = None
algorithm = "recursive"
if comm_wrapper is not None:
nvrar_tensor, nvrar_tensor_id = comm_wrapper.allocate_tensor(num_elems, dtype, device, nvshmem_comm_cuda.Protocol.LL8)
# Allocate symmetric tensor via nvshmem4py and register with wrapper
nvrar_tensor = nvshmem.tensor((num_elems,), dtype=dtype)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the first major difference. I couldn't think of a good way to handle the tensor_id stuff purely in python, so what I did is:

  • Replace tensor allocation with the nvshmem.core wrapper
  • keep the other parts of the process in your C code (and rename it to register_tensor instead of allocate_tensor)

nvrar_tensor_id = comm_wrapper.register_tensor(nvrar_tensor, nvshmem_comm_cuda.Protocol.LL8)

# Choose kernel params
if params_resolver is not None:
Expand Down Expand Up @@ -513,7 +527,8 @@ def main():

# Cleanup
if comm_wrapper is not None and nvrar_tensor_id is not None:
comm_wrapper.free_tensor(nvrar_tensor_id)
comm_wrapper.deregister_tensor(nvrar_tensor_id)
nvshmem.free_tensor(nvrar_tensor)

dist.barrier()
if dist.is_initialized():
Expand Down
40 changes: 30 additions & 10 deletions nvrar/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import torch
import torch.distributed as dist
import nvshmem.core
from cuda.core.experimental import Device
from . import nvshmem_comm_cuda

class NVRARCommunicator:
Expand All @@ -17,16 +19,34 @@ def __init__(self, process_group: torch.distributed.ProcessGroup):
return

device = torch.cuda.current_device()

unique_id = nvshmem_comm_cuda.NVSHMEMCommWrapper.get_unique_id_bytes()
uid_gpu = unique_id.to("cuda")
ranks = torch.distributed.get_process_group_ranks(process_group)
torch.distributed.broadcast(uid_gpu, src=ranks[0], group=process_group)
torch.distributed.barrier(group=process_group)

unique_id = uid_gpu.to("cpu")

self.comm_wrapper = nvshmem_comm_cuda.NVSHMEMCommWrapper(rank, nranks, device, unique_id)
cuda_dev = Device(device)
# This should be idempotent
cuda_dev.set_current()
stream = torch.cuda.current_stream()

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the same boilerplate

# Fetch NVSHMEM unique ID via nvshmem4py and broadcast via torch.distributed
uniqueid = nvshmem.core.get_unique_id(empty=True)
if rank == 0:
# Rank 0 gets a real uniqueid
uniqueid = nvshmem.core.get_unique_id()
broadcast_objects = [uniqueid]
else:
broadcast_objects = [None]

# We use torch.distributed.broadcast_object_list to send the UID to all ranks
dist.broadcast_object_list(broadcast_objects, src=0, group=process_group)
dist.barrier(group=process_group)

nvshmem.core.init(
device=cuda_dev,
uid=broadcast_objects[0],
rank=rank,
nranks=nranks,
initializer_method="uid",
)


self.comm_wrapper = nvshmem_comm_cuda.NVSHMEMCommWrapper(rank, nranks, device)
print(f"NVRARCommunicator created for process group {process_group} with rank {rank} and nranks {nranks}")

self.comm_wrapper.set_kernel_params(4, 256, 16384)
Expand Down
27 changes: 27 additions & 0 deletions nvrar/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,33 @@ target_include_directories(nvshmem_comm_cuda PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/include # Our headers
)

# 9.1) CUDA 13 CCCL headers live under <CUDA_ROOT>/include/cccl
# Try to locate CUDA root from CUDA_HOME/CUDA_PATH or from nvcc
# Allow user override via -DCUDA_CCCL_INCLUDE_DIR
set(CUDA_CCCL_INCLUDE_DIR "" CACHE PATH "Path to CUDA CCCL include directory (contains cuda/std)")
set(_CUDA_ROOT "")
if(DEFINED ENV{CUDA_HOME})
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is hacky and terrible and there is a better way to do it. In NVSHMEM's source, we handle it like this: https://github.com/NVIDIA/nvshmem/blob/2d7d25f0816235e3c2b51779571ec032606ea0dd/src/device/CMakeLists.txt#L188

set(_CUDA_ROOT $ENV{CUDA_HOME})
elseif(DEFINED ENV{CUDA_PATH})
set(_CUDA_ROOT $ENV{CUDA_PATH})
elseif(CMAKE_CUDA_COMPILER)
get_filename_component(_CUDA_BIN_DIR "${CMAKE_CUDA_COMPILER}" DIRECTORY)
get_filename_component(_CUDA_ROOT "${_CUDA_BIN_DIR}" DIRECTORY)
endif()
if(CUDA_CCCL_INCLUDE_DIR)
if(EXISTS "${CUDA_CCCL_INCLUDE_DIR}/cuda/std/tuple" OR EXISTS "${CUDA_CCCL_INCLUDE_DIR}/cuda/std")
message(STATUS "Using user-provided CUDA CCCL include path: ${CUDA_CCCL_INCLUDE_DIR}")
target_include_directories(nvshmem_comm_cuda PRIVATE "${CUDA_CCCL_INCLUDE_DIR}")
else()
message(WARNING "CUDA_CCCL_INCLUDE_DIR does not contain cuda/std: ${CUDA_CCCL_INCLUDE_DIR}")
endif()
elseif(_CUDA_ROOT AND EXISTS "${_CUDA_ROOT}/include/cccl")
message(STATUS "Adding CUDA CCCL include path: ${_CUDA_ROOT}/include/cccl")
target_include_directories(nvshmem_comm_cuda PRIVATE "${_CUDA_ROOT}/include/cccl")
else()
message(STATUS "CUDA CCCL include path not found; set -DCUDA_CCCL_INCLUDE_DIR=<path to .../include/cccl>")
endif()

# Torch Python shim: prefer target if present, else raw library
set(_need_torch_python ON)
if(TARGET Torch::Python)
Expand Down
49 changes: 17 additions & 32 deletions nvrar/csrc/include/coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ class IColl {
virtual void init(int num_blocks, int threads_per_block,
size_t chunk_size) = 0;

virtual std::tuple<torch::Tensor, uint64_t> allocate_tensor(
size_t size, torch::Dtype dt, torch::Device dev) = 0;
virtual void free_tensor(uint64_t id) = 0;
// Register an externally-allocated symmetric tensor (e.g., via nvshmem4py)
// Returns a newly assigned tensor id
virtual uint64_t register_external_tensor(torch::Tensor& t) = 0;
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the renaming I mentioned above.

// Deregister a previously registered tensor without freeing memory
virtual void deregister_tensor(uint64_t id) = 0;

virtual void dispatch_allreduce_preallocated(torch::Tensor& t, uint64_t id,
cudaStream_t s,
Expand All @@ -47,10 +49,7 @@ template <class Derived>
class CollBase : public IColl {
public:
~CollBase() noexcept override {
for (auto [id, ptr] : allocated_tensors_) {
nvshmem_free(ptr);
}
allocated_tensors_.clear();
// Nothing to free: external tensors are owned by nvshmem4py
}

void init(int num_blocks, int threads_per_block, size_t chunk_size) override {
Expand All @@ -67,34 +66,23 @@ class CollBase : public IColl {
derived()->initialize(num_blocks, threads_per_block, chunk_size);
}

std::tuple<torch::Tensor, uint64_t> allocate_tensor(
size_t size, torch::Dtype dt, torch::Device dev) override {
void* ptr = nvshmem_malloc(size * torch::elementSize(dt));

if (!ptr) {
throw std::runtime_error("Failed to allocate tensor memory");
uint64_t register_external_tensor(torch::Tensor& t) override {
// Accept a pre-allocated symmetric tensor; we do not own its memory
void* ptr = t.data_ptr();
if (ptr == nullptr) {
throw std::runtime_error("register_external_tensor: null data_ptr");
}

uint64_t id = next_id_.fetch_add(1);
allocated_tensors_[id] = ptr;

// Register the tensor with the derived class which can maintain its own
// scratch memory
// Let derived class register scratch/meta using size/dtype/device
const size_t size = static_cast<size_t>(t.numel());
const torch::Dtype dt = t.scalar_type();
const torch::Device dev = t.device();
derived()->register_tensor(id, size, dt, dev);

auto tensor = torch::from_blob(ptr, {static_cast<int64_t>(size)},
torch::dtype(dt).device(dev));
return std::make_tuple(tensor, id);
return id;
}

void free_tensor(uint64_t id) override {
if (allocated_tensors_.find(id) == allocated_tensors_.end()) {
throw std::runtime_error("Invalid tensor ID");
}
nvshmem_free(allocated_tensors_[id]);
void deregister_tensor(uint64_t id) override {
derived()->deregister_tensor(id);

allocated_tensors_.erase(id);
}

void dispatch_allreduce_preallocated(torch::Tensor& t, uint64_t id,
Expand Down Expand Up @@ -124,9 +112,6 @@ class CollBase : public IColl {
Derived* derived() { return static_cast<Derived*>(this); }

protected:
// Memory Pools for allocated tensors
std::unordered_map<uint64_t, void*> allocated_tensors_;

// Next ID for allocated tensors
std::atomic<uint64_t> next_id_;

Expand Down
22 changes: 4 additions & 18 deletions nvrar/csrc/include/nvshmem_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
class NVSHMEMCommWrapper {
public:
NVSHMEMCommWrapper(int rank, int world_size, int device);
// Initialize using NVSHMEM unique id based attributes; unique_id is the raw
// bytes returned by nvshmemx_get_uniqueid
NVSHMEMCommWrapper(int rank, int world_size, int device,
const torch::Tensor& unique_id_bytes);
~NVSHMEMCommWrapper();

// Disable copy constructor and assignment operator
Expand All @@ -35,11 +31,10 @@ class NVSHMEMCommWrapper {

void destroy();

std::tuple<torch::Tensor, uint64_t> allocate_tensor(size_t size,
torch::Dtype dtype,
torch::Device device,
Protocol protocol);
void free_tensor(uint64_t id);
// Register an externally-allocated symmetric tensor (e.g., nvshmem4py)
// Returns an internal tensor id used by collectives
uint64_t register_tensor(torch::Tensor& tensor, Protocol protocol);
void deregister_tensor(uint64_t id);

// Collective operations
void allreduce_preallocated(torch::Tensor& tensor, uint64_t id,
Expand All @@ -50,15 +45,6 @@ class NVSHMEMCommWrapper {
void set_kernel_params(Protocol protocol, int num_blocks,
int threads_per_block, size_t chunk_size);

// Getter methods
int get_rank() const { return rank_; }
int get_world_size() const { return world_size_; }

int get_mype() const { return mype_; }
int get_npes() const { return npes_; }

static torch::Tensor get_unique_id_bytes();

private:
void initialize_coll(Protocol protocol);

Expand Down
24 changes: 15 additions & 9 deletions nvrar/csrc/src/ll8_coll.cu
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,18 @@ void RecursiveLL8Coll::register_tensor(uint64_t id, size_t size,
nvshmem_barrier_all();

// Create signal tensors that hold the peer sequence numbers to wait on
uint64_t* seq_num_signal =
(uint64_t*)nvshmem_calloc(steps_inter_, sizeof(uint64_t));
if (!seq_num_signal) {
throw std::runtime_error("Failed to allocate signal memory");
uint64_t* seq_num_signal = nullptr;
// TODO:
if (steps_inter_ > 0) {
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just here so my tests would pass on 1 node. If it's 1 node but we don't have this check, the calloc will fail because steps_inter_ is 0 so we allocate nothing.

seq_num_signal = (uint64_t*)nvshmem_calloc(steps_inter_, sizeof(uint64_t));
if (!seq_num_signal) {
throw std::runtime_error("Failed to allocate signal memory");
}
init_signal_kernel<<<1, steps_inter_>>>(seq_num_signal, steps_inter_);
cudaDeviceSynchronize();
nvshmem_barrier_all();
}

init_signal_kernel<<<1, steps_inter_>>>(seq_num_signal, steps_inter_);
cudaDeviceSynchronize();
nvshmem_barrier_all();

allocated_scratch_send_[id] = send_scratch;
allocated_scratch_recv_[id] = recv_scratch;
seq_nums_[id] = seq_num;
Expand All @@ -321,13 +323,17 @@ void RecursiveLL8Coll::register_tensor(uint64_t id, size_t size,
}

void RecursiveLL8Coll::deregister_tensor(uint64_t id) {
// TODO: Implement
// TODO: Adding this so that I can test on 1-node. Is this valuable?
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

if (allocated_scratch_send_.find(id) == allocated_scratch_send_.end()) {
throw std::runtime_error("Invalid tensor ID");
}
nvshmem_free(allocated_scratch_send_[id]);
nvshmem_free(allocated_scratch_recv_[id]);
nvshmem_free(seq_num_signals_[id]);
// Free only if allocated (steps_inter_ > 0)
if (seq_num_signals_[id]) {
nvshmem_free(seq_num_signals_[id]);
}
cudaFree(seq_nums_[id]);
allocated_scratch_send_.erase(id);
allocated_scratch_recv_.erase(id);
Expand Down
Loading