Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/prime_rl/configs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class ModelConfig(BaseModelConfig):
class WeightBroadcastConfig(BaseConfig):
"""Configures weight broadcast settings."""

type: Annotated[Literal["nccl", "filesystem"], Field(description="The type of weight broadcast to use.")] = (
"filesystem"
)
type: Annotated[
Literal["nccl", "filesystem", "nixl"], Field(description="The type of weight broadcast to use.")
] = "filesystem"


# Valid vLLM max_lora_rank values (from vllm/config/lora.py)
Expand Down
32 changes: 29 additions & 3 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,34 @@ class NCCLWeightBroadcastConfig(BaseModel):
] = 1


class NIXLWeightBroadcastConfig(BaseModel):
"""Configures the NIXL weight transfer.

FP8 kernel-format transfer is always used for NIXL.
"""

type: Literal["nixl"] = "nixl"

host: Annotated[str, Field(description="Rendezvous host used by StatelessProcessGroup.")] = "localhost"
port: Annotated[int, Field(description="Rendezvous port used by StatelessProcessGroup.")] = 29502
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds.")] = 1200

trainer_world_size: Annotated[
int,
Field(ge=1, description="Total number of trainer ranks (FSDP × EP)."),
] = 1

inference_world_size: Annotated[
int,
Field(ge=1, description="Total number of inference GPUs across all servers."),
] = 1

backends: Annotated[list[str], Field(description="NIXL backends to initialize the agent with.")] = ["UCX"]


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down Expand Up @@ -1097,9 +1123,9 @@ def validate_unique_filter_types(self):

@model_validator(mode="after")
def nccl_max_async_level(self):
if self.weight_broadcast.type == "nccl":
if self.weight_broadcast.type in ("nccl", "nixl"):
if not self.max_async_level == 1:
raise ValueError("max_async_level must be 1 for NCCL broadcast")
raise ValueError(f"max_async_level must be 1 for {self.weight_broadcast.type} broadcast")
return self

@model_validator(mode="after")
Expand Down
18 changes: 17 additions & 1 deletion src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,24 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig):
] = False


class NIXLWeightBroadcastConfig(BaseWeightBroadcastConfig):
"""Configures the NIXL (UCX/RDMA) weight transfer.

FP8 kernel-format transfer is always used for NIXL — the whole point of
the pre-registered stable buffers is to avoid per-step FP8 allocation.
"""

type: Literal["nixl"] = "nixl"
host: Annotated[str, Field(description="Rendezvous host used by StatelessProcessGroup.")] = "localhost"
port: Annotated[int, Field(description="Rendezvous port used by StatelessProcessGroup.")] = 29502
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds.")] = 1200
inference_world_size: Annotated[int, Field(description="Total number of inference GPUs.")] = 1
backends: Annotated[list[str], Field(description="NIXL backends to initialize the agent with.")] = ["UCX"]


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down
19 changes: 19 additions & 0 deletions src/prime_rl/inference/vllm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def models(request: Request) -> OpenAIServingModels:
WORKER_EXTENSION_CLS = {
"nccl": "prime_rl.inference.vllm.worker.nccl.NCCLWeightUpdateWorker",
"filesystem": "prime_rl.inference.vllm.worker.filesystem.FileSystemWeightUpdateWorker",
"nixl": "prime_rl.inference.vllm.worker.nixl.NIXLWeightUpdateWorker",
}


Expand Down Expand Up @@ -235,6 +236,24 @@ async def init_broadcaster(request: Request):
return {"status": "ok"}


@router.post("/init_nixl_transfer")
async def init_nixl_transfer(request: Request):
data = await request.json()
await engine_client(request).collective_rpc(
"init_nixl_transfer",
args=(
data["host"],
data["port"],
data["rank_offset"],
data["trainer_world_size"],
data["inference_world_size"],
data["timeout"],
data.get("backends", ["UCX"]),
),
)
return {"status": "ok"}


@router.post(
"/v1/chat/completions/tokens",
dependencies=[Depends(validate_json_request)],
Expand Down
136 changes: 136 additions & 0 deletions src/prime_rl/inference/vllm/worker/nixl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""vLLM worker extension that receives weight updates over NIXL.

Counterpart to :mod:`prime_rl.trainer.rl.broadcast.nixl`. The inference side
registers parameter memory directly with NIXL (zero-copy RDMA target),
publishes its expert-ownership map per FusedMoE module, and sits on a
single process-group barrier per sync while the trainer posts writes.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import torch
from torch.nn import Module
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger

from prime_rl.inference.vllm.worker.weight_transfer import build_expert_map, update_mla_absorbed_weights
from prime_rl.utils.nixl_transfer import NixlAgentWrapper, make_agent_name

if TYPE_CHECKING:
from vllm.v1.worker.gpu_worker import Worker # noqa: F401

Worker = Worker # type: ignore
else:
Worker = object # type: ignore

logger = init_logger("vllm.inference.vllm.worker_nixl")


def _iter_transfer_targets(model: Module):
"""Yield (name, tensor) for every parameter + weight-scale buffer we want to
receive from the trainer.

vLLM stores FP8 scales as buffers (``w13_weight_scale_inv`` / ``w2_weight_scale_inv``),
not parameters, so ``named_parameters()`` alone is insufficient.
"""
seen: set[str] = set()
for name, param in model.named_parameters():
seen.add(name)
yield name, param.data
for name, buf in model.named_buffers():
if name in seen:
continue
# Only ship weight scales — other buffers (rotary embeddings, caches) are not
# synchronized from the trainer.
if name.endswith("_weight_scale_inv"):
yield name, buf


class NIXLWeightUpdateWorker(Worker):
"""vLLM worker extension for in-place weight updates over NIXL."""

def init_nixl_transfer(
self,
host: str,
port: int,
rank_offset: int,
trainer_world_size: int,
inference_world_size: int,
timeout: int,
backends: list[str],
) -> None:
"""Register local parameter memory and rendezvous with the trainer."""
local_rank = self.device.index
global_rank = trainer_world_size + rank_offset + local_rank
full_world_size = trainer_world_size + inference_world_size

logger.info(
f"Initializing NIXL transfer: local_rank={local_rank} rank_offset={rank_offset} "
f"global_rank={global_rank} trainer_ws={trainer_world_size} inference_ws={inference_world_size}"
)

model_runner = self.model_runner
model = model_runner.model.runnable if hasattr(model_runner.model, "runnable") else model_runner.model
assert isinstance(model, Module)
self._model = model

self._agent = NixlAgentWrapper(
name=make_agent_name("inference", global_rank),
local_rank=local_rank,
backends=backends,
)

# Register every receivable tensor and record its serialized descriptor so the
# trainer can deserialize it on the other side.
descriptors: dict[str, bytes] = {}
for name, tensor in _iter_transfer_targets(model):
desc = self._agent.register_tensor(tensor.contiguous())
descriptors[name] = self._agent.serialize_descs(desc)
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated

# Expert ownership per FusedMoE module. Re-use the existing helper — each entry
# is a tensor of global expert indices that this worker holds (sorted by local slot).
expert_map = {k: v.cpu().tolist() for k, v in build_expert_map(model).items()}

self._spg = StatelessProcessGroup.create(
host=host,
port=port,
rank=global_rank,
world_size=full_world_size,
store_timeout=timeout,
)

my_info = {
"role": "inference",
"global_rank": global_rank,
"agent_name": self._agent.name,
"agent_metadata": self._agent.get_metadata(),
"descriptors": descriptors,
"expert_map": expert_map,
}
all_info: list[dict[str, Any]] = self._spg.all_gather_obj(my_info)

# Add every trainer agent so future WRITEs from them can land here.
for peer in all_info[:trainer_world_size]:
self._agent.add_remote(peer["agent_metadata"])

logger.info(
f"NIXL transfer ready: registered {len(descriptors)} tensors, "
f"added {trainer_world_size} trainer peers"
)

@torch.no_grad()
def update_weights_from_path(self, weight_dir: str | None = None) -> None:
"""Receive one round of NIXL writes and repost-process the model.

The actual data movement is driven entirely by the trainer: writes land
directly in the already-registered parameter memory. We only need to
wait on the end-of-sync barrier and recompute MLA absorbed weights.
"""
if not hasattr(self, "_spg"):
raise RuntimeError("NIXL transfer not initialized — call /init_nixl_transfer first")
logger.debug("Waiting for NIXL end-of-sync barrier")
self._spg.barrier()
logger.debug("NIXL writes complete, running postprocess")
update_mla_absorbed_weights(self._model)
15 changes: 13 additions & 2 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from prime_rl.trainer.model import setup_tokenizer
from prime_rl.utils.client import (
init_nccl_broadcast,
init_nixl_transfer,
setup_inference_pool,
)
from prime_rl.utils.config import cli
Expand Down Expand Up @@ -274,6 +275,16 @@ async def orchestrate(config: OrchestratorConfig):
inference_world_size=config.weight_broadcast.inference_world_size,
quantize_in_weight_transfer=config.weight_broadcast.quantize_in_weight_transfer,
)
elif config.weight_broadcast.type == "nixl":
await init_nixl_transfer(
inference_pool.admin_clients,
config.weight_broadcast.host,
config.weight_broadcast.port,
config.weight_broadcast.timeout,
trainer_world_size=config.weight_broadcast.trainer_world_size,
inference_world_size=config.weight_broadcast.inference_world_size,
backends=config.weight_broadcast.backends,
)
else:
logger.info("Skipping weight broadcast initialization (SFT distillation mode)")

Expand Down Expand Up @@ -302,8 +313,8 @@ async def orchestrate(config: OrchestratorConfig):
prev_ckpt_step = scheduler.ckpt_step - 1

if enable_policy_updates:
# In NCCL mode, skip existence check - weights are broadcasted, not stored on disk
check_exists = config.weight_broadcast.type != "nccl"
# In NCCL/NIXL mode, skip existence check - weights are transferred, not stored on disk
check_exists = config.weight_broadcast.type not in ("nccl", "nixl")
wait_timeout = config.ckpt.wait_for_weights_timeout if config.ckpt else None
weights_path = get_weight_dir(
config.output_dir, scheduler.ckpt_step, check_exists=check_exists, wait_timeout=wait_timeout
Expand Down
Loading
Loading