Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
bcf1c9e
ep: add CXI/Slingshot (libfabric) transport backend behind USE_CXI
Jun 5, 2026
7eb9fd2
Fix init_dist to correctly set CUDA device and default device in mult…
fergusfinn Jun 6, 2026
3c3e5b9
Fix proxy initialization by passing data_ptr instead of raw torch ten…
fergusfinn Jun 6, 2026
c148a45
Let Buffer manage scratch allocation and initialize uccl internally i…
fergusfinn Jun 6, 2026
0ef8108
Revert changes to test_internode_simple.py
fergusfinn Jun 6, 2026
35ba3b6
Fix CXI low-latency internode path
fergusfinn Jun 6, 2026
6be7e09
Fix EP8 combine config for 4-GPU nodes
fergusfinn Jun 6, 2026
16de290
Fix host atomic buffer initialization
fergusfinn Jun 8, 2026
7dec293
Fix zero-token RDMA rank count clearing
fergusfinn Jun 9, 2026
c386222
Add zero-layout guard regression test
fergusfinn Jun 9, 2026
e0c2a92
ep/cxi: async write path with per-completion ring retirement
Jun 10, 2026
275e8cf
ep/cxi: gate control atomics on completion of their preceding writes
Jun 10, 2026
b9e6a6c
ep/bench: sustained dispatch loop for wire-utilization measurement
Jun 10, 2026
9b8dc12
ep/bench: fix dispatch_loop teardown ordering (context-destroyed abort)
Jun 10, 2026
e808810
ep/bench: drop cached-dispatch output ref before teardown in dispatch…
Jun 10, 2026
98304f6
ep/bench: scope measurement tensors so dispatch_loop tears down cleanly
Jun 10, 2026
8b45c6d
ep: account for the CXI barrier carve-out in the atomic-buffer sizing…
Jun 10, 2026
19f36fc
gitignore core dumps (containers run with the repo as workdir)
Jun 11, 2026
dd2a672
Merge pull request #4 from doublewordai/cxi-async-transport
fergusfinn Jun 11, 2026
373ce5d
ep: fix two review findings (barrier seq wrap under CXI, uninit member)
Jun 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,4 @@ thirdparty/gdrcopy/
.worktrees/
worktrees/
*.csv
*.egg-info/
*.egg-info//core
2 changes: 1 addition & 1 deletion ep/bench/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def get_combine_config(num_ranks: int) -> Config:
config_map = {
2: Config(Buffer.num_sms, 10, 256, 6, 128),
4: Config(Buffer.num_sms, 9, 256, 6, 128),
8: Config(Buffer.num_sms, 4, 256, 6, 128),
8: Config(Buffer.num_sms, 4, 256, 8, 128),
16: Config(Buffer.num_sms, 4, 288, 12, 512 if Buffer._is_efa() else 128),
24: Config(Buffer.num_sms, 1, 288, 8, 128),
32: Config(Buffer.num_sms, 1, 288, 8, 512 if Buffer._is_efa() else 128),
Expand Down
90 changes: 90 additions & 0 deletions ep/bench/dispatch_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Sustained dispatch loop for wire-utilization measurement.

Runs buffer.dispatch() at a fixed config in a tight loop for LOOP_DURATION_S
seconds (torchrun-style, one process per GPU). Bracket externally with CXI
telemetry snapshots to measure true NIC bytes/sec. LOOP_CACHED=1 reuses the
dispatch handle (skips the notify/count exchange per iteration).
"""
import gc, os, time
import torch
import torch.distributed as dist
from buffer import Buffer
from utils import init_dist_under_torchrun
from test_internode import compute_buffer_sizes
from uccl.ep import Config


def measure(buffer, group, rank, local_world):
duration = float(os.environ.get("LOOP_DURATION_S", "60"))
num_tokens = int(os.environ.get("LOOP_TOKENS", "4096"))
hidden = int(os.environ.get("LOOP_HIDDEN", "7168"))
num_topk = int(os.environ.get("LOOP_TOPK", "8"))
num_experts = int(os.environ.get("LOOP_EXPERTS", "288"))
nvl_chunk = int(os.environ.get("LOOP_NVL_CHUNK", "32"))
nvl_buf = int(os.environ.get("LOOP_NVL_BUF", "256"))
rdma_chunk = int(os.environ.get("LOOP_RDMA_CHUNK", "64"))
rdma_buf = int(os.environ.get("LOOP_RDMA_BUF", "128"))
cached = os.environ.get("LOOP_CACHED", "0") == "1"

torch.manual_seed(rank)
x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
scores = torch.randn((num_tokens, num_experts), dtype=torch.float32,
device="cuda").abs() + 1.0
topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True)[1]
topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32,
device="cuda")
(num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert,
is_token_in_rank, _) = buffer.get_dispatch_layout(topk_idx, num_experts)
config = Config(24, nvl_chunk, nvl_buf, rdma_chunk, rdma_buf)
args = dict(x=x, num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
topk_idx=topk_idx, topk_weights=topk_weights,
num_tokens_per_expert=num_tokens_per_expert, config=config)
if cached:
recv = buffer.dispatch(**args)
args = dict(x=x, handle=recv[4], config=config)
for _ in range(5):
buffer.dispatch(**args)
torch.cuda.synchronize()
dist.barrier(group)
if rank == 0:
print(f"[loop] start duration={duration}s tokens={num_tokens} "
f"cached={cached}", flush=True)
t0 = time.time()
iters = 0
while time.time() - t0 < duration:
buffer.dispatch(**args)
iters += 1
torch.cuda.synchronize()
elapsed = time.time() - t0
rdma_tokens = int(num_tokens_per_rdma_rank.sum().item()) - int(
num_tokens_per_rdma_rank[rank // local_world].item())
bytes_per_iter = rdma_tokens * hidden * 2
print(f"[loop] rank={rank} iters={iters} elapsed={elapsed:.2f}s "
f"rdma_tokens={rdma_tokens} bytes/iter={bytes_per_iter} "
f"offered_GBps={iters * bytes_per_iter / elapsed / 1e9:.2f}",
flush=True)


def main():
local_rank = int(os.environ["LOCAL_RANK"])
local_world = int(os.environ.get("LOCAL_WORLD_SIZE", "4"))
rank, world, group = init_dist_under_torchrun(local_rank, local_world)
hidden = int(os.environ.get("LOOP_HIDDEN", "7168"))
nvl_b, rdma_b = compute_buffer_sizes(24, hidden, world)
buffer = Buffer(group, nvl_b, rdma_b, low_latency_mode=False,
num_qps_per_rank=24, explicitly_destroy=True)
measure(buffer, group, rank, local_world)
# All measurement tensors died with measure()s frame; flush deferred
# frees before the buffer tears down the CUDA context.
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
dist.barrier(group)
buffer.destroy()
dist.destroy_process_group()


if __name__ == "__main__":
main()
86 changes: 80 additions & 6 deletions ep/bench/test_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_main(
args.num_experts,
)

assert num_experts % num_ranks == 0 and num_local_ranks == 8
assert num_experts % num_ranks == 0 and num_local_ranks in (4, 8)
if local_rank == 0:
print(
f"[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}",
Expand Down Expand Up @@ -394,6 +394,10 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
), f"{calc_diff(check_topk_weights, ref_topk_weights)}"

hash_value += hash_tensor(recv_x)
if getattr(args, "smoke_one", False):
if local_rank == 0:
print("[testing] smoke-one complete", flush=True)
return hash_value

# For later tuning
dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2
Expand All @@ -409,6 +413,24 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
# Tune dispatch performance
best_dispatch_results = None
fp8_factor = (1 + 4 / 128) / 2
fixed_dispatch = (
args.fixed_dispatch_nvl_chunk is not None
or args.fixed_dispatch_rdma_chunk is not None
)
if fixed_dispatch:
if (
args.fixed_dispatch_nvl_chunk is None
or args.fixed_dispatch_rdma_chunk is None
):
raise ValueError(
"--fixed-dispatch-nvl-chunk and --fixed-dispatch-rdma-chunk must be set together"
)
dispatch_nvl_chunks = (args.fixed_dispatch_nvl_chunk,)
dispatch_rdma_chunks = (args.fixed_dispatch_rdma_chunk,)
else:
dispatch_nvl_chunks = range(4, 45, 4)
dispatch_rdma_chunks = range(4, 129, 8)

for current_x in (x_e4m3, x):
best_time, best_results = 1e10, None
rdma_send_bytes = (
Expand All @@ -421,8 +443,8 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
if isinstance(current_x, tuple)
else dispatch_bf16_nvl_recv_bytes
)
for nvl_chunk_size in range(4, 45, 4):
for rdma_chunk_size in range(4, 33, 4):
for nvl_chunk_size in dispatch_nvl_chunks:
for rdma_chunk_size in dispatch_rdma_chunks:
config = Config(
num_sms,
nvl_chunk_size,
Expand All @@ -431,9 +453,13 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
rdma_buffer_size,
)
tune_args = {"x": current_x, "handle": handle, "config": config}
os.environ["UCCL_CXI_PHASE"] = (
"dispatch_fp8" if isinstance(current_x, tuple) else "dispatch_bf16"
)
t, notify_t = bench_kineto(
lambda: buffer.dispatch(**tune_args), ("dispatch", "notify")
)
os.environ.pop("UCCL_CXI_PHASE", None)
if t == 0 or notify_t == 0:
continue
if t < best_time:
Expand Down Expand Up @@ -486,12 +512,29 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
"num_tokens_per_expert": num_tokens_per_expert,
"config": dispatch_config if dispatch_config is not None else config,
}
os.environ["UCCL_CXI_PHASE"] = "dispatch_final"
recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args)
os.environ.pop("UCCL_CXI_PHASE", None)

# Tune combine performance
fixed_combine = (
args.fixed_combine_nvl_chunk is not None
or args.fixed_combine_rdma_chunk is not None
)
if fixed_combine:
if args.fixed_combine_nvl_chunk is None or args.fixed_combine_rdma_chunk is None:
raise ValueError(
"--fixed-combine-nvl-chunk and --fixed-combine-rdma-chunk must be set together"
)
combine_nvl_chunks = (args.fixed_combine_nvl_chunk,)
combine_rdma_chunks = (args.fixed_combine_rdma_chunk,)
else:
combine_nvl_chunks = range(1, 8, 1)
combine_rdma_chunks = range(12 if num_nodes == 2 else 8, 33, 4)

best_time, best_results = 1e10, None
for nvl_chunk_size in range(1, 8, 1):
for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4):
for nvl_chunk_size in combine_nvl_chunks:
for rdma_chunk_size in combine_rdma_chunks:
config = Config(
num_sms,
nvl_chunk_size,
Expand All @@ -500,9 +543,11 @@ def check_data(check_x, recv_gbl_rank_prefix_sum):
rdma_buffer_size,
)
tune_args = {"x": recv_x, "handle": handle, "config": config}
os.environ["UCCL_CXI_PHASE"] = "combine"
t, notify_t = bench_kineto(
lambda: buffer.combine(**tune_args), ("combine", "notify")
)
os.environ.pop("UCCL_CXI_PHASE", None)
if t == 0 or notify_t == 0:
continue
if local_rank == 0:
Expand Down Expand Up @@ -567,7 +612,7 @@ def test_loop(
explicitly_destroy=True,
)

assert num_local_ranks == 8 and num_ranks > 8
assert num_local_ranks in (4, 8) and num_ranks > num_local_ranks

for seed in range(int(1e9)):
if local_rank == 0:
Expand Down Expand Up @@ -659,6 +704,35 @@ def test_loop(
action="store_true",
help="whether to test compatibility with low-latency kernels",
)
parser.add_argument(
"--smoke-one",
action="store_true",
help="run only the first dispatch/combine correctness variant",
)
parser.add_argument(
"--fixed-dispatch-nvl-chunk",
type=int,
default=None,
help="benchmark only this dispatch NVL chunk size",
)
parser.add_argument(
"--fixed-dispatch-rdma-chunk",
type=int,
default=None,
help="benchmark only this dispatch RDMA chunk size",
)
parser.add_argument(
"--fixed-combine-nvl-chunk",
type=int,
default=None,
help="benchmark only this combine NVL chunk size",
)
parser.add_argument(
"--fixed-combine-rdma-chunk",
type=int,
default=None,
help="benchmark only this combine RDMA chunk size",
)
args = parser.parse_args()
world_size = int(os.environ["WORLD_SIZE"])
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
Expand Down
41 changes: 8 additions & 33 deletions ep/bench/test_internode_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
from utils import (
init_dist,
detect_ib_hca,
get_cpu_proxies_meta,
initialize_uccl,
destroy_uccl,
)


Expand All @@ -49,16 +46,12 @@ def test_simple_internode(rank: int, num_ranks: int, group: dist.ProcessGroup):
device_index
).multi_processor_count

scratch_nbytes = int(1e9) # 256 MB
scratch = torch.empty(
scratch_nbytes, dtype=torch.uint8, device=f"cuda:{device_index}"
)
proxies, workers = initialize_uccl(scratch, scratch_nbytes, rank, num_ranks, group)
scratch_nbytes = int(1e9)
buffer = None

try:
buffer = Buffer(
group=group,
rdma_buffer_ptr=scratch.data_ptr(),
num_nvl_bytes=0,
num_rdma_bytes=int(scratch_nbytes),
low_latency_mode=True,
Expand All @@ -71,20 +64,6 @@ def test_simple_internode(rank: int, num_ranks: int, group: dist.ProcessGroup):
if rank == 0:
print("[simple-test] ✓ Buffer created successfully", flush=True)

buffer.connect_atomic_buffer(proxies[0])

for proxy in proxies:
proxy.calculate_and_set_dispatch_recv_data_offset(
num_tokens, hidden, num_experts
)
proxy.set_atomic_buffer_ptr(proxies[0].get_atomic_buffer_ptr())

if rank == 0:
print(
"[simple-test] ✓ dispatch_recv_data_offset calculated and set by CPU proxy",
flush=True,
)

cumulative_local_expert_recv_stats = torch.zeros(
(num_experts // num_ranks,), dtype=torch.int, device="cuda"
)
Expand Down Expand Up @@ -129,7 +108,6 @@ def test_simple_internode(rank: int, num_ranks: int, group: dist.ProcessGroup):

time.sleep(1)
print("[simple-test] ✓ before destroy!", flush=True)

except Exception as e:
if rank == 0:
import traceback
Expand All @@ -138,16 +116,14 @@ def test_simple_internode(rank: int, num_ranks: int, group: dist.ProcessGroup):
traceback.print_exc()
raise

try:
buffer.destroy()
except Exception:
pass

dist.barrier()
print("[simple-test] ✓ Buffer destroyed", flush=True)
if buffer is not None:
try:
buffer.destroy()
except Exception:
pass

destroy_uccl(proxies, workers)
dist.barrier()
print("[simple-test] ✓ Buffer destroyed", flush=True)


def test_worker(local_rank: int, num_local_ranks: int):
Expand All @@ -156,7 +132,6 @@ def test_worker(local_rank: int, num_local_ranks: int):
try:
test_simple_internode(rank, num_ranks, group)
finally:
dist.barrier()
dist.destroy_process_group()


Expand Down
Loading