Skip to content

Nixl weight transfer#2326

Open
S1ro1 wants to merge 17 commits intomainfrom
nixl-weight-transfer
Open

Nixl weight transfer#2326
S1ro1 wants to merge 17 commits intomainfrom
nixl-weight-transfer

Conversation

@S1ro1
Copy link
Copy Markdown
Collaborator

@S1ro1 S1ro1 commented Apr 19, 2026

RDMA based weight transfer for GLM5 FP8


Note

High Risk
Introduces a new RDMA-based weight update mechanism spanning trainer/orchestrator/inference plus custom GPU memory allocation and Triton kernels, which is complex and failure-prone in distributed environments. Misconfiguration or subtle sync/registration issues could stall training or corrupt inference weights.

Overview
Adds a new nixl weight broadcast mode that updates vLLM inference weights via NIXL/UCX RDMA with a two-round StatelessProcessGroup rendezvous, per-push barriers, and orchestrator-driven pause/resume guarded by STABLE/NIXL_READY markers.

On the trainer side this introduces TransportPlan/slot-based pre-registered buffers (including FP8 scale buffers) backed by a new classic cudaMalloc mempool, plus model-driven ConversionSpec tables and a Triton FP8 block-quantization implementation. On the inference side it adds a NIXLWeightUpdateWorker and admin API to initialize NIXL transfer and participate in the rendezvous.

The PR also simplifies NCCL weight broadcast by removing the optional “quantize in transfer” path, updates configs/validators and SLURM templates to support nixl, and adds a system-contract doc for the new protocol.

Reviewed by Cursor Bugbot for commit f23d68b. Bugbot is set up for automated code reviews on this repo. Configure here.

Comment thread src/prime_rl/inference/vllm/worker/nixl.py Outdated
Comment thread src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py Outdated
S1ro1 and others added 2 commits April 19, 2026 19:00
Replace the if/else mess in convert_tt_layer_to_vllm_kernel with a
deterministic spec table mirroring mapper.py; only branch left is dense
vs sparse layer. out_buffers is now required (fail-loud on missing key),
experts quantize in a single grouped-triton call instead of a per-expert
loop. Slot allocation moved into the NIXL broadcast init so the
converter never allocates, and expert_lead / konig / ImportError fallbacks
are gone.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment thread src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py Outdated
Comment thread src/prime_rl/trainer/models/glm_moe_dsa/modeling_glm_moe_dsa.py Outdated
Comment thread src/prime_rl/trainer/rl/broadcast/nixl.py Outdated
Comment thread src/prime_rl/trainer/models/base.py Outdated
Comment thread src/prime_rl/trainer/rl/broadcast/nixl.py
Comment thread src/prime_rl/trainer/rl/broadcast/nixl.py Outdated
Comment thread src/prime_rl/inference/vllm/worker/nixl.py Outdated
S1ro1 added 2 commits April 20, 2026 05:07
Weight broadcast over NIXL/UCX now completes on the 12-node prod setup
(8 trainer x 8 GPU + 2 prefill + 2 decode). Key fixes stacked here:

- Per-rank NIC pin actually applied: drop the sbatch-level
  UCX_NET_DEVICES=all so trainer ranks keep their PIX-attached NIC.
  GPUs 4-7 have three PIX NICs, one of which (mlx5_8) is DOWN on every
  trainer node, and UCX=all was ending up on it for GPU 4 and hitting
  NIXL_ERR_REMOTE_DISCONNECT.
- Allocate slot dtype to match vLLM for layernorm affine params and
  the expert-routing bias, which vLLM stores in fp32.
- Gather rather than per-shard when a source tensor is below a 2 MiB
  threshold, cutting per-rank handle count ~60%.
- Chunked drain (flush_every=100) to bound UCX queue depth.
- Trainer dist.barrier() after rank 0 sees NIXL_READY so no rank
  RDMA-writes before all inference engines have acked /pause.
- Peer/tag context propagated through wait() for diagnostics.
Decode inference workers run with UCX_NET_DEVICES=mlx5_0:1 pre-set for
the PD KV NixlConnector, so pin_ucx_rails setdefault was a no-op on
them and every weight-transfer WRITE funneled through a single NIC
per decode node. Switching to a plain assignment makes the
weight-transfer agent bind its UCP worker to the per-GPU PIX NIC
instead; the PD connector is already up with its own env snapshot so
it keeps mlx5_0.

Push throughput on the 12-node prod run jumps from ~4.8 GB/s to
~7.5 GB/s wire, net bandwidth from ~10 GB/s to ~20 GB/s, and the
per-push SPG barrier drops from ~8 s to ~0.8 s. The 7-second
straggler gap between local rank 0 and ranks 1-7 also disappears.
Comment thread src/prime_rl/trainer/rl/broadcast/nixl.py Outdated
Comment thread src/prime_rl/trainer/rl/broadcast/nixl.py Outdated
PyTorch expandable_segments:True hands tensors out of cuMemCreate +
cuMemMap virtual ranges. ibv_reg_mr on such VA succeeds, but the mlx5
HCA's MMU walk at WRITE time fails with "Local protection" (syndrome
0x4) because nvidia_peermem cannot pin a VA that spans multiple
cuMemCreate handles. UCX closes the EP and NIXL surfaces it as
REMOTE_DISCONNECT.

Fix: carve out the NIXL-registered slot buffers into a dedicated
CUDAPluggableAllocator + MemPool that calls cudaMalloc / cudaFree
directly. Everything else in the trainer keeps using expandable
segments, preserving the fragmentation mitigation.

Also preload libcudart with RTLD_GLOBAL because TileLang's stub
runtime wins dlsym(RTLD_DEFAULT) otherwise and its self-check aborts
the process the moment we enter the MemPool context. Wrapped in
try/except since CDLL can fail on hosts without a real CUDA runtime.

Wire_bw on the 12-node disagg run stays at 7.1-7.6 GB/s — same as
the expandable_segments=False baseline.
Comment thread src/prime_rl/trainer/rl/broadcast/nixl.py Outdated
S1ro1 and others added 5 commits April 20, 2026 19:48
Split the old GLM-specific _Spec into two model-agnostic dataclasses
in a new trainer/models/conversion_spec.py:

- QuantizationSpec(destination_dtype, scale_suffix="") owns the
  transformation. Empty scale_suffix means plain copy_cast; non-empty
  means FP8 block quantize (2D or 3D dispatched from src.ndim).
  QuantizationSpec.apply(src, out, scale_out) is the single entry
  point for every spec type, including bf16 / fp32 dtype casts.
- ConversionSpec(dst, sources, cat_dim, quantization) owns routing.
  scale_name and per_source_scale_key moved from private helpers to
  public methods.

Drops the ad-hoc vllm_fp32_srcs set, the duplicated 2D/3D quantize-fn
dispatch, and the hardcoded .weight/_weight_scale_inv suffix in
modeling_glm_moe_dsa.py. Spec table entries are explicit — each fp32
layernorm and FP8 projection inlines its QuantizationSpec.

Moved the libcudart.so RTLD_GLOBAL preload to the top of
sparse_mla_{fwd,bwd}.py, before the import tilelang that was winning
the dlsym race and aborting the process. Previously it lived in
classic_cuda_pool.py which was imported too late.

Added docs/nixl-weight-broadcast.md covering the whole workflow
including the new class design and the platform gotchas.

12-node prod run stays at 7.0-7.6 GB/s wire / 18-22 GB/s net after
the refactor.
- Replace model.allocate_slots/non_expert_slot_layout/convert_layer_to_vllm_kernel with a flat ConversionSpec table the model exposes via conversion_specs(layer_idx).
- New model-agnostic ShardedSlot/GatheredSlot/ExpertSlot own their own buffers, layout payload, and write tables; TransportPlan owns the NIXL registrations, SPG rendezvous, and per-step push.
- Flat LayoutEntry wire format replaces the nested non_expert_layout dict.
- Fix: _ResolvedWrite remote_idx must be 0 — _resolve_remote preps one serialized chunk dlist at a time, so chunk selection already happened at prep time (caused NIXL_ERR_INVALID_PARAM at prod scale).
- Drop FP8 kernel-format NCCL weight transfer (quantize_in_weight_transfer): fully superseded by NIXL; removes config fields, validators, client/server routing, inference-side loader, and trainer-side preprocessor.
- Remove outdated NIXL/conversion tests — new ones to follow against the new API.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Add docs/nixl-architecture.md: concise ownership/lifecycle/wire-format reference for TransportPlan + Slot + rendezvous.
- Drop tests/fixtures/build_tiny_glm_moe_dsa.py — only the removed NIXL integration tests used it.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the step-by-step walkthrough + gotcha log with a role-by-role
contract: what trainer/inference/orchestrator promise each other, what
flows on each channel (SPG control, NIXL data, filesystem markers),
what's guaranteed after a push vs not, and what breaks the contract.

Keeps a companion to docs/nixl-architecture.md — the architecture doc
covers class ownership and wire encoding; this doc covers system-level
promises.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The contract-focused nixl-weight-broadcast.md covers the same surface
without duplicating implementation-level class tables.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment thread src/prime_rl/trainer/models/glm_moe_dsa/converting_glm_moe_dsa.py Outdated
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this a section somewhere? if we don't have yet, can we make a weight-transfer.md which details all of the transfer methods we support atm

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agree, will probably add this in here

] = False
port: Annotated[int, Field(description="Rendezvous port (NCCL or NIXL).")] = 29501
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds (NCCL or NIXL).")] = 1200
backends: Annotated[
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its a little anti-patternish to have a shared config which only applies to one type of weight broadcast config. wondering if we should do smth like SharedFileSystemBroadcastConfig, SharedNCCLBroadcastConfig, etc

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbf, we alr do this rn with the port and timeout which are nccl-only -- just seeing this now

Comment on lines 473 to +474
use_nccl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nccl",
use_nixl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nixl",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we make this weight_broadcast_type: Literal["....]?

self._multi_run_manager = get_multi_run_manager()
return self._multi_run_manager

def _notify_orchestrator(self) -> list[tuple[int, Path]]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is prob protocol agnostic, right? should we move this into the base WeightBroadcast?

Comment on lines +304 to +305
(weight_dir / NCCL_READY_MARKER).touch()
(weight_dir / NIXL_READY_MARKER).touch()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a bit off to write both flags, but can clean this up later

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a bit off to write both flags, but can clean this up later

Yeah we should make it smth like TRANSFER_READY and be transfer agnostic

S1ro1 and others added 2 commits April 21, 2026 01:55
HSDP:
- NIXLWeightBroadcast runs the full protocol only on replica 0 (dp_replicate
  rank 0). Non-primary replicas skip TransportPlan / agent / SPG entirely.
- SPG world size drops to (dp_shard × cp) + inference_ws; trainer rank in
  the SPG is the dp_shard_cp axis rank, not the global rank.
- Slots index by dp_shard_cp rank too (via _shard_rank_and_size) so
  ShardedSlot.remote_chunk_idx and GatheredSlot's round-robin fan-out
  align with the SPG membership.
- broadcast_weights adds a second dist.barrier() at the end so non-primary
  replicas exit in lockstep with the primary's push. Without it they'd
  race ahead into the next step while the primary is still draining.
- Config plumbs per-replica trainer_world_size via _infer_nixl_trainer_ws.

EP partition assertion:
- ExpertSlot.from_spec now verifies num_local × fsdp × ep == total_experts
  from the DTensor global shape — catches misconfigured meshes (e.g. EP
  axis not dividing cleanly into dp_shard_mod_ep) at slot construction
  instead of as a silent routing bug.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The Triton _grouped_fp8_block_quantize used a 1e-4 scale floor; main's
PyTorch quantize_to_fp8_blockwise used 1e-12. For blocks with amax
below ~0.0448 (e.g. LoRA B-tensors, freshly-initialized heads) the new
floor zeroed out tiny nonzero weights at dequant time.

With this change the Triton kernel produces bit-exact output vs the old
PyTorch kernel across normal / tiny / zero / large-magnitude blocks.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit f23d68b. Configure here.

timeout: Annotated[int, Field(description="Rendezvous timeout in seconds (NCCL or NIXL).")] = 1200
backends: Annotated[
list[str], Field(description="NIXL backends (only used when type='nixl').")
] = ["UCX"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed config field without CHANGELOG update

Medium Severity

The quantize_in_weight_transfer field was removed from SharedWeightBroadcastConfig, trainer NCCLWeightBroadcastConfig, and orchestrator NCCLWeightBroadcastConfig. Since these configs use extra="forbid", any existing user config that sets this field will fail validation at load time. This is a breaking removal of a config field, but CHANGELOG.md has no corresponding entry.

Additional Locations (2)
Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions

Reviewed by Cursor Bugbot for commit f23d68b. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants