Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/device/ep/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ buffer.disconnect_ranks(ranks)

## Key APIs

- `Buffer(rank_id, ...)`: Initialize the NIXL communication buffer
- `Buffer(rank_id, ..., nvl_group_size=8)`: Initialize the NIXL communication buffer
- `update_memory_buffers(num_ranks, num_experts_per_rank, num_rdma_bytes, num_nvl_bytes=0)`: Prepare buffers for up to `num_ranks` ranks and `num_experts_per_rank` experts
- `connect_ranks(remote_ranks, activate=True)`: Establish NIXL connections to new peers (can be called multiple times); in low-latency mode, use `activate=False` to keep new peers masked until explicitly unmasked.
- `disconnect_ranks(remote_ranks)`: Clean up connections to departing peers
Expand Down
18 changes: 10 additions & 8 deletions examples/device/ep/csrc/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,15 @@ struct Config {
EP_HOST_ASSERT(num_max_rdma_chunked_send_tokens <= num_max_rdma_chunked_recv_tokens / 2);
}

size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks) const {
size_t get_nvl_buffer_size_hint(size_t hidden_bytes, int num_ranks, int nvl_group_size = NUM_MAX_NVL_PEERS) const {
// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks <= NUM_MAX_NVL_PEERS or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / NUM_MAX_NVL_PEERS, 1);
EP_HOST_ASSERT(nvl_group_size > 0 and nvl_group_size <= NUM_MAX_NVL_PEERS and NUM_MAX_NVL_PEERS % nvl_group_size == 0);
EP_HOST_ASSERT(num_ranks < nvl_group_size or num_ranks % nvl_group_size == 0);
EP_HOST_ASSERT(num_ranks <= nvl_group_size or num_sms % 2 == 0);
const auto num_rdma_ranks = std::max(num_ranks / nvl_group_size, 1);
const auto num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS);
const int num_channels = num_sms / 2;

Expand All @@ -86,18 +87,19 @@ struct Config {
return num_bytes;
}

size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks) const {
size_t get_rdma_buffer_size_hint(int64_t hidden_bytes, int num_ranks, int nvl_group_size = NUM_MAX_NVL_PEERS) const {
// Legacy mode
if (num_ranks <= NUM_MAX_NVL_PEERS)
EP_HOST_ASSERT(nvl_group_size > 0 and nvl_group_size <= NUM_MAX_NVL_PEERS and NUM_MAX_NVL_PEERS % nvl_group_size == 0);
if (num_ranks <= nvl_group_size)
return 0;

// Below are some assumptions
// TODO: add assertions
constexpr int kNumMaxTopK = 128;
constexpr int kNumMaxScales = 128;
EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0);
EP_HOST_ASSERT(num_ranks % nvl_group_size == 0);
EP_HOST_ASSERT(num_sms % 2 == 0);
const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS;
const int num_rdma_ranks = num_ranks / nvl_group_size;
const int num_channels = num_sms / 2;

size_t num_bytes = 0;
Expand Down
29 changes: 25 additions & 4 deletions examples/device/ep/csrc/nixl_ep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,17 @@ void Buffer::update_memory_buffers(int num_ranks, int num_experts_per_rank, int6
}
}

Buffer::Buffer(int rank, bool explicitly_destroy, bool low_latency_mode, int timeout_ms):
Buffer::Buffer(int rank, bool explicitly_destroy, bool low_latency_mode, int timeout_ms, int nvl_group_size):
low_latency_mode(low_latency_mode),
timeout_ms([timeout_ms] {
EP_HOST_ASSERT(timeout_ms >= 0);
return static_cast<uint64_t>(timeout_ms);
}()),
rank(rank),
nvl_group_size([nvl_group_size] {
EP_HOST_ASSERT(nvl_group_size > 0 and nvl_group_size <= NUM_MAX_NVL_PEERS and NUM_MAX_NVL_PEERS % nvl_group_size == 0);
return nvl_group_size;
}()),
Comment thread
xtyao66 marked this conversation as resolved.
explicitly_destroy(explicitly_destroy),
comm_stream(at::cuda::getStreamFromPool(true)) {}

Expand Down Expand Up @@ -253,6 +257,14 @@ int Buffer::get_rdma_rank() const {
return rdma_rank;
}

int Buffer::get_nvl_rank() const {
return nvl_rank;
}

int Buffer::get_nvl_group_size() const {
return nvl_group_size;
}

int Buffer::get_root_rdma_rank(bool global) const {
return global ? nvl_rank : 0;
}
Expand Down Expand Up @@ -1500,15 +1512,22 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("num_sms") = 20,
py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256,
py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256)
.def("get_nvl_buffer_size_hint", &nixl_ep::Config::get_nvl_buffer_size_hint)
.def("get_rdma_buffer_size_hint", &nixl_ep::Config::get_rdma_buffer_size_hint);
.def("get_nvl_buffer_size_hint", &nixl_ep::Config::get_nvl_buffer_size_hint,
pybind11::arg("hidden_bytes"), pybind11::arg("num_ranks"), pybind11::arg("nvl_group_size") = NUM_MAX_NVL_PEERS)
.def("get_rdma_buffer_size_hint", &nixl_ep::Config::get_rdma_buffer_size_hint,
pybind11::arg("hidden_bytes"), pybind11::arg("num_ranks"), pybind11::arg("nvl_group_size") = NUM_MAX_NVL_PEERS);

pybind11::class_<nixl_ep::EventHandle>(m, "EventHandle")
.def(pybind11::init<>())
.def("current_stream_wait", &nixl_ep::EventHandle::current_stream_wait);

pybind11::class_<nixl_ep::Buffer>(m, "Buffer")
.def(pybind11::init<int, bool, bool, int>())
.def(pybind11::init<int, bool, bool, int, int>(),
pybind11::arg("rank"),
pybind11::arg("explicitly_destroy"),
pybind11::arg("low_latency_mode"),
pybind11::arg("timeout_ms"),
pybind11::arg("nvl_group_size") = NUM_MAX_NVL_PEERS)
.def("update_memory_buffers", &nixl_ep::Buffer::update_memory_buffers)
.def("barrier", &nixl_ep::Buffer::barrier)
.def("connect_ranks", [](nixl_ep::Buffer &buffer, const std::vector<int>& remote_ranks, const std::optional<std::vector<pybind11::bytes>>& remote_mds, const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles, bool activate) {
Expand All @@ -1518,6 +1537,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("is_available", &nixl_ep::Buffer::is_available)
.def("get_num_rdma_ranks", &nixl_ep::Buffer::get_num_rdma_ranks)
.def("get_rdma_rank", &nixl_ep::Buffer::get_rdma_rank)
.def("get_nvl_rank", &nixl_ep::Buffer::get_nvl_rank)
.def("get_nvl_group_size", &nixl_ep::Buffer::get_nvl_group_size)
.def("get_root_rdma_rank", &nixl_ep::Buffer::get_root_rdma_rank)
.def("get_local_device_id", &nixl_ep::Buffer::get_local_device_id)
.def("get_local_ipc_handle", &nixl_ep::Buffer::get_local_ipc_handle)
Expand Down
7 changes: 6 additions & 1 deletion examples/device/ep/csrc/nixl_ep.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ struct Buffer {
int num_device_sms;
uint64_t timeout_cycles = 0;
int rank, rdma_rank, nvl_rank;
int nvl_group_size = NUM_MAX_NVL_PEERS;
int max_num_ranks;
std::vector<int> remote_ranks; /* global ranks */
// Host-side active rank state over max_num_ranks. This can differ from
Expand Down Expand Up @@ -180,7 +181,7 @@ struct Buffer {
void _ipc_handles_sync(const std::vector<std::optional<pybind11::bytearray>> &all_gathered_handles);

public:
Buffer(int rank, bool explicitly_destroy, bool low_latency_mode, int timeout_ms);
Buffer(int rank, bool explicitly_destroy, bool low_latency_mode, int timeout_ms, int nvl_group_size = NUM_MAX_NVL_PEERS);

void update_memory_buffers(int num_ranks, int num_experts_per_rank, int64_t num_rdma_bytes, int64_t num_nvl_bytes = 0);

Expand All @@ -200,6 +201,10 @@ struct Buffer {

int get_rdma_rank() const;

int get_nvl_rank() const;

int get_nvl_group_size() const;

int get_root_rdma_rank(bool global) const;

int get_local_device_id() const;
Expand Down
13 changes: 12 additions & 1 deletion examples/device/ep/nixl_ep/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
comm: Optional["mpi4py.MPI.Comm"] = None,
tcp_store_group: Optional[dist.TCPStore] = None,
timeout_ms: int = DEFAULT_TIMEOUT_MS,
nvl_group_size: int = 8,
) -> None:
"""
Initialize the nixl communication buffer.
Expand All @@ -81,11 +82,15 @@ def __init__(
In low-latency paths, a timeout marks the rank invalid and masks it out.
In high-throughput paths, a timeout is fatal and traps.
Default: 30000 ms.
nvl_group_size: number of ranks in one CUDA-IPC/NVLink-local group.
Defaults to 8 for existing deployments.
"""
assert 0 < nvl_group_size <= 8 and 8 % nvl_group_size == 0
self.rank = rank
self.group_size = 0 # Will be updated by `update_memory_buffers`
self.low_latency_mode = low_latency_mode
self.timeout_ms = timeout_ms
self.nvl_group_size = nvl_group_size

self.explicitly_destroy = explicitly_destroy
self.group = group
Expand All @@ -97,7 +102,7 @@ def __init__(
os.environ["UCX_TLS"] = "^cuda_ipc"

self.runtime = nixl_ep_cpp.Buffer(
self.rank, explicitly_destroy, low_latency_mode, timeout_ms
self.rank, explicitly_destroy, low_latency_mode, timeout_ms, nvl_group_size
)

def destroy(self):
Expand All @@ -113,6 +118,12 @@ def destroy(self):
def is_sm90_compiled():
return nixl_ep_cpp.is_sm90_compiled()

def get_nvl_rank(self) -> int:
return self.runtime.get_nvl_rank()

def get_nvl_group_size(self) -> int:
return self.runtime.get_nvl_group_size()

@staticmethod
def set_num_sms(new_num_sms: int) -> None:
"""
Expand Down
Loading