Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/prime_rl/configs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class ModelConfig(BaseModelConfig):
class WeightBroadcastConfig(BaseConfig):
"""Configures weight broadcast settings."""

type: Annotated[Literal["nccl", "filesystem"], Field(description="The type of weight broadcast to use.")] = (
"filesystem"
)
type: Annotated[
Literal["nccl", "filesystem", "nixl"], Field(description="The type of weight broadcast to use.")
] = "filesystem"


# Valid vLLM max_lora_rank values (from vllm/config/lora.py)
Expand Down
32 changes: 29 additions & 3 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,34 @@ class NCCLWeightBroadcastConfig(BaseModel):
] = 1


class NIXLWeightBroadcastConfig(BaseModel):
"""Configures the NIXL weight transfer.

FP8 kernel-format transfer is always used for NIXL.
"""

type: Literal["nixl"] = "nixl"

host: Annotated[str, Field(description="Rendezvous host used by StatelessProcessGroup.")] = "localhost"
port: Annotated[int, Field(description="Rendezvous port used by StatelessProcessGroup.")] = 29502
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds.")] = 1200

trainer_world_size: Annotated[
int,
Field(ge=1, description="Total number of trainer ranks (FSDP × EP)."),
] = 1

inference_world_size: Annotated[
int,
Field(ge=1, description="Total number of inference GPUs across all servers."),
] = 1

backends: Annotated[list[str], Field(description="NIXL backends to initialize the agent with.")] = ["UCX"]


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down Expand Up @@ -1097,9 +1123,9 @@ def validate_unique_filter_types(self):

@model_validator(mode="after")
def nccl_max_async_level(self):
if self.weight_broadcast.type == "nccl":
if self.weight_broadcast.type in ("nccl", "nixl"):
if not self.max_async_level == 1:
raise ValueError("max_async_level must be 1 for NCCL broadcast")
raise ValueError(f"max_async_level must be 1 for {self.weight_broadcast.type} broadcast")
return self

@model_validator(mode="after")
Expand Down
18 changes: 17 additions & 1 deletion src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,24 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig):
] = False


class NIXLWeightBroadcastConfig(BaseWeightBroadcastConfig):
"""Configures the NIXL (UCX/RDMA) weight transfer.

FP8 kernel-format transfer is always used for NIXL — the whole point of
the pre-registered stable buffers is to avoid per-step FP8 allocation.
"""

type: Literal["nixl"] = "nixl"
host: Annotated[str, Field(description="Rendezvous host used by StatelessProcessGroup.")] = "localhost"
port: Annotated[int, Field(description="Rendezvous port used by StatelessProcessGroup.")] = 29502
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds.")] = 1200
inference_world_size: Annotated[int, Field(description="Total number of inference GPUs.")] = 1
backends: Annotated[list[str], Field(description="NIXL backends to initialize the agent with.")] = ["UCX"]


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down
19 changes: 19 additions & 0 deletions src/prime_rl/inference/vllm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def models(request: Request) -> OpenAIServingModels:
WORKER_EXTENSION_CLS = {
"nccl": "prime_rl.inference.vllm.worker.nccl.NCCLWeightUpdateWorker",
"filesystem": "prime_rl.inference.vllm.worker.filesystem.FileSystemWeightUpdateWorker",
"nixl": "prime_rl.inference.vllm.worker.nixl.NIXLWeightUpdateWorker",
}


Expand Down Expand Up @@ -235,6 +236,24 @@ async def init_broadcaster(request: Request):
return {"status": "ok"}


@router.post("/init_nixl_transfer")
async def init_nixl_transfer(request: Request):
data = await request.json()
await engine_client(request).collective_rpc(
"init_nixl_transfer",
args=(
data["host"],
data["port"],
data["rank_offset"],
data["trainer_world_size"],
data["inference_world_size"],
data["timeout"],
data.get("backends", ["UCX"]),
),
)
return {"status": "ok"}


@router.post(
"/v1/chat/completions/tokens",
dependencies=[Depends(validate_json_request)],
Expand Down
110 changes: 110 additions & 0 deletions src/prime_rl/inference/vllm/worker/nixl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""vLLM worker extension that receives weight updates over NIXL.

Registers every vLLM parameter + weight-scale buffer with NIXL and publishes
``(ptr, nbytes, device)`` per tensor. The trainer builds its own remote
descriptors at the right byte offsets — this side doesn't need to know the
per-source layout, so one ``all_gather_obj`` round is enough.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from torch.nn import Module
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger

from prime_rl.inference.vllm.worker.weight_transfer import build_expert_map, update_mla_absorbed_weights
from prime_rl.utils.nixl_transfer import NixlAgentWrapper, make_agent_name

if TYPE_CHECKING:
from vllm.v1.worker.gpu_worker import Worker # noqa: F401

Worker = Worker # type: ignore
else:
Worker = object # type: ignore

logger = init_logger("vllm.inference.vllm.worker_nixl")


class NIXLWeightUpdateWorker(Worker):
"""vLLM worker extension for in-place weight updates over NIXL."""

def init_nixl_transfer(
self,
host: str,
port: int,
rank_offset: int,
trainer_world_size: int,
inference_world_size: int,
timeout: int,
backends: list[str],
) -> None:
local_rank = self.device.index
global_rank = trainer_world_size + rank_offset + local_rank
full_world_size = trainer_world_size + inference_world_size

logger.info(
f"Initializing NIXL transfer: local_rank={local_rank} rank_offset={rank_offset} "
f"global_rank={global_rank} trainer_ws={trainer_world_size} inference_ws={inference_world_size}"
)

model_runner = self.model_runner
model = model_runner.model.runnable if hasattr(model_runner.model, "runnable") else model_runner.model
assert isinstance(model, Module)
self._model = model

self._agent = NixlAgentWrapper(
name=make_agent_name("inference", global_rank),
local_rank=local_rank,
backends=backends,
)

tensor_ptrs: dict[str, tuple[int, int, int]] = {}

def _register(name: str, tensor: torch.Tensor) -> None:
contig = tensor.contiguous()
self._agent.register_tensor(contig)
tensor_ptrs[name] = (contig.data_ptr(), contig.numel() * contig.element_size(), contig.get_device())
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated

for name, param in model.named_parameters():
_register(name, param.data)
for name, buf in model.named_buffers():
if name in tensor_ptrs or not name.endswith("_weight_scale_inv"):
continue
_register(name, buf)

expert_map = {k: v.cpu().tolist() for k, v in build_expert_map(model).items()}

self._spg = StatelessProcessGroup.create(
host=host,
port=port,
rank=global_rank,
world_size=full_world_size,
store_timeout=timeout,
)
gathered = self._spg.all_gather_obj(
{
"role": "inference",
"global_rank": global_rank,
"agent_name": self._agent.name,
"agent_metadata": self._agent.get_metadata(),
"tensor_ptrs": tensor_ptrs,
"expert_map": expert_map,
}
)
for peer in gathered[:trainer_world_size]:
self._agent.add_remote(peer["agent_metadata"])

logger.info(
f"NIXL transfer ready: registered {len(tensor_ptrs)} tensors, "
f"added {trainer_world_size} trainer peers"
)

@torch.no_grad()
def update_weights_from_path(self, weight_dir: str | None = None) -> None:
if not hasattr(self, "_spg"):
raise RuntimeError("NIXL transfer not initialized — call /init_nixl_transfer first")
self._spg.barrier()
update_mla_absorbed_weights(self._model)
15 changes: 13 additions & 2 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from prime_rl.trainer.model import setup_tokenizer
from prime_rl.utils.client import (
init_nccl_broadcast,
init_nixl_transfer,
setup_inference_pool,
)
from prime_rl.utils.config import cli
Expand Down Expand Up @@ -274,6 +275,16 @@ async def orchestrate(config: OrchestratorConfig):
inference_world_size=config.weight_broadcast.inference_world_size,
quantize_in_weight_transfer=config.weight_broadcast.quantize_in_weight_transfer,
)
elif config.weight_broadcast.type == "nixl":
await init_nixl_transfer(
inference_pool.admin_clients,
config.weight_broadcast.host,
config.weight_broadcast.port,
config.weight_broadcast.timeout,
trainer_world_size=config.weight_broadcast.trainer_world_size,
inference_world_size=config.weight_broadcast.inference_world_size,
backends=config.weight_broadcast.backends,
)
else:
logger.info("Skipping weight broadcast initialization (SFT distillation mode)")

Expand Down Expand Up @@ -302,8 +313,8 @@ async def orchestrate(config: OrchestratorConfig):
prev_ckpt_step = scheduler.ckpt_step - 1

if enable_policy_updates:
# In NCCL mode, skip existence check - weights are broadcasted, not stored on disk
check_exists = config.weight_broadcast.type != "nccl"
# In NCCL/NIXL mode, skip existence check - weights are transferred, not stored on disk
check_exists = config.weight_broadcast.type not in ("nccl", "nixl")
wait_timeout = config.ckpt.wait_for_weights_timeout if config.ckpt else None
weights_path = get_weight_dir(
config.output_dir, scheduler.ckpt_step, check_exists=check_exists, wait_timeout=wait_timeout
Expand Down
32 changes: 21 additions & 11 deletions src/prime_rl/trainer/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import TYPE_CHECKING

from torch import Tensor
from transformers.modeling_utils import PreTrainedModel

if TYPE_CHECKING:
from prime_rl.trainer.parallel_dims import ParallelDims


class PreTrainedModelPrimeRL(PreTrainedModel):
"""
Expand Down Expand Up @@ -103,22 +108,27 @@ def convert_layer_to_prime(cls, state_dict: dict[str, Tensor], layer_idx: int) -
"""
raise NotImplementedError(f"convert_layer_to_prime is not implemented for {cls.__name__}")

@classmethod
def convert_layer_to_vllm_kernel(
cls,
state_dict: dict[str, Tensor],
self,
layer_idx: int,
quantize_fp8: bool = False,
) -> dict[str, Tensor]:
out_buffers: dict[str, Tensor],
) -> None:
"""Convert one layer of this model's state dict into vLLM FP8 kernel format.

Reads ``self.state_dict()``, resolves any DTensor shards to rank-local
tensors, and writes the converted values into the caller-provided
``out_buffers`` in place. Used by the NIXL broadcast path.
"""
Convert a single layer's state dict from PrimeRL format to vLLM kernel format.
raise NotImplementedError(f"convert_layer_to_vllm_kernel is not implemented for {self.__class__.__name__}")
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated

Args:
state_dict: Layer weights in PrimeRL format.
layer_idx: Layer index to convert.
quantize_fp8: Whether to emit FP8 (e4m3) kernel weights with per-block scales.
def allocate_slots(self, parallel_dims: "ParallelDims") -> dict[int, dict[str, Tensor]]:
"""Allocate stable per-layer destination buffers for NIXL weight transfer.

The NIXL broadcast writes into these buffers in place every push; they
must be sized for the rank-local shard (EP shard for expert tensors,
full tensor for everything else). Implemented per model.
"""
raise NotImplementedError(f"convert_layer_to_vllm_kernel is not implemented for {cls.__name__}")
raise NotImplementedError(f"allocate_slots is not implemented for {self.__class__.__name__}")

def init_buffers_post_meta(self) -> None:
"""
Expand Down
Loading
Loading