Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
4a7a6dd
Initial
S1ro1 Apr 19, 2026
f270cac
Feat: Cleanup
S1ro1 Apr 19, 2026
cd3a565
Clean up GLM MoE DSA converter + NIXL broadcast
S1ro1 Apr 19, 2026
612429f
Feat: some cleanup
S1ro1 Apr 19, 2026
bec06a0
Feat: cleanup more
S1ro1 Apr 19, 2026
459f19f
wtf did claude cook
S1ro1 Apr 19, 2026
690dc4a
Feat: NIXL broadcast working end-to-end on GLM-5.1 (12-node disagg)
S1ro1 Apr 19, 2026
0d49320
Feat: hard-override UCX_NET_DEVICES in pin_ucx_rail
S1ro1 Apr 19, 2026
18b39fe
Feat: NIXL weight transfer now works with expandable_segments=True
S1ro1 Apr 20, 2026
5ea1051
Feat: ConversionSpec + QuantizationSpec, doc, fix tilelang preload
S1ro1 Apr 20, 2026
ea791f8
Feat: TransportPlan + Slot refactor, drop FP8 NCCL quantize path
S1ro1 Apr 20, 2026
90c4dc4
Docs: NIXL architecture contract + drop stale fixtures
S1ro1 Apr 20, 2026
e78fa10
Docs: rewrite nixl-weight-broadcast.md as a system contract
S1ro1 Apr 20, 2026
ed71964
Docs: drop nixl-architecture.md, superseded by contract rewrite
S1ro1 Apr 20, 2026
3a47826
Fix: typo
S1ro1 Apr 20, 2026
4369d21
Feat: HSDP support (primary-replica push) + EP partition assertion
S1ro1 Apr 20, 2026
f23d68b
Fix: FP8 scale floor back to 1e-12 to match pre-Triton parity
S1ro1 Apr 20, 2026
81be8e7
Doc: KL mismatch investigation scratchpad
S1ro1 Apr 21, 2026
4df14cc
Exp iter2: add end-to-end signature diagnostic for anchor slot
S1ro1 Apr 21, 2026
2740cd3
Exp iter3: expand SIG diagnostic to FP8 gather + expert anchors
S1ro1 Apr 21, 2026
a1bcc7d
Exp iter4: inference SIG lookup checks both param + buffer dicts
S1ro1 Apr 21, 2026
8f41149
Exp iter5 (doc): disable DeepGemm to test layout-mismatch hypothesis
S1ro1 Apr 21, 2026
9cd8541
Exp iter6: SIG now logs shape+stride on both sides
S1ro1 Apr 21, 2026
da5e072
Exp iter7: fused-region sum check + multiple expert anchors
S1ro1 Apr 21, 2026
60a78f5
Exp iter8: transport non-layer tensors (embed, norm, lm_head)
S1ro1 Apr 21, 2026
cd9ff66
Exp iter9: untracked-keys diagnostic for missing slots
S1ro1 Apr 21, 2026
37fc774
Exp iter10: cuda.synchronize on inference after SPG barrier
S1ro1 Apr 21, 2026
d341bd6
Exp iter11 (doc): enforce_eager=true on inference
S1ro1 Apr 21, 2026
0813d85
Exp iter12: verify N anchors (embed/norm/lm_head) transport
S1ro1 Apr 21, 2026
d6cca80
Exp iter13: precise ShardedSlot verification via head[:2420] sum
S1ro1 Apr 21, 2026
0af021b
Exp iter14 (nixl side): flush_every=1 (per-write drain)
S1ro1 Apr 21, 2026
71d24b0
Exp iter14 (doc): maximum conservatism — stack all knobs
S1ro1 Apr 21, 2026
68dcfb4
Investigation wrap-up: exhausted surface-level NIXL hypotheses
S1ro1 Apr 21, 2026
c494158
Exp iter15: pre-write SPG barrier + inference cuda.sync before it
S1ro1 Apr 21, 2026
3e53fa6
Exp iter16: byte-level trainer/inference dump + diff tool
S1ro1 Apr 21, 2026
658f3cc
Exp iter17: pause clear_cache=true — test KV cache staleness theory
S1ro1 Apr 21, 2026
9035914
Exp iter18: swap Triton FP8 quantize for main's PyTorch impl
S1ro1 Apr 21, 2026
6a6a23f
Exp iter19: abort in-flight requests on pause
S1ro1 Apr 21, 2026
b29bae3
Revert "Exp iter19: abort in-flight requests on pause"
S1ro1 Apr 21, 2026
94edaf7
Exp iter19: flush GPUDirect RDMA writes on inference
S1ro1 Apr 21, 2026
a2f81ab
Exp iter20: per-write drain with GPUDirect flush
S1ro1 Apr 21, 2026
b053f67
Revert "Exp iter20: per-write drain with GPUDirect flush"
S1ro1 Apr 21, 2026
121782b
Exp iter21: enable sync memops on NIXL buffers
S1ro1 Apr 21, 2026
6f4a685
Exp iter22-27 (squash): freeze_{experts,non_experts} + transfer_mode …
S1ro1 Apr 21, 2026
1c9fe0c
Doc: wrap-up — iter22-27 summary, iter26/27 W&B data, bug narrowed to…
S1ro1 Apr 21, 2026
94b6ad6
Doc: rule out inference non-determinism
S1ro1 Apr 21, 2026
94c14f4
Remove tools/inference_dashboard from tracking
S1ro1 Apr 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions docs/nixl-weight-broadcast.md
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this a section somewhere? if we don't have yet, can we make a weight-transfer.md which details all of the transfer methods we support atm

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agree, will probably add this in here

Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# NIXL weight broadcast — system contract

What each role promises, what flows between them, and what the system
guarantees after a push.

## Roles

| Role | What it runs | What it owns |
|---|---|---|
| **Orchestrator** | `prime_rl.orchestrator.orchestrator` | Pause/resume of the inference pool, `STABLE`/`NIXL_READY` markers, train-step pacing. |
| **Trainer** | `NIXLWeightBroadcast` → `TransportPlan` | Source of truth for model weights. Decides when a broadcast happens, drives the transfer. |
| **Inference** | `NIXLWeightUpdateWorker` per vLLM worker | Destination buffers. Pauses forward pass during a broadcast, resumes only when the orchestrator allows. |

The transfer happens end-to-end over:

* **SPG** (TCP) — rendezvous, barriers. `trainer_ws + inference_ws`
ranks, established once at trainer init.
* **NIXL / UCX / IB RDMA** — the data path. Trainer posts WRITEs into
pre-registered inference parameter buffers.
* **Filesystem markers** — one-way orchestrator ↔ trainer signaling
(`STABLE`, `NIXL_READY`).

## Trainer ↔ inference contract

The trainer and inference agree on three things *before the first push*
and never renegotiate:

1. **The slot inventory.** Every trainer-side destination buffer has a
unique `slot_key`. The inference side publishes a descriptor list
under the same key. Expert slots use the destination param name;
non-expert slots use the source-tensor name.
2. **The layout of every non-expert destination.** Trainer ships one
`LayoutEntry(slot_key, inference_name, offset_rows, rows, num_chunks)`
per slot-buffer in SPG round 1. Inference narrows its vLLM tensor
per those coordinates and publishes one serialized xfer dlist per
chunk.
3. **The expert map.** Inference publishes
`{moe_prefix: [global_expert_id, …]}` so the trainer knows which
peers own which global experts. Trainer only writes a local expert
to peers that own it.

Once the write table is built, every broadcast reuses it.

### Per-push guarantees (what `push_once` provides)

After `push_once` returns **on every trainer rank** and
`update_weights_from_path` returns **on every inference rank**:

* Every inference-side parameter buffer that the trainer is responsible
for has been overwritten with the current step's weights (after
quantization + dtype cast as declared by the slot's `QuantizationSpec`).
* All RDMA WRITEs have been acknowledged at the remote NIC; no writes
are in flight.
* MLA absorbed weights (`W_UV`, `W_UK_T`) on inference have been
recomputed from the freshly written `kv_b_proj`.

### Per-push non-guarantees

* **No freshness beyond the current step.** If the trainer updates
weights again before the next barrier lands, inference may observe
a mixed snapshot. The orchestrator's pause/resume is what makes this
safe in practice.
* **No delta.** Every push ships the entire registered surface,
regardless of which params changed.
* **No ordering between slots.** Writes are posted in a fixed order but
drained in batches; an inference observer that isn't paused would
see torn writes.

## Orchestrator ↔ trainer ↔ inference contract

Per step, the orchestrator is the one authority that says "it's safe to
overwrite inference weights now" and "you can start serving again."

```
trainer rank 0 orchestrator inference (all ranks)
│ │ │
├── touch STABLE ─────────▶ │
│ ├── /pause ───────────────▶
│ │◀── ack all ─────────────┤
│ ├── touch NIXL_READY │
│◀── see NIXL_READY ──────┤ │
│ │
├───────── dist.barrier() across all trainer ranks ─│
│ │
├─────────── RDMA WRITEs (every rank) ─────────────▶│
│ │
├──────── spg.barrier() across trainer+inference ───│
│ │
│ │◀── /resume ─────────────┤
│ ├── resume ──────────────▶│
```

The contract:

* **Trainer promises:** no rank posts any RDMA WRITE before
`NIXL_READY` is observed. The `dist.barrier()` across all trainer
ranks enforces this — otherwise non-master ranks would race ahead.
* **Orchestrator promises:** once `NIXL_READY` is written, every
inference worker has paused; no forward pass is reading params.
* **Inference promises:** once `update_weights_from_path` enters its
SPG barrier, its params are quiescent and remain quiescent until
both the barrier releases and `/resume` returns.
* **Shared ack:** the final `spg.barrier()` at the end of `push_once`
is the single synchronization point that gates "weights are now in
place" across the 96-rank cluster.

## Registration invariants (set once, forever)

These are properties of the pre-registered buffers. Breaking them
causes `NIXL_ERR_INVALID_PARAM` at post time or mlx5 Local Protection
at WRITE-landing time — both are debugging sinkholes.

* **One MR per logical buffer on inference.** The full vLLM tensor is
registered once. Per-chunk xfer descriptors resolve to that MR's
rkey at write time. Registering overlapping per-chunk MRs trips
mlx5 rkey lookup.
* **Trainer slots live in the classic cudaMalloc pool.** Not in
PyTorch's VMM `expandable_segments` pool — `nvidia_peermem` refuses
cuMemMap-backed VA. Managed by `classic_cuda_alloc()` context.
* **NIC pinning is per-GPU.** Every trainer agent uses its GPU's
PIX-attached NIC via `pin_ucx_rail(local_rank)`. Without this, inference
decode's pre-set `UCX_NET_DEVICES=mlx5_0:1` (from vLLM's PD KV
connector) would serialize every weight write through one NIC per
decode node.
* **Chunk selection is prep-time, not post-time.** Each `remote_prep`
is built from exactly one serialized dlist entry. `post_write` uses
`remote_idx=0` always; the chunk is already encoded in the prep
itself.

## What flows on the wire

### SPG control plane (rendezvous, once)

Round 1: layout only — trainer ships `list[LayoutEntry]`, inference
ships `expert_map`. Agent metadata is deferred so round-2 metadata
covers every chunk MR.

Round 2: agent metadata + inference's `descriptors` (per-slot_key
lists of serialized chunk dlists) + `expert_map` again.

### NIXL data plane (every push)

One RDMA WRITE per `(local_slot_chunk, inference_peer_chunk)` pair.
Write table size is fixed at rendezvous; per-push the only thing that
changes is the bytes.

### SPG control plane (every push)

Exactly one barrier at the end of `push_once`, joined by all trainer
and inference ranks.

### Filesystem (every push)

One `STABLE` touched by trainer rank 0, one `NIXL_READY` touched by
the orchestrator, under `broadcasts/step_N/` in the run's output dir.

## Who can break the contract

* **Changing a `ConversionSpec` between runs** (dtype, sources, cat_dim)
without rebuilding the write table on both sides — the slot inventory
and layout no longer match.
* **Allocating slots outside `classic_cuda_alloc()`.**
* **Creating the trainer's `NixlAgentWrapper` before `pin_ucx_rail`.**
* **Posting WRITEs before `NIXL_READY` / before the trainer-side
`dist.barrier()`** — races against live forward passes.
* **Skipping the end-of-push `spg.barrier()`.** Orchestrator will
`/resume` inference before some peers have acked their writes.
* **Registering the same inference tensor twice.** Overlapping MRs
are what `makeXferReq` refuses with LOCAL_PROTECTION.

## Scope boundary

Not part of the contract:

* How `ConversionSpec` is constructed from model code — that's the
model's business (`conversion_specs(layer_idx)` hook).
* Which UCX backends / transports are selected — `NixlAgentWrapper`
picks them based on env vars set by `pin_ucx_rail`.
* How FSDP / EP / CP meshes are built — `ParallelDims` is handed to
`TransportPlan`, the plan reads the mesh but does not shape it.
* How inference's `expert_map` is computed — `build_expert_map`
reads it off the vLLM MoE modules.
* Orchestrator pause/resume internals — the trainer-side code only
waits for the `NIXL_READY` marker.
1 change: 0 additions & 1 deletion examples/glm5_pd_disag/rl.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ name = "zai-org/GLM-5"
type = "nccl"
port = 29502
timeout = 12000
quantize_in_weight_transfer = true

[deployment]
type = "multi_node"
Expand Down
6 changes: 3 additions & 3 deletions src/prime_rl/configs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class ModelConfig(BaseModelConfig):
class WeightBroadcastConfig(BaseConfig):
"""Configures weight broadcast settings."""

type: Annotated[Literal["nccl", "filesystem"], Field(description="The type of weight broadcast to use.")] = (
"filesystem"
)
type: Annotated[
Literal["nccl", "filesystem", "nixl"], Field(description="The type of weight broadcast to use.")
] = "filesystem"


# Valid vLLM max_lora_rank values (from vllm/config/lora.py)
Expand Down
36 changes: 29 additions & 7 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,10 +790,6 @@ class NCCLWeightBroadcastConfig(BaseModel):
host: Annotated[str, Field(description="The host to use for the NCCL broadcast.")] = "localhost"
port: Annotated[int, Field(description="The port to use for the NCCL broadcast.")] = 29501
timeout: Annotated[int, Field(description="The timeout in seconds to use for the NCCL broadcast.")] = 1200
quantize_in_weight_transfer: Annotated[
bool,
Field(description="Use kernel-format FP8 quantized NCCL transfer for weight updates."),
] = False

inference_world_size: Annotated[
int,
Expand All @@ -804,8 +800,34 @@ class NCCLWeightBroadcastConfig(BaseModel):
] = 1


class NIXLWeightBroadcastConfig(BaseModel):
"""Configures the NIXL weight transfer.

FP8 kernel-format transfer is always used for NIXL.
"""

type: Literal["nixl"] = "nixl"

host: Annotated[str, Field(description="Rendezvous host used by StatelessProcessGroup.")] = "localhost"
port: Annotated[int, Field(description="Rendezvous port used by StatelessProcessGroup.")] = 29502
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds.")] = 1200

trainer_world_size: Annotated[
int,
Field(ge=1, description="Total number of trainer ranks (FSDP × EP)."),
] = 1

inference_world_size: Annotated[
int,
Field(ge=1, description="Total number of inference GPUs across all servers."),
] = 1

backends: Annotated[list[str], Field(description="NIXL backends to initialize the agent with.")] = ["UCX"]


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down Expand Up @@ -1097,9 +1119,9 @@ def validate_unique_filter_types(self):

@model_validator(mode="after")
def nccl_max_async_level(self):
if self.weight_broadcast.type == "nccl":
if self.weight_broadcast.type in ("nccl", "nixl"):
if not self.max_async_level == 1:
raise ValueError("max_async_level must be 1 for NCCL broadcast")
raise ValueError(f"max_async_level must be 1 for {self.weight_broadcast.type} broadcast")
return self

@model_validator(mode="after")
Expand Down
Loading
Loading