Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
4a7a6dd
Initial
S1ro1 Apr 19, 2026
f270cac
Feat: Cleanup
S1ro1 Apr 19, 2026
cd3a565
Clean up GLM MoE DSA converter + NIXL broadcast
S1ro1 Apr 19, 2026
612429f
Feat: some cleanup
S1ro1 Apr 19, 2026
bec06a0
Feat: cleanup more
S1ro1 Apr 19, 2026
459f19f
wtf did claude cook
S1ro1 Apr 19, 2026
690dc4a
Feat: NIXL broadcast working end-to-end on GLM-5.1 (12-node disagg)
S1ro1 Apr 19, 2026
0d49320
Feat: hard-override UCX_NET_DEVICES in pin_ucx_rail
S1ro1 Apr 19, 2026
18b39fe
Feat: NIXL weight transfer now works with expandable_segments=True
S1ro1 Apr 20, 2026
5ea1051
Feat: ConversionSpec + QuantizationSpec, doc, fix tilelang preload
S1ro1 Apr 20, 2026
ea791f8
Feat: TransportPlan + Slot refactor, drop FP8 NCCL quantize path
S1ro1 Apr 20, 2026
90c4dc4
Docs: NIXL architecture contract + drop stale fixtures
S1ro1 Apr 20, 2026
e78fa10
Docs: rewrite nixl-weight-broadcast.md as a system contract
S1ro1 Apr 20, 2026
ed71964
Docs: drop nixl-architecture.md, superseded by contract rewrite
S1ro1 Apr 20, 2026
3a47826
Fix: typo
S1ro1 Apr 20, 2026
4369d21
Feat: HSDP support (primary-replica push) + EP partition assertion
S1ro1 Apr 20, 2026
f23d68b
Fix: FP8 scale floor back to 1e-12 to match pre-Triton parity
S1ro1 Apr 20, 2026
81be8e7
Doc: KL mismatch investigation scratchpad
S1ro1 Apr 21, 2026
4df14cc
Exp iter2: add end-to-end signature diagnostic for anchor slot
S1ro1 Apr 21, 2026
2740cd3
Exp iter3: expand SIG diagnostic to FP8 gather + expert anchors
S1ro1 Apr 21, 2026
a1bcc7d
Exp iter4: inference SIG lookup checks both param + buffer dicts
S1ro1 Apr 21, 2026
8f41149
Exp iter5 (doc): disable DeepGemm to test layout-mismatch hypothesis
S1ro1 Apr 21, 2026
9cd8541
Exp iter6: SIG now logs shape+stride on both sides
S1ro1 Apr 21, 2026
da5e072
Exp iter7: fused-region sum check + multiple expert anchors
S1ro1 Apr 21, 2026
60a78f5
Exp iter8: transport non-layer tensors (embed, norm, lm_head)
S1ro1 Apr 21, 2026
cd9ff66
Exp iter9: untracked-keys diagnostic for missing slots
S1ro1 Apr 21, 2026
37fc774
Exp iter10: cuda.synchronize on inference after SPG barrier
S1ro1 Apr 21, 2026
d341bd6
Exp iter11 (doc): enforce_eager=true on inference
S1ro1 Apr 21, 2026
0813d85
Exp iter12: verify N anchors (embed/norm/lm_head) transport
S1ro1 Apr 21, 2026
d6cca80
Exp iter13: precise ShardedSlot verification via head[:2420] sum
S1ro1 Apr 21, 2026
0af021b
Exp iter14 (nixl side): flush_every=1 (per-write drain)
S1ro1 Apr 21, 2026
71d24b0
Exp iter14 (doc): maximum conservatism — stack all knobs
S1ro1 Apr 21, 2026
68dcfb4
Investigation wrap-up: exhausted surface-level NIXL hypotheses
S1ro1 Apr 21, 2026
c494158
Exp iter15: pre-write SPG barrier + inference cuda.sync before it
S1ro1 Apr 21, 2026
3e53fa6
Exp iter16: byte-level trainer/inference dump + diff tool
S1ro1 Apr 21, 2026
658f3cc
Exp iter17: pause clear_cache=true — test KV cache staleness theory
S1ro1 Apr 21, 2026
9035914
Exp iter18: swap Triton FP8 quantize for main's PyTorch impl
S1ro1 Apr 21, 2026
6a6a23f
Exp iter19: abort in-flight requests on pause
S1ro1 Apr 21, 2026
b29bae3
Revert "Exp iter19: abort in-flight requests on pause"
S1ro1 Apr 21, 2026
94edaf7
Exp iter19: flush GPUDirect RDMA writes on inference
S1ro1 Apr 21, 2026
a2f81ab
Exp iter20: per-write drain with GPUDirect flush
S1ro1 Apr 21, 2026
b053f67
Revert "Exp iter20: per-write drain with GPUDirect flush"
S1ro1 Apr 21, 2026
121782b
Exp iter21: enable sync memops on NIXL buffers
S1ro1 Apr 21, 2026
6f4a685
Exp iter22-27 (squash): freeze_{experts,non_experts} + transfer_mode …
S1ro1 Apr 21, 2026
1c9fe0c
Doc: wrap-up — iter22-27 summary, iter26/27 W&B data, bug narrowed to…
S1ro1 Apr 21, 2026
94b6ad6
Doc: rule out inference non-determinism
S1ro1 Apr 21, 2026
94c14f4
Remove tools/inference_dashboard from tracking
S1ro1 Apr 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/prime_rl/configs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class ModelConfig(BaseModelConfig):
class WeightBroadcastConfig(BaseConfig):
"""Configures weight broadcast settings."""

type: Annotated[Literal["nccl", "filesystem"], Field(description="The type of weight broadcast to use.")] = (
"filesystem"
)
type: Annotated[
Literal["nccl", "filesystem", "nixl"], Field(description="The type of weight broadcast to use.")
] = "filesystem"


# Valid vLLM max_lora_rank values (from vllm/config/lora.py)
Expand Down
32 changes: 29 additions & 3 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,34 @@ class NCCLWeightBroadcastConfig(BaseModel):
] = 1


class NIXLWeightBroadcastConfig(BaseModel):
"""Configures the NIXL weight transfer.

FP8 kernel-format transfer is always used for NIXL.
"""

type: Literal["nixl"] = "nixl"

host: Annotated[str, Field(description="Rendezvous host used by StatelessProcessGroup.")] = "localhost"
port: Annotated[int, Field(description="Rendezvous port used by StatelessProcessGroup.")] = 29502
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds.")] = 1200

trainer_world_size: Annotated[
int,
Field(ge=1, description="Total number of trainer ranks (FSDP × EP)."),
] = 1

inference_world_size: Annotated[
int,
Field(ge=1, description="Total number of inference GPUs across all servers."),
] = 1

backends: Annotated[list[str], Field(description="NIXL backends to initialize the agent with.")] = ["UCX"]


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down Expand Up @@ -1097,9 +1123,9 @@ def validate_unique_filter_types(self):

@model_validator(mode="after")
def nccl_max_async_level(self):
if self.weight_broadcast.type == "nccl":
if self.weight_broadcast.type in ("nccl", "nixl"):
if not self.max_async_level == 1:
raise ValueError("max_async_level must be 1 for NCCL broadcast")
raise ValueError(f"max_async_level must be 1 for {self.weight_broadcast.type} broadcast")
return self

@model_validator(mode="after")
Expand Down
18 changes: 17 additions & 1 deletion src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,24 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig):
] = False


class NIXLWeightBroadcastConfig(BaseWeightBroadcastConfig):
"""Configures the NIXL (UCX/RDMA) weight transfer.

FP8 kernel-format transfer is always used for NIXL — the whole point of
the pre-registered stable buffers is to avoid per-step FP8 allocation.
"""

type: Literal["nixl"] = "nixl"
host: Annotated[str, Field(description="Rendezvous host used by StatelessProcessGroup.")] = "localhost"
port: Annotated[int, Field(description="Rendezvous port used by StatelessProcessGroup.")] = 29502
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds.")] = 1200
inference_world_size: Annotated[int, Field(description="Total number of inference GPUs.")] = 1
backends: Annotated[list[str], Field(description="NIXL backends to initialize the agent with.")] = ["UCX"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed config field without CHANGELOG entry

Medium Severity

The quantize_in_weight_transfer field is removed from NCCLWeightBroadcastConfig (trainer), NCCLWeightBroadcastConfig (orchestrator), and SharedWeightBroadcastConfig. Existing user configs (e.g., the example in examples/glm5_pd_disag/rl.toml previously had this field set) that include quantize_in_weight_transfer will fail Pydantic validation at startup. This is a breaking config removal without a corresponding CHANGELOG.md entry.

Additional Locations (2)
Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions

Reviewed by Cursor Bugbot for commit 94b6ad6. Configure here.



WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down
19 changes: 19 additions & 0 deletions src/prime_rl/inference/vllm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def models(request: Request) -> OpenAIServingModels:
WORKER_EXTENSION_CLS = {
"nccl": "prime_rl.inference.vllm.worker.nccl.NCCLWeightUpdateWorker",
"filesystem": "prime_rl.inference.vllm.worker.filesystem.FileSystemWeightUpdateWorker",
"nixl": "prime_rl.inference.vllm.worker.nixl.NIXLWeightUpdateWorker",
}


Expand Down Expand Up @@ -235,6 +236,24 @@ async def init_broadcaster(request: Request):
return {"status": "ok"}


@router.post("/init_nixl_transfer")
async def init_nixl_transfer(request: Request):
data = await request.json()
await engine_client(request).collective_rpc(
"init_nixl_transfer",
args=(
data["host"],
data["port"],
data["rank_offset"],
data["trainer_world_size"],
data["inference_world_size"],
data["timeout"],
data.get("backends", ["UCX"]),
),
)
return {"status": "ok"}


@router.post(
"/v1/chat/completions/tokens",
dependencies=[Depends(validate_json_request)],
Expand Down
136 changes: 136 additions & 0 deletions src/prime_rl/inference/vllm/worker/nixl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""vLLM worker extension that receives weight updates over NIXL.

Counterpart to :mod:`prime_rl.trainer.rl.broadcast.nixl`. The inference side
registers parameter memory directly with NIXL (zero-copy RDMA target),
publishes its expert-ownership map per FusedMoE module, and sits on a
single process-group barrier per sync while the trainer posts writes.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch
from torch.nn import Module
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger

from prime_rl.inference.vllm.worker.weight_transfer import build_expert_map, update_mla_absorbed_weights
from prime_rl.utils.nixl_transfer import NixlAgentWrapper, make_agent_name

if TYPE_CHECKING:
from vllm.v1.worker.gpu_worker import Worker # noqa: F401

Worker = Worker # type: ignore
else:
Worker = object # type: ignore

logger = init_logger("vllm.inference.vllm.worker_nixl")


def _iter_transfer_targets(model: Module):
"""Yield (name, tensor) for every parameter + weight-scale buffer we want to
receive from the trainer.

vLLM stores FP8 scales as buffers (``w13_weight_scale_inv`` / ``w2_weight_scale_inv``),
not parameters, so ``named_parameters()`` alone is insufficient.
"""
seen: set[str] = set()
for name, param in model.named_parameters():
seen.add(name)
yield name, param.data
for name, buf in model.named_buffers():
if name in seen:
continue
# Only ship weight scales — other buffers (rotary embeddings, caches) are not
# synchronized from the trainer.
if name.endswith("_weight_scale_inv"):
yield name, buf


class NIXLWeightUpdateWorker(Worker):
"""vLLM worker extension for in-place weight updates over NIXL."""

def init_nixl_transfer(
self,
host: str,
port: int,
rank_offset: int,
trainer_world_size: int,
inference_world_size: int,
timeout: int,
backends: list[str],
) -> None:
"""Register local parameter memory and rendezvous with the trainer."""
local_rank = self.device.index
global_rank = trainer_world_size + rank_offset + local_rank
full_world_size = trainer_world_size + inference_world_size

logger.info(
f"Initializing NIXL transfer: local_rank={local_rank} rank_offset={rank_offset} "
f"global_rank={global_rank} trainer_ws={trainer_world_size} inference_ws={inference_world_size}"
)

model_runner = self.model_runner
model = model_runner.model.runnable if hasattr(model_runner.model, "runnable") else model_runner.model
assert isinstance(model, Module)
self._model = model

self._agent = NixlAgentWrapper(
name=make_agent_name("inference", global_rank),
local_rank=local_rank,
backends=backends,
)

# Register every receivable tensor and record its serialized descriptor so the
# trainer can deserialize it on the other side.
descriptors: dict[str, bytes] = {}
for name, tensor in _iter_transfer_targets(model):
desc = self._agent.register_tensor(tensor.contiguous())
descriptors[name] = self._agent.serialize_descs(desc)
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated

# Expert ownership per FusedMoE module. Re-use the existing helper — each entry
# is a tensor of global expert indices that this worker holds (sorted by local slot).
expert_map = {k: v.cpu().tolist() for k, v in build_expert_map(model).items()}

self._spg = StatelessProcessGroup.create(
host=host,
port=port,
rank=global_rank,
world_size=full_world_size,
store_timeout=timeout,
)

my_info = {
"role": "inference",
"global_rank": global_rank,
"agent_name": self._agent.name,
"agent_metadata": self._agent.get_metadata(),
"descriptors": descriptors,
"expert_map": expert_map,
}
all_info: list[dict[str, Any]] = self._spg.all_gather_obj(my_info)

# Add every trainer agent so future WRITEs from them can land here.
for peer in all_info[:trainer_world_size]:
self._agent.add_remote(peer["agent_metadata"])

logger.info(
f"NIXL transfer ready: registered {len(descriptors)} tensors, "
f"added {trainer_world_size} trainer peers"
)

@torch.no_grad()
def update_weights_from_path(self, weight_dir: str | None = None) -> None:
"""Receive one round of NIXL writes and repost-process the model.

The actual data movement is driven entirely by the trainer: writes land
directly in the already-registered parameter memory. We only need to
wait on the end-of-sync barrier and recompute MLA absorbed weights.
"""
if not hasattr(self, "_spg"):
raise RuntimeError("NIXL transfer not initialized — call /init_nixl_transfer first")
logger.debug("Waiting for NIXL end-of-sync barrier")
self._spg.barrier()
logger.debug("NIXL writes complete, running postprocess")
update_mla_absorbed_weights(self._model)
15 changes: 13 additions & 2 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from prime_rl.trainer.model import setup_tokenizer
from prime_rl.utils.client import (
init_nccl_broadcast,
init_nixl_transfer,
setup_inference_pool,
)
from prime_rl.utils.config import cli
Expand Down Expand Up @@ -274,6 +275,16 @@ async def orchestrate(config: OrchestratorConfig):
inference_world_size=config.weight_broadcast.inference_world_size,
quantize_in_weight_transfer=config.weight_broadcast.quantize_in_weight_transfer,
)
elif config.weight_broadcast.type == "nixl":
await init_nixl_transfer(
inference_pool.admin_clients,
config.weight_broadcast.host,
config.weight_broadcast.port,
config.weight_broadcast.timeout,
trainer_world_size=config.weight_broadcast.trainer_world_size,
inference_world_size=config.weight_broadcast.inference_world_size,
backends=config.weight_broadcast.backends,
)
else:
logger.info("Skipping weight broadcast initialization (SFT distillation mode)")

Expand Down Expand Up @@ -302,8 +313,8 @@ async def orchestrate(config: OrchestratorConfig):
prev_ckpt_step = scheduler.ckpt_step - 1

if enable_policy_updates:
# In NCCL mode, skip existence check - weights are broadcasted, not stored on disk
check_exists = config.weight_broadcast.type != "nccl"
# In NCCL/NIXL mode, skip existence check - weights are transferred, not stored on disk
check_exists = config.weight_broadcast.type not in ("nccl", "nixl")
wait_timeout = config.ckpt.wait_for_weights_timeout if config.ckpt else None
weights_path = get_weight_dir(
config.output_dir, scheduler.ckpt_step, check_exists=check_exists, wait_timeout=wait_timeout
Expand Down
32 changes: 21 additions & 11 deletions src/prime_rl/trainer/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import TYPE_CHECKING

from torch import Tensor
from transformers.modeling_utils import PreTrainedModel

if TYPE_CHECKING:
from prime_rl.trainer.parallel_dims import ParallelDims


class PreTrainedModelPrimeRL(PreTrainedModel):
"""
Expand Down Expand Up @@ -103,22 +108,27 @@ def convert_layer_to_prime(cls, state_dict: dict[str, Tensor], layer_idx: int) -
"""
raise NotImplementedError(f"convert_layer_to_prime is not implemented for {cls.__name__}")

@classmethod
def convert_layer_to_vllm_kernel(
cls,
state_dict: dict[str, Tensor],
self,
layer_idx: int,
quantize_fp8: bool = False,
) -> dict[str, Tensor]:
out_buffers: dict[str, Tensor],
) -> None:
"""Convert one layer of this model's state dict into vLLM FP8 kernel format.

Reads ``self.state_dict()``, resolves any DTensor shards to rank-local
tensors, and writes the converted values into the caller-provided
``out_buffers`` in place. Used by the NIXL broadcast path.
"""
Convert a single layer's state dict from PrimeRL format to vLLM kernel format.
raise NotImplementedError(f"convert_layer_to_vllm_kernel is not implemented for {self.__class__.__name__}")
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated

Args:
state_dict: Layer weights in PrimeRL format.
layer_idx: Layer index to convert.
quantize_fp8: Whether to emit FP8 (e4m3) kernel weights with per-block scales.
def allocate_slots(self, parallel_dims: "ParallelDims") -> dict[int, dict[str, Tensor]]:
"""Allocate stable per-layer destination buffers for NIXL weight transfer.

The NIXL broadcast writes into these buffers in place every push; they
must be sized for the rank-local shard (EP shard for expert tensors,
full tensor for everything else). Implemented per model.
"""
raise NotImplementedError(f"convert_layer_to_vllm_kernel is not implemented for {cls.__name__}")
raise NotImplementedError(f"allocate_slots is not implemented for {self.__class__.__name__}")

def init_buffers_post_meta(self) -> None:
"""
Expand Down
Loading
Loading