Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/prime_rl/configs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class ModelConfig(BaseModelConfig):
class WeightBroadcastConfig(BaseConfig):
"""Configures weight broadcast settings."""

type: Annotated[Literal["nccl", "filesystem"], Field(description="The type of weight broadcast to use.")] = (
"filesystem"
)
type: Annotated[
Literal["nccl", "filesystem", "nixl"], Field(description="The type of weight broadcast to use.")
] = "filesystem"


# Valid vLLM max_lora_rank values (from vllm/config/lora.py)
Expand Down
32 changes: 29 additions & 3 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,8 +804,34 @@ class NCCLWeightBroadcastConfig(BaseModel):
] = 1


class NIXLWeightBroadcastConfig(BaseModel):
"""Configures the NIXL weight transfer.

FP8 kernel-format transfer is always used for NIXL.
"""

type: Literal["nixl"] = "nixl"

host: Annotated[str, Field(description="Rendezvous host used by StatelessProcessGroup.")] = "localhost"
port: Annotated[int, Field(description="Rendezvous port used by StatelessProcessGroup.")] = 29502
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds.")] = 1200

trainer_world_size: Annotated[
int,
Field(ge=1, description="Total number of trainer ranks (FSDP × EP)."),
] = 1

inference_world_size: Annotated[
int,
Field(ge=1, description="Total number of inference GPUs across all servers."),
] = 1

backends: Annotated[list[str], Field(description="NIXL backends to initialize the agent with.")] = ["UCX"]


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down Expand Up @@ -1097,9 +1123,9 @@ def validate_unique_filter_types(self):

@model_validator(mode="after")
def nccl_max_async_level(self):
if self.weight_broadcast.type == "nccl":
if self.weight_broadcast.type in ("nccl", "nixl"):
if not self.max_async_level == 1:
raise ValueError("max_async_level must be 1 for NCCL broadcast")
raise ValueError(f"max_async_level must be 1 for {self.weight_broadcast.type} broadcast")
return self

@model_validator(mode="after")
Expand Down
70 changes: 62 additions & 8 deletions src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from prime_rl.configs.orchestrator import (
NCCLWeightBroadcastConfig as OrchestratorNCCLWeightBroadcastConfig,
)
from prime_rl.configs.orchestrator import (
NIXLWeightBroadcastConfig as OrchestratorNIXLWeightBroadcastConfig,
)
from prime_rl.configs.orchestrator import (
OrchestratorConfig,
)
Expand All @@ -38,6 +41,9 @@
from prime_rl.configs.trainer import (
NCCLWeightBroadcastConfig as TrainerNCCLWeightBroadcastConfig,
)
from prime_rl.configs.trainer import (
NIXLWeightBroadcastConfig as TrainerNIXLWeightBroadcastConfig,
)
from prime_rl.utils.config import BaseConfig
from prime_rl.utils.logger import get_logger
from prime_rl.utils.validation import (
Expand Down Expand Up @@ -146,21 +152,25 @@ class SharedModelConfig(BaseConfig):
class SharedWeightBroadcastConfig(BaseConfig):
"""Configures shared weight broadcast settings."""

type: Annotated[Literal["nccl", "filesystem"], Field(description="The type of weight broadcast to use.")] = (
"filesystem"
)
type: Annotated[
Literal["nccl", "filesystem", "nixl"], Field(description="The type of weight broadcast to use.")
] = "filesystem"

port: Annotated[int, Field(description="The port to use for NCCL weight broadcast.")] = 29501
timeout: Annotated[int, Field(description="The timeout in seconds for NCCL weight broadcast.")] = 1200
port: Annotated[int, Field(description="Rendezvous port (NCCL or NIXL).")] = 29501
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds (NCCL or NIXL).")] = 1200
quantize_in_weight_transfer: Annotated[
bool,
Field(
description=(
"Use kernel-format FP8 quantized NCCL transfer for weight updates. "
"When disabled, uses default HF checkpoint-format transfer."
"When disabled, uses default HF checkpoint-format transfer. "
"Only valid with type='nccl' (NIXL is always kernel-format)."
),
),
] = False
backends: Annotated[
list[str], Field(description="NIXL backends (only used when type='nixl').")
] = ["UCX"]


class BaseDeploymentConfig(BaseModel):
Expand Down Expand Up @@ -236,6 +246,13 @@ def teacher_inference_not_supported(self):
]


def _infer_trainer_world_size(deployment: "DeploymentConfig") -> int:
"""Total number of trainer ranks across all nodes."""
if deployment.type == "single_node":
return deployment.num_train_gpus
return deployment.num_train_nodes * deployment.gpus_per_node


class RLConfig(BaseConfig):
"""Configures an RL training run."""

Expand Down Expand Up @@ -393,10 +410,11 @@ def validate_no_teacher_in_multinode(self):
@model_validator(mode="after")
def validate_enough_devices_for_nccl(self):
if self.deployment.type == "single_node":
if self.trainer.weight_broadcast.type == "nccl":
if self.trainer.weight_broadcast.type in ("nccl", "nixl"):
if self.deployment.num_train_gpus + self.deployment.num_infer_gpus < 2:
raise ValueError(
"NCCL weight broadcast requires at least 2 GPUs to build the broadcast process group."
f"{self.trainer.weight_broadcast.type.upper()} weight broadcast requires at least 2 GPUs to "
"build the broadcast process group."
)
return self

Expand Down Expand Up @@ -659,6 +677,24 @@ def auto_setup_weight_broadcast(self):
inference_world_size=inference_world_size,
quantize_in_weight_transfer=self.weight_broadcast.quantize_in_weight_transfer,
)
elif self.weight_broadcast.type == "nixl":
inference_world_size = self.inference.parallel.dp * self.inference.parallel.tp if self.inference else 1
trainer_world_size = _infer_trainer_world_size(self.deployment)
self.trainer.weight_broadcast = TrainerNIXLWeightBroadcastConfig(
type=self.weight_broadcast.type,
port=self.weight_broadcast.port,
timeout=self.weight_broadcast.timeout,
inference_world_size=inference_world_size,
backends=self.weight_broadcast.backends,
)
self.orchestrator.weight_broadcast = OrchestratorNIXLWeightBroadcastConfig(
type=self.weight_broadcast.type,
port=self.weight_broadcast.port,
timeout=self.weight_broadcast.timeout,
trainer_world_size=trainer_world_size,
inference_world_size=inference_world_size,
backends=self.weight_broadcast.backends,
)
elif self.weight_broadcast.type == "filesystem":
self.trainer.weight_broadcast = TrainerFileSystemWeightBroadcastConfig()
self.orchestrator.weight_broadcast = OrchestratorFileSystemWeightBroadcastConfig()
Expand Down Expand Up @@ -858,6 +894,18 @@ def auto_setup_deployment(self):
assert self.orchestrator.weight_broadcast.type == "nccl"
self.orchestrator.weight_broadcast.inference_world_size = total_infer_workers

if self.weight_broadcast is not None and self.weight_broadcast.type == "nixl":
api_server_count = self.inference.api_server_count if self.inference else 1
tp = self.inference.parallel.tp if self.inference else 1
total_infer_workers = self.deployment.total_infer_nodes * api_server_count * tp
trainer_world_size = _infer_trainer_world_size(self.deployment)
assert self.trainer.weight_broadcast.type == "nixl"
self.trainer.weight_broadcast.host = "0.0.0.0"
self.trainer.weight_broadcast.inference_world_size = total_infer_workers
assert self.orchestrator.weight_broadcast.type == "nixl"
self.orchestrator.weight_broadcast.inference_world_size = total_infer_workers
self.orchestrator.weight_broadcast.trainer_world_size = trainer_world_size

return self

@model_validator(mode="after")
Expand All @@ -883,6 +931,12 @@ def auto_setup_disaggregated_inference(self):
self.trainer.weight_broadcast.inference_world_size = total_infer_gpus
assert self.orchestrator.weight_broadcast.type == "nccl"
self.orchestrator.weight_broadcast.inference_world_size = total_infer_gpus
if self.weight_broadcast is not None and self.weight_broadcast.type == "nixl":
assert self.trainer.weight_broadcast.type == "nixl"
self.trainer.weight_broadcast.inference_world_size = total_infer_gpus
assert self.orchestrator.weight_broadcast.type == "nixl"
self.orchestrator.weight_broadcast.inference_world_size = total_infer_gpus
self.orchestrator.weight_broadcast.trainer_world_size = _infer_trainer_world_size(self.deployment)

return self

Expand Down
18 changes: 17 additions & 1 deletion src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,24 @@ class NCCLWeightBroadcastConfig(BaseWeightBroadcastConfig):
] = False


class NIXLWeightBroadcastConfig(BaseWeightBroadcastConfig):
"""Configures the NIXL (UCX/RDMA) weight transfer.

FP8 kernel-format transfer is always used for NIXL — the whole point of
the pre-registered stable buffers is to avoid per-step FP8 allocation.
"""

type: Literal["nixl"] = "nixl"
host: Annotated[str, Field(description="Rendezvous host used by StatelessProcessGroup.")] = "localhost"
port: Annotated[int, Field(description="Rendezvous port used by StatelessProcessGroup.")] = 29502
timeout: Annotated[int, Field(description="Rendezvous timeout in seconds.")] = 1200
inference_world_size: Annotated[int, Field(description="Total number of inference GPUs.")] = 1
backends: Annotated[list[str], Field(description="NIXL backends to initialize the agent with.")] = ["UCX"]


WeightBroadcastConfig: TypeAlias = Annotated[
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig, Field(discriminator="type")
FileSystemWeightBroadcastConfig | NCCLWeightBroadcastConfig | NIXLWeightBroadcastConfig,
Field(discriminator="type"),
]


Expand Down
2 changes: 2 additions & 0 deletions src/prime_rl/entrypoints/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def write_slurm_script(config: RLConfig, config_dir: Path, script_path: Path) ->
kv_offload=infer_deploy.kv_cache_offload is not None,
kv_offload_cpu_bytes=int(infer_deploy.kv_cache_offload.cpu_bytes) if infer_deploy.kv_cache_offload else 0,
use_nccl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nccl",
use_nixl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nixl",
wandb_shared=config.wandb is not None and config.wandb.shared,
ranks_filter=",".join(map(str, config.trainer.log.ranks_filter)),
)
Expand All @@ -470,6 +471,7 @@ def write_slurm_script(config: RLConfig, config_dir: Path, script_path: Path) ->
inference_data_parallel_rpc_port=config.inference.data_parallel_rpc_port if config.inference else 29600,
dp_per_node=(config.deployment.gpus_per_node // config.inference.parallel.tp) if config.inference else 1,
use_nccl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nccl",
use_nixl_broadcast=config.weight_broadcast is not None and config.weight_broadcast.type == "nixl",
wandb_shared=config.wandb is not None and config.wandb.shared,
ranks_filter=",".join(map(str, config.trainer.log.ranks_filter)),
)
Expand Down
19 changes: 19 additions & 0 deletions src/prime_rl/inference/vllm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def models(request: Request) -> OpenAIServingModels:
WORKER_EXTENSION_CLS = {
"nccl": "prime_rl.inference.vllm.worker.nccl.NCCLWeightUpdateWorker",
"filesystem": "prime_rl.inference.vllm.worker.filesystem.FileSystemWeightUpdateWorker",
"nixl": "prime_rl.inference.vllm.worker.nixl.NIXLWeightUpdateWorker",
}


Expand Down Expand Up @@ -235,6 +236,24 @@ async def init_broadcaster(request: Request):
return {"status": "ok"}


@router.post("/init_nixl_transfer")
async def init_nixl_transfer(request: Request):
data = await request.json()
await engine_client(request).collective_rpc(
"init_nixl_transfer",
args=(
data["host"],
data["port"],
data["rank_offset"],
data["trainer_world_size"],
data["inference_world_size"],
data["timeout"],
data.get("backends", ["UCX"]),
),
)
return {"status": "ok"}


@router.post(
"/v1/chat/completions/tokens",
dependencies=[Depends(validate_json_request)],
Expand Down
Loading
Loading