Skip to content

Nixl weight transfer#2326

Open
S1ro1 wants to merge 9 commits intomainfrom
nixl-weight-transfer
Open

Nixl weight transfer#2326
S1ro1 wants to merge 9 commits intomainfrom
nixl-weight-transfer

Conversation

@S1ro1
Copy link
Copy Markdown
Collaborator

@S1ro1 S1ro1 commented Apr 19, 2026

RDMA based weight transfer for GLM5 FP8


Note

High Risk
Introduces complex distributed RDMA-based weight transfer, new rendezvous/synchronization paths, and a custom CUDA allocator/extension build, which can fail in subtle GPU/driver/cluster configurations and affects the core weight-update mechanism.

Overview
Adds a new nixl weight broadcast mode that transfers weights via NIXL/UCX RDMA directly into inference GPU memory, alongside existing filesystem/nccl options.

This wires a full trainer↔orchestrator↔vLLM protocol: new config variants and SLURM flags, a new inference admin route (/init_nixl_transfer) and vLLM worker extension, plus a trainer-side NIXLWeightBroadcast that pre-registers stable buffers, FP8-quantizes into them in-place, and posts chunked RDMA writes with expert/non-expert routing and synchronization barriers.

To support zero-copy FP8 updates, it replaces the prior Python FP8 quantizer with Triton kernels that write into caller-provided output buffers, extends the GLM MoE DSA model to allocate/register NIXL slot buffers (using a classic cudaMalloc mempool to avoid VMM issues), and adds new unit/integration tests and fixtures covering smoke, single-rank, and multi-rank transfers.

Reviewed by Cursor Bugbot for commit 18b39fe. Bugbot is set up for automated code reviews on this repo. Configure here.

Comment thread src/prime_rl/inference/vllm/worker/nixl.py Outdated
Comment thread src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py Outdated
S1ro1 and others added 2 commits April 19, 2026 19:00
Replace the if/else mess in convert_tt_layer_to_vllm_kernel with a
deterministic spec table mirroring mapper.py; only branch left is dense
vs sparse layer. out_buffers is now required (fail-loud on missing key),
experts quantize in a single grouped-triton call instead of a per-expert
loop. Slot allocation moved into the NIXL broadcast init so the
converter never allocates, and expert_lead / konig / ImportError fallbacks
are gone.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment thread src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py Outdated
Comment thread src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py Outdated
Comment thread src/prime_rl/trainer/rl/broadcast/nixl.py Outdated
``out_buffers`` in place. Used by the NIXL broadcast path.
"""
Convert a single layer's state dict from PrimeRL format to vLLM kernel format.
raise NotImplementedError(f"convert_layer_to_vllm_kernel is not implemented for {self.__class__.__name__}")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL FP8 path broken by signature change

High Severity

convert_layer_to_vllm_kernel changed from a classmethod with signature (cls, state_dict, layer_idx, quantize_fp8) to an instance method with signature (self, layer_idx, out_buffers). The call in preprocess_layer_quantized still uses the old signature: model.convert_layer_to_vllm_kernel(layer_state_dict, layer_idx, quantize_fp8=True). This will pass layer_state_dict (a dict) as layer_idx (an int) and layer_idx (an int) as out_buffers (a dict), crashing at runtime when NCCL FP8 quantized weight transfer is used.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit bec06a0. Configure here.

Comment thread src/prime_rl/trainer/rl/broadcast/nixl.py
[(ptr + offset_bytes + my_rank * slot_bytes, slot_bytes, dev)]
)
remote_prep = self._agent.prep_remote(peer["agent_name"], remote_descs)
self._writes.append((local_prep, 0, remote_prep, 0))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per-shard offset uses global rank instead of FSDP rank

Medium Severity

The per_shard remote write offset is computed as my_rank * slot_bytes where my_rank = self.world.rank (the global rank). For per-shard non-expert tensors, the offset into the inference-side tensor should correspond to the FSDP shard index, not the global rank. With dp_replicate > 1, the global rank diverges from the FSDP shard position, causing writes to land at incorrect offsets and leaving parts of the inference tensor unwritten.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 459f19f. Configure here.

def _register(name: str, tensor: torch.Tensor) -> None:
contig = tensor.contiguous()
self._agent.register_tensor(contig)
tensor_ptrs[name] = (contig.data_ptr(), contig.numel() * contig.element_size(), contig.get_device())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registered contiguous copy may become dangling memory

Medium Severity

In _register, contig = tensor.contiguous() may create a new tensor if the original is non-contiguous. This copy is registered with NIXL and its pointer is published, but no persistent reference to contig is kept. After _register returns, the copy can be garbage-collected, making the RDMA-registered pointer dangling. Trainer writes would go into freed memory instead of the actual model parameter.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 459f19f. Configure here.

S1ro1 added 2 commits April 20, 2026 05:07
Weight broadcast over NIXL/UCX now completes on the 12-node prod setup
(8 trainer x 8 GPU + 2 prefill + 2 decode). Key fixes stacked here:

- Per-rank NIC pin actually applied: drop the sbatch-level
  UCX_NET_DEVICES=all so trainer ranks keep their PIX-attached NIC.
  GPUs 4-7 have three PIX NICs, one of which (mlx5_8) is DOWN on every
  trainer node, and UCX=all was ending up on it for GPU 4 and hitting
  NIXL_ERR_REMOTE_DISCONNECT.
- Allocate slot dtype to match vLLM for layernorm affine params and
  the expert-routing bias, which vLLM stores in fp32.
- Gather rather than per-shard when a source tensor is below a 2 MiB
  threshold, cutting per-rank handle count ~60%.
- Chunked drain (flush_every=100) to bound UCX queue depth.
- Trainer dist.barrier() after rank 0 sees NIXL_READY so no rank
  RDMA-writes before all inference engines have acked /pause.
- Peer/tag context propagated through wait() for diagnostics.
Decode inference workers run with UCX_NET_DEVICES=mlx5_0:1 pre-set for
the PD KV NixlConnector, so pin_ucx_rails setdefault was a no-op on
them and every weight-transfer WRITE funneled through a single NIC
per decode node. Switching to a plain assignment makes the
weight-transfer agent bind its UCP worker to the per-GPU PIX NIC
instead; the PD connector is already up with its own env snapshot so
it keeps mlx5_0.

Push throughput on the 12-node prod run jumps from ~4.8 GB/s to
~7.5 GB/s wire, net bandwidth from ~10 GB/s to ~20 GB/s, and the
per-push SPG barrier drops from ~8 s to ~0.8 s. The 7-second
straggler gap between local rank 0 and ranks 1-7 also disappears.

self.logger.info(
f"[nixl rank={self.world.rank}] push "
f"bytes={self._bytes_per_push / 1e6:.2f}MB handles={len(handles)} "
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logging always reports zero handles after drain clears list

Low Severity

The push_once log message references len(handles), but handles is always empty at that point because _drain calls handles.clear(). The log will always show handles=0 regardless of how many writes were posted. The intent was likely to log len(self._writes).

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 0d49320. Configure here.

remote_prep = _get_remote_prep(peer["agent_name"], name, my_rank, peer["descriptors"][name])
self._writes.append(
(local_prep, 0, remote_prep, 0, peer["agent_name"], f"per_shard:{name}")
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per-shard NIXL writes assume fsdp_total equals trainer world size

Medium Severity

For per_shard non-expert slots, the inference side creates trainer_world_size chunks while the trainer's slot is sized by fsdp_total (dp_shard * cp). The trainer indexes into inference chunks using its global world.rank. When dp_replicate > 1 or pp > 1, trainer_world_size > fsdp_total, causing the trainer's slot (larger) to be written into a smaller inference chunk, producing an RDMA size mismatch error.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 0d49320. Configure here.

PyTorch expandable_segments:True hands tensors out of cuMemCreate +
cuMemMap virtual ranges. ibv_reg_mr on such VA succeeds, but the mlx5
HCA's MMU walk at WRITE time fails with "Local protection" (syndrome
0x4) because nvidia_peermem cannot pin a VA that spans multiple
cuMemCreate handles. UCX closes the EP and NIXL surfaces it as
REMOTE_DISCONNECT.

Fix: carve out the NIXL-registered slot buffers into a dedicated
CUDAPluggableAllocator + MemPool that calls cudaMalloc / cudaFree
directly. Everything else in the trainer keeps using expandable
segments, preserving the fragmentation mitigation.

Also preload libcudart with RTLD_GLOBAL because TileLang's stub
runtime wins dlsym(RTLD_DEFAULT) otherwise and its self-check aborts
the process the moment we enter the MemPool context. Wrapped in
try/except since CDLL can fail on hosts without a real CUDA runtime.

Wire_bw on the 12-node disagg run stays at 7.1-7.6 GB/s — same as
the expandable_segments=False baseline.
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 6 total unresolved issues (including 5 from previous reviews).

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 18b39fe. Configure here.

_drain(len(self._writes))

t_posted = time.perf_counter()
t_waited = t_posted
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait timing variables are identical, metrics always zero

Low Severity

t_posted and t_waited are assigned the same value on consecutive lines, making dt_wait always zero. The gbps_wire and gbps_net metrics derived from dt_wire and dt_post + dt_wait become misleading since the wait phase (previously a separate step before the drain refactor) is no longer captured separately.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 18b39fe. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant