Conversation
Replace the if/else mess in convert_tt_layer_to_vllm_kernel with a deterministic spec table mirroring mapper.py; only branch left is dense vs sparse layer. out_buffers is now required (fail-loud on missing key), experts quantize in a single grouped-triton call instead of a per-expert loop. Slot allocation moved into the NIXL broadcast init so the converter never allocates, and expert_lead / konig / ImportError fallbacks are gone. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Weight broadcast over NIXL/UCX now completes on the 12-node prod setup (8 trainer x 8 GPU + 2 prefill + 2 decode). Key fixes stacked here: - Per-rank NIC pin actually applied: drop the sbatch-level UCX_NET_DEVICES=all so trainer ranks keep their PIX-attached NIC. GPUs 4-7 have three PIX NICs, one of which (mlx5_8) is DOWN on every trainer node, and UCX=all was ending up on it for GPU 4 and hitting NIXL_ERR_REMOTE_DISCONNECT. - Allocate slot dtype to match vLLM for layernorm affine params and the expert-routing bias, which vLLM stores in fp32. - Gather rather than per-shard when a source tensor is below a 2 MiB threshold, cutting per-rank handle count ~60%. - Chunked drain (flush_every=100) to bound UCX queue depth. - Trainer dist.barrier() after rank 0 sees NIXL_READY so no rank RDMA-writes before all inference engines have acked /pause. - Peer/tag context propagated through wait() for diagnostics.
Decode inference workers run with UCX_NET_DEVICES=mlx5_0:1 pre-set for the PD KV NixlConnector, so pin_ucx_rails setdefault was a no-op on them and every weight-transfer WRITE funneled through a single NIC per decode node. Switching to a plain assignment makes the weight-transfer agent bind its UCP worker to the per-GPU PIX NIC instead; the PD connector is already up with its own env snapshot so it keeps mlx5_0. Push throughput on the 12-node prod run jumps from ~4.8 GB/s to ~7.5 GB/s wire, net bandwidth from ~10 GB/s to ~20 GB/s, and the per-push SPG barrier drops from ~8 s to ~0.8 s. The 7-second straggler gap between local rank 0 and ranks 1-7 also disappears.
PyTorch expandable_segments:True hands tensors out of cuMemCreate + cuMemMap virtual ranges. ibv_reg_mr on such VA succeeds, but the mlx5 HCA's MMU walk at WRITE time fails with "Local protection" (syndrome 0x4) because nvidia_peermem cannot pin a VA that spans multiple cuMemCreate handles. UCX closes the EP and NIXL surfaces it as REMOTE_DISCONNECT. Fix: carve out the NIXL-registered slot buffers into a dedicated CUDAPluggableAllocator + MemPool that calls cudaMalloc / cudaFree directly. Everything else in the trainer keeps using expandable segments, preserving the fragmentation mitigation. Also preload libcudart with RTLD_GLOBAL because TileLang's stub runtime wins dlsym(RTLD_DEFAULT) otherwise and its self-check aborts the process the moment we enter the MemPool context. Wrapped in try/except since CDLL can fail on hosts without a real CUDA runtime. Wire_bw on the 12-node disagg run stays at 7.1-7.6 GB/s — same as the expandable_segments=False baseline.
Split the old GLM-specific _Spec into two model-agnostic dataclasses
in a new trainer/models/conversion_spec.py:
- QuantizationSpec(destination_dtype, scale_suffix="") owns the
transformation. Empty scale_suffix means plain copy_cast; non-empty
means FP8 block quantize (2D or 3D dispatched from src.ndim).
QuantizationSpec.apply(src, out, scale_out) is the single entry
point for every spec type, including bf16 / fp32 dtype casts.
- ConversionSpec(dst, sources, cat_dim, quantization) owns routing.
scale_name and per_source_scale_key moved from private helpers to
public methods.
Drops the ad-hoc vllm_fp32_srcs set, the duplicated 2D/3D quantize-fn
dispatch, and the hardcoded .weight/_weight_scale_inv suffix in
modeling_glm_moe_dsa.py. Spec table entries are explicit — each fp32
layernorm and FP8 projection inlines its QuantizationSpec.
Moved the libcudart.so RTLD_GLOBAL preload to the top of
sparse_mla_{fwd,bwd}.py, before the import tilelang that was winning
the dlsym race and aborting the process. Previously it lived in
classic_cuda_pool.py which was imported too late.
Added docs/nixl-weight-broadcast.md covering the whole workflow
including the new class design and the platform gotchas.
12-node prod run stays at 7.0-7.6 GB/s wire / 18-22 GB/s net after
the refactor.
- Replace model.allocate_slots/non_expert_slot_layout/convert_layer_to_vllm_kernel with a flat ConversionSpec table the model exposes via conversion_specs(layer_idx). - New model-agnostic ShardedSlot/GatheredSlot/ExpertSlot own their own buffers, layout payload, and write tables; TransportPlan owns the NIXL registrations, SPG rendezvous, and per-step push. - Flat LayoutEntry wire format replaces the nested non_expert_layout dict. - Fix: _ResolvedWrite remote_idx must be 0 — _resolve_remote preps one serialized chunk dlist at a time, so chunk selection already happened at prep time (caused NIXL_ERR_INVALID_PARAM at prod scale). - Drop FP8 kernel-format NCCL weight transfer (quantize_in_weight_transfer): fully superseded by NIXL; removes config fields, validators, client/server routing, inference-side loader, and trainer-side preprocessor. - Remove outdated NIXL/conversion tests — new ones to follow against the new API. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Add docs/nixl-architecture.md: concise ownership/lifecycle/wire-format reference for TransportPlan + Slot + rendezvous. - Drop tests/fixtures/build_tiny_glm_moe_dsa.py — only the removed NIXL integration tests used it. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the step-by-step walkthrough + gotcha log with a role-by-role contract: what trainer/inference/orchestrator promise each other, what flows on each channel (SPG control, NIXL data, filesystem markers), what's guaranteed after a push vs not, and what breaks the contract. Keeps a companion to docs/nixl-architecture.md — the architecture doc covers class ownership and wire encoding; this doc covers system-level promises. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The contract-focused nixl-weight-broadcast.md covers the same surface without duplicating implementation-level class tables. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
can we make this a section somewhere? if we don't have yet, can we make a weight-transfer.md which details all of the transfer methods we support atm
There was a problem hiding this comment.
Yeah agree, will probably add this in here
| ] = False | ||
| port: Annotated[int, Field(description="Rendezvous port (NCCL or NIXL).")] = 29501 | ||
| timeout: Annotated[int, Field(description="Rendezvous timeout in seconds (NCCL or NIXL).")] = 1200 | ||
| backends: Annotated[ |
There was a problem hiding this comment.
its a little anti-patternish to have a shared config which only applies to one type of weight broadcast config. wondering if we should do smth like SharedFileSystemBroadcastConfig, SharedNCCLBroadcastConfig, etc
There was a problem hiding this comment.
tbf, we alr do this rn with the port and timeout which are nccl-only -- just seeing this now
| use_nccl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nccl", | ||
| use_nixl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nixl", |
There was a problem hiding this comment.
should we make this weight_broadcast_type: Literal["....]?
| self._multi_run_manager = get_multi_run_manager() | ||
| return self._multi_run_manager | ||
|
|
||
| def _notify_orchestrator(self) -> list[tuple[int, Path]]: |
There was a problem hiding this comment.
this is prob protocol agnostic, right? should we move this into the base WeightBroadcast?
| (weight_dir / NCCL_READY_MARKER).touch() | ||
| (weight_dir / NIXL_READY_MARKER).touch() |
There was a problem hiding this comment.
seems a bit off to write both flags, but can clean this up later
There was a problem hiding this comment.
seems a bit off to write both flags, but can clean this up later
Yeah we should make it smth like TRANSFER_READY and be transfer agnostic
HSDP: - NIXLWeightBroadcast runs the full protocol only on replica 0 (dp_replicate rank 0). Non-primary replicas skip TransportPlan / agent / SPG entirely. - SPG world size drops to (dp_shard × cp) + inference_ws; trainer rank in the SPG is the dp_shard_cp axis rank, not the global rank. - Slots index by dp_shard_cp rank too (via _shard_rank_and_size) so ShardedSlot.remote_chunk_idx and GatheredSlot's round-robin fan-out align with the SPG membership. - broadcast_weights adds a second dist.barrier() at the end so non-primary replicas exit in lockstep with the primary's push. Without it they'd race ahead into the next step while the primary is still draining. - Config plumbs per-replica trainer_world_size via _infer_nixl_trainer_ws. EP partition assertion: - ExpertSlot.from_spec now verifies num_local × fsdp × ep == total_experts from the DTensor global shape — catches misconfigured meshes (e.g. EP axis not dividing cleanly into dp_shard_mod_ep) at slot construction instead of as a silent routing bug. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The Triton _grouped_fp8_block_quantize used a 1e-4 scale floor; main's PyTorch quantize_to_fp8_blockwise used 1e-12. For blocks with amax below ~0.0448 (e.g. LoRA B-tensors, freshly-initialized heads) the new floor zeroed out tiny nonzero weights at dequant time. With this change the Triton kernel produces bit-exact output vs the old PyTorch kernel across normal / tiny / zero / large-magnitude blocks. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit f23d68b. Configure here.
| timeout: Annotated[int, Field(description="Rendezvous timeout in seconds (NCCL or NIXL).")] = 1200 | ||
| backends: Annotated[ | ||
| list[str], Field(description="NIXL backends (only used when type='nixl').") | ||
| ] = ["UCX"] |
There was a problem hiding this comment.
Removed config field without CHANGELOG update
Medium Severity
The quantize_in_weight_transfer field was removed from SharedWeightBroadcastConfig, trainer NCCLWeightBroadcastConfig, and orchestrator NCCLWeightBroadcastConfig. Since these configs use extra="forbid", any existing user config that sets this field will fail validation at load time. This is a breaking removal of a config field, but CHANGELOG.md has no corresponding entry.
Additional Locations (2)
Triggered by project rule: BugBot Instructions
Reviewed by Cursor Bugbot for commit f23d68b. Configure here.


RDMA based weight transfer for GLM5 FP8
Note
High Risk
Introduces a new RDMA-based weight update mechanism spanning trainer/orchestrator/inference plus custom GPU memory allocation and Triton kernels, which is complex and failure-prone in distributed environments. Misconfiguration or subtle sync/registration issues could stall training or corrupt inference weights.
Overview
Adds a new
nixlweight broadcast mode that updates vLLM inference weights via NIXL/UCX RDMA with a two-roundStatelessProcessGrouprendezvous, per-push barriers, and orchestrator-driven pause/resume guarded bySTABLE/NIXL_READYmarkers.On the trainer side this introduces
TransportPlan/slot-based pre-registered buffers (including FP8 scale buffers) backed by a new classiccudaMallocmempool, plus model-drivenConversionSpectables and a Triton FP8 block-quantization implementation. On the inference side it adds aNIXLWeightUpdateWorkerand admin API to initialize NIXL transfer and participate in the rendezvous.The PR also simplifies NCCL weight broadcast by removing the optional “quantize in transfer” path, updates configs/validators and SLURM templates to support
nixl, and adds a system-contract doc for the new protocol.Reviewed by Cursor Bugbot for commit f23d68b. Bugbot is set up for automated code reviews on this repo. Configure here.