Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions docs/nixl-weight-broadcast.md
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this a section somewhere? if we don't have yet, can we make a weight-transfer.md which details all of the transfer methods we support atm

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agree, will probably add this in here

Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# NIXL weight broadcast — system contract

What each role promises, what flows between them, and what the system
guarantees after a push.

## Roles

| Role | What it runs | What it owns |
|---|---|---|
| **Orchestrator** | `prime_rl.orchestrator.orchestrator` | Pause/resume of the inference pool, `STABLE`/`NIXL_READY` markers, train-step pacing. |
| **Trainer** | `NIXLWeightBroadcast` → `TransportPlan` | Source of truth for model weights. Decides when a broadcast happens, drives the transfer. |
| **Inference** | `NIXLWeightUpdateWorker` per vLLM worker | Destination buffers. Pauses forward pass during a broadcast, resumes only when the orchestrator allows. |

The transfer happens end-to-end over:

* **SPG** (TCP) — rendezvous, barriers. `trainer_ws + inference_ws`
ranks, established once at trainer init.
* **NIXL / UCX / IB RDMA** — the data path. Trainer posts WRITEs into
pre-registered inference parameter buffers.
* **Filesystem markers** — one-way orchestrator ↔ trainer signaling
(`STABLE`, `NIXL_READY`).

## Trainer ↔ inference contract

The trainer and inference agree on three things *before the first push*
and never renegotiate:

1. **The slot inventory.** Every trainer-side destination buffer has a
unique `slot_key`. The inference side publishes a descriptor list
under the same key. Expert slots use the destination param name;
non-expert slots use the source-tensor name.
2. **The layout of every non-expert destination.** Trainer ships one
`LayoutEntry(slot_key, inference_name, offset_rows, rows, num_chunks)`
per slot-buffer in SPG round 1. Inference narrows its vLLM tensor
per those coordinates and publishes one serialized xfer dlist per
chunk.
3. **The expert map.** Inference publishes
`{moe_prefix: [global_expert_id, …]}` so the trainer knows which
peers own which global experts. Trainer only writes a local expert
to peers that own it.

Once the write table is built, every broadcast reuses it.

### Per-push guarantees (what `push_once` provides)

After `push_once` returns **on every trainer rank** and
`update_weights_from_path` returns **on every inference rank**:

* Every inference-side parameter buffer that the trainer is responsible
for has been overwritten with the current step's weights (after
quantization + dtype cast as declared by the slot's `QuantizationSpec`).
* All RDMA WRITEs have been acknowledged at the remote NIC; no writes
are in flight.
* MLA absorbed weights (`W_UV`, `W_UK_T`) on inference have been
recomputed from the freshly written `kv_b_proj`.

### Per-push non-guarantees

* **No freshness beyond the current step.** If the trainer updates
weights again before the next barrier lands, inference may observe
a mixed snapshot. The orchestrator's pause/resume is what makes this
safe in practice.
* **No delta.** Every push ships the entire registered surface,
regardless of which params changed.
* **No ordering between slots.** Writes are posted in a fixed order but
drained in batches; an inference observer that isn't paused would
see torn writes.

## Orchestrator ↔ trainer ↔ inference contract

Per step, the orchestrator is the one authority that says "it's safe to
overwrite inference weights now" and "you can start serving again."

```
trainer rank 0 orchestrator inference (all ranks)
│ │ │
├── touch STABLE ─────────▶ │
│ ├── /pause ───────────────▶
│ │◀── ack all ─────────────┤
│ ├── touch NIXL_READY │
│◀── see NIXL_READY ──────┤ │
│ │
├───────── dist.barrier() across all trainer ranks ─│
│ │
├─────────── RDMA WRITEs (every rank) ─────────────▶│
│ │
├──────── spg.barrier() across trainer+inference ───│
│ │
│ │◀── /resume ─────────────┤
│ ├── resume ──────────────▶│
```

The contract:

* **Trainer promises:** no rank posts any RDMA WRITE before
`NIXL_READY` is observed. The `dist.barrier()` across all trainer
ranks enforces this — otherwise non-master ranks would race ahead.
* **Orchestrator promises:** once `NIXL_READY` is written, every
inference worker has paused; no forward pass is reading params.
* **Inference promises:** once `update_weights_from_path` enters its
SPG barrier, its params are quiescent and remain quiescent until
both the barrier releases and `/resume` returns.
* **Shared ack:** the final `spg.barrier()` at the end of `push_once`
is the single synchronization point that gates "weights are now in
place" across the 96-rank cluster.

## Registration invariants (set once, forever)

These are properties of the pre-registered buffers. Breaking them
causes `NIXL_ERR_INVALID_PARAM` at post time or mlx5 Local Protection
at WRITE-landing time — both are debugging sinkholes.

* **One MR per logical buffer on inference.** The full vLLM tensor is
registered once. Per-chunk xfer descriptors resolve to that MR's
rkey at write time. Registering overlapping per-chunk MRs trips
mlx5 rkey lookup.
* **Trainer slots live in the classic cudaMalloc pool.** Not in
PyTorch's VMM `expandable_segments` pool — `nvidia_peermem` refuses
cuMemMap-backed VA. Managed by `classic_cuda_alloc()` context.
* **NIC pinning is per-GPU.** Every trainer agent uses its GPU's
PIX-attached NIC via `pin_ucx_rail(local_rank)`. Without this, inference
decode's pre-set `UCX_NET_DEVICES=mlx5_0:1` (from vLLM's PD KV
connector) would serialize every weight write through one NIC per
decode node.
* **Chunk selection is prep-time, not post-time.** Each `remote_prep`
is built from exactly one serialized dlist entry. `post_write` uses
`remote_idx=0` always; the chunk is already encoded in the prep
itself.

## What flows on the wire

### SPG control plane (rendezvous, once)

Round 1: layout only — trainer ships `list[LayoutEntry]`, inference
ships `expert_map`. Agent metadata is deferred so round-2 metadata
covers every chunk MR.

Round 2: agent metadata + inference's `descriptors` (per-slot_key
lists of serialized chunk dlists) + `expert_map` again.

### NIXL data plane (every push)

One RDMA WRITE per `(local_slot_chunk, inference_peer_chunk)` pair.
Write table size is fixed at rendezvous; per-push the only thing that
changes is the bytes.

### SPG control plane (every push)

Exactly one barrier at the end of `push_once`, joined by all trainer
and inference ranks.

### Filesystem (every push)

One `STABLE` touched by trainer rank 0, one `NIXL_READY` touched by
the orchestrator, under `broadcasts/step_N/` in the run's output dir.

## Who can break the contract

* **Changing a `ConversionSpec` between runs** (dtype, sources, cat_dim)
without rebuilding the write table on both sides — the slot inventory
and layout no longer match.
* **Allocating slots outside `classic_cuda_alloc()`.**
* **Creating the trainer's `NixlAgentWrapper` before `pin_ucx_rail`.**
* **Posting WRITEs before `NIXL_READY` / before the trainer-side
`dist.barrier()`** — races against live forward passes.
* **Skipping the end-of-push `spg.barrier()`.** Orchestrator will
`/resume` inference before some peers have acked their writes.
* **Registering the same inference tensor twice.** Overlapping MRs
are what `makeXferReq` refuses with LOCAL_PROTECTION.

## Scope boundary

Not part of the contract:

* How `ConversionSpec` is constructed from model code — that's the
model's business (`conversion_specs(layer_idx)` hook).
* Which UCX backends / transports are selected — `NixlAgentWrapper`
picks them based on env vars set by `pin_ucx_rail`.
* How FSDP / EP / CP meshes are built — `ParallelDims` is handed to
`TransportPlan`, the plan reads the mesh but does not shape it.
* How inference's `expert_map` is computed — `build_expert_map`
reads it off the vLLM MoE modules.
* Orchestrator pause/resume internals — the trainer-side code only
waits for the `NIXL_READY` marker.
1 change: 0 additions & 1 deletion examples/glm5_pd_disag/rl.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ name = "zai-org/GLM-5"
type = "nccl"
port = 29502
timeout = 12000
quantize_in_weight_transfer = true

[deployment]
type = "multi_node"
Expand Down
6 changes: 3 additions & 3 deletions src/prime_rl/configs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class ModelConfig(BaseModelConfig):
class WeightBroadcastConfig(BaseConfig):
"""Configures weight broadcast settings."""

type: Annotated[Literal["nccl", "filesystem"], Field(description="The type of weight broadcast to use.")] = (
"filesystem"
)
type: Annotated[
Literal["nccl", "filesystem", "nixl"], Field(description="The type of weight broadcast to use.")
] = "filesystem"


# Valid vLLM max_lora_rank values (from vllm/config/lora.py)
Expand Down
36 changes: 29 additions & 7 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,10 +790,6 @@ class NCCLWeightBroadcastConfig(BaseModel):
host: Annotated[str, Field(description="The host to use for the NCCL broadcast.")] = "localhost"
port: Annotated[int, Field(description="The port to use for the NCCL broadcast.")] = 29501
timeout: Annotated[int, Field(description="The timeout in seconds to use for the NCCL broadcast.")] = 1200
quantize_in_weight_transfer: Annotated[
bool,
Field(description="Use kernel-format FP8 quantized NCCL transfer for weight updates."),
] = False

inference_world_size: Annotated[
int,
Expand All @@ -804,8 +800,34 @@ class NCCLWeightBroadcastConfig(BaseModel):
] = 1


class NIXLWeightBroadcastConfig(BaseModel):
"""Configures the NIXL weight transfer.

FP8 kernel-format transfer is always used for NIXL.
"""

type: Literal["nixl"] = "nixl"

host: Annotated[str, Field(description="Rendezvous host used by StatelessProcessGroup.")] = "localhost"
port: Annotated[int, Field(description="Rendezvous port used by StatelessProcessGroup.")] = 29502
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds.")] = 1200

trainer_world_size: Annotated[
int,
Field(ge=1, description="Total number of trainer ranks (FSDP × EP)."),
] = 1

inference_world_size: Annotated[
int,
Field(ge=1, description="Total number of inference GPUs across all servers."),
] = 1

backends: Annotated[list[str], Field(description="NIXL backends to initialize the agent with.")] = ["UCX"]


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down Expand Up @@ -1097,9 +1119,9 @@ def validate_unique_filter_types(self):

@model_validator(mode="after")
def nccl_max_async_level(self):
if self.weight_broadcast.type == "nccl":
if self.weight_broadcast.type in ("nccl", "nixl"):
if not self.max_async_level == 1:
raise ValueError("max_async_level must be 1 for NCCL broadcast")
raise ValueError(f"max_async_level must be 1 for {self.weight_broadcast.type} broadcast")
return self

@model_validator(mode="after")
Expand Down
Loading
Loading