Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion configs/ci/integration/rl_multi_run/trainer.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Trainer config for multi-run RL integration test
max_concurrent_runs = 2
max_async_level = 5

[model]
name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT"
Expand Down
1 change: 0 additions & 1 deletion configs/debug/orch.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
max_steps = 5
max_async_level = 5
batch_size = 16

[sampling]
Expand Down
1 change: 0 additions & 1 deletion configs/debug/rl/train.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
max_steps = 5
max_async_level = 5

[data.fake]
batch_size = 2
Expand Down
8 changes: 5 additions & 3 deletions docs/async.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Asynchronous Training

PRIME-RL implements asynchronous off-policy training, instead of the traditional synchronous on-policy training. This means that we allow inference to generate rollouts from a stale policy up to $k$ (in the code we call this `max_async_level`) steps ahead of the trainer. With `k=1` and trainer and inference step timings being equal, this allows to run without any idle time on either the trainer or inference. By default, we set `k=2` to allow overlap with a weight broadcast over the Internet, which is needed for decentralized training.
PRIME-RL implements asynchronous off-policy training, instead of the traditional synchronous on-policy training. This means that we allow inference to generate rollouts from a stale policy while the trainer progresses. The orchestrator always serves rollouts from the latest available policy, and off-policy rollouts beyond `max_off_policy_steps` are dropped.

For debugging, you can set `on_policy = true` to force fully synchronous on-policy RL. In this mode, the orchestrator blocks until the trainer's checkpoint for the current step is ready. This is significantly slower and is intended only for testing.

![Two-Step Off-Policy Training](assets/two-step-off-policy.png)

Expand Down Expand Up @@ -34,6 +36,6 @@ where $\mu$ refers to the policy that generated the rollout, $\pi$ refers to the
PRIME-RL uses a global training step $n=1,2,3,\dots$ that is used to tag artifacts:

- **Trainer**: Produces policy $\pi_n$ with weights $\theta_n$ from rollouts $(x_n, y_n)$
- **Inference**: Produces rollouts $(x_n, y_n)$ from policy $\pi_{max(0, n-k)}$
- **Inference**: Produces rollouts $(x_n, y_n)$ from the latest available policy $\pi_m$ with $m \le n$

Here, $k$ is the `max_async_level` parameter, which defaults to 2. Note that we use 0-indexed steps to cleanly indicate that at each step, the divergence off-policy gap is at most $k$ steps.
Rollouts whose off-policy gap $n - m$ exceeds `max_off_policy_steps` are dropped.
22 changes: 3 additions & 19 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,25 +1001,17 @@ class OrchestratorConfig(BaseConfig):
),
] = 8

max_async_level: Annotated[
int,
Field(
ge=0,
description="Maximum number of steps the inference can be ahead of training. If 0, will degenerate to synchronous on-policy RL. If >=1, training and inference will be overlapped.",
),
Comment thread
samsja marked this conversation as resolved.
] = 1

strict_async_level: Annotated[
on_policy: Annotated[
bool,
Field(
description="Whether to strictly enforce the max async level. If True, will always ensure that the policy used for generating rollouts is exactly `max_async_level` steps ahead of training. If False, any policy that is at most `max_async_level` steps ahead of training is allowed, i.e. we always use the latest available policy.",
description="Debug-only flag to force fully synchronous on-policy RL. If True, the orchestrator waits for the trainer to produce a checkpoint at the current step before generating rollouts. This is significantly slower than async training and is intended only for testing/debugging on-policy behavior.",
),
] = False

bench: Annotated[
bool,
Field(
description="Whether to run in benchmark mode. It will automatically set the maximum number of steps to run to 5, max async level to ~infinity and disable W&B.",
description="Whether to run in benchmark mode. It will automatically set the maximum number of steps to run to 5 and disable W&B.",
),
] = False

Expand Down Expand Up @@ -1086,13 +1078,6 @@ def validate_unique_filter_types(self):
raise ValueError(f"Duplicate filter types: {types}. Each filter type may only appear once.")
return self

@model_validator(mode="after")
Comment thread
mikasenghaas marked this conversation as resolved.
def nccl_max_async_level(self):
if self.weight_broadcast.type == "nccl":
if not self.max_async_level == 1:
raise ValueError("max_async_level must be 1 for NCCL broadcast")
return self

@model_validator(mode="after")
def resolve_batching(self):
has_rollout_batch = self.batch_size is not None
Expand Down Expand Up @@ -1136,7 +1121,6 @@ def resolve_batching(self):
def auto_setup_bench(self):
if self.bench:
self.max_steps = 4 # Run for 1 warmup step + 3 evaluation steps
self.max_async_level = int(1e9) # Never wait for RL weight checkpoints
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bench mode lost "never wait" async override

Medium Severity

The auto_setup_bench method previously set self.max_async_level = int(1e9) with the explicit comment "Never wait for RL weight checkpoints." This line was removed without any replacement. Bench mode now uses the hardcoded async level of 1, meaning the orchestrator will block waiting for trainer checkpoints at step 2+. This defeats the purpose of benchmark mode, which is to measure orchestrator throughput without trainer bottlenecks.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 7c7817b. Configure here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this is valid i think. if we hardcode async level 1 then the benchmark mode will stop at the async barrier of 1. maybe we can circumvent this by setting enable_policy_updates=False


# Disable evaluation
self.eval = None
Expand Down
20 changes: 10 additions & 10 deletions src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
from prime_rl.utils.logger import get_logger
from prime_rl.utils.validation import (
validate_shared_ckpt_config,
validate_shared_max_async_level,
validate_shared_max_steps,
validate_shared_model_name,
validate_shared_on_policy,
validate_shared_output_dir,
validate_shared_tokenizer,
validate_shared_wandb_config,
Expand Down Expand Up @@ -330,10 +330,10 @@ class RLConfig(BaseConfig):
),
] = None

max_async_level: Annotated[
int | None,
on_policy: Annotated[
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we can remove this from the shared config? the trainer doesn't need to know about the async level anymore. thus far, it needed to to know at which point it can start cleaning broadcast checkpoints. since now async level is implicitly <=1, we are fine with cleaning dirs >=2 steps away

bool | None,
Field(
description="The async level to use. If None, will fallback to the async level specified on submodule configs."
description="Debug-only flag to force fully synchronous on-policy RL. If None, falls back to the value on submodule configs. Significantly slower than async training."
),
] = None

Expand Down Expand Up @@ -608,13 +608,13 @@ def auto_setup_max_steps(self):
return self

@model_validator(mode="after")
def auto_setup_async_level(self):
"""Auto-setup shared async level for trainer and orchestrator."""
if self.max_async_level is not None:
self.trainer.max_async_level = self.max_async_level
self.orchestrator.max_async_level = self.max_async_level
def auto_setup_on_policy(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove this as well

"""Auto-setup shared on_policy flag for trainer and orchestrator."""
if self.on_policy is not None:
self.trainer.on_policy = self.on_policy
self.orchestrator.on_policy = self.on_policy

validate_shared_max_async_level(self.trainer, self.orchestrator)
validate_shared_on_policy(self.trainer, self.orchestrator)

return self

Expand Down
15 changes: 4 additions & 11 deletions src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,13 +784,12 @@ class TrainerConfig(BaseConfig):
),
] = None

max_async_level: Annotated[
int,
on_policy: Annotated[
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove

bool,
Field(
ge=0,
description="Maximum number of steps that inference can be ahead of training. Determines how 'off-policy' the inference engines can be. Higher values yield better throughput through async execution, but may yield lower performance. If 0, will be fully synchronous.",
description="Debug-only flag to force fully synchronous on-policy RL. When True, the trainer broadcasts weights every step (including the final one) and the orchestrator blocks until the matching checkpoint is available. Significantly slower than async training.",
),
] = 1
] = False

enable_router_replay: Annotated[
bool,
Expand Down Expand Up @@ -908,12 +907,6 @@ def validate_lora_adapter_saving(self):
)
return self

@model_validator(mode="after")
def validate_weight_broadcast_type(self):
if self.weight_broadcast.type == "nccl" and self.max_async_level != 1:
raise ValueError("NCCL weight broadcast only works with async level 1")
return self

@model_validator(mode="after")
def validate_opt_and_fsdp_offload(self):
if self.optim.type == "muon" and self.model.fsdp_cpu_offload:
Expand Down
5 changes: 2 additions & 3 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,8 @@ async def orchestrate(config: OrchestratorConfig):
buffer=buffer,
inference_pool=inference_pool,
max_inflight_rollouts=config.max_inflight_rollouts,
max_async_level=config.max_async_level,
max_off_policy_steps=config.max_off_policy_steps,
strict_async_level=config.strict_async_level,
on_policy=config.on_policy,
tasks_per_minute=config.tasks_per_minute,
enable_policy_updates=enable_policy_updates,
lora_name=config.model.lora.name if config.model.lora else None,
Expand Down Expand Up @@ -485,7 +484,7 @@ async def orchestrate(config: OrchestratorConfig):

# VLM: build image cache in a thread so it doesn't block the event loop.
# This lets the scheduler continue servicing inflight rollout requests
# and — with max_async_level >= 2 — overlap with the next batch's inference.
# and overlap with the next batch's inference.
if is_vlm:
vlm_cache = await asyncio.to_thread(build_vlm_image_cache, train_rollouts, processor)
mm_token_type_ids_mapping = {}
Expand Down
32 changes: 18 additions & 14 deletions src/prime_rl/orchestrator/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def __init__(
buffer: Buffer,
config: OrchestratorConfig,
max_inflight_rollouts: int,
max_async_level: int,
max_off_policy_steps: int,
strict_async_level: bool,
on_policy: bool,
tasks_per_minute: int | None,
enable_policy_updates: bool = True,
lora_name: str | None = None,
Expand All @@ -81,9 +80,8 @@ def __init__(
self.token_batch_size = config.token_batch_size
self.rollouts_per_example = config.rollouts_per_example
self.max_inflight_rollouts = max_inflight_rollouts
self.max_async_level = max_async_level
self.max_off_policy_steps = max_off_policy_steps
self.strict_async_level = strict_async_level
self.on_policy = on_policy
self.enable_policy_updates = enable_policy_updates
self.lora_name = lora_name
self.model_name = self.config.model.name
Expand Down Expand Up @@ -274,17 +272,23 @@ async def update_policy_loop(self):

def _compute_next_ckpt_step(self) -> int:
latest_ckpt_step = get_latest_ckpt_step(get_broadcast_dir(self.config.output_dir)) or 0
async_away_ckpt_step = max(self.step - self.max_async_level, 0)
if self.strict_async_level:
return async_away_ckpt_step
return max(async_away_ckpt_step, latest_ckpt_step)
if self.on_policy:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need the async_away_ckpt_step but hardcoded with max_async_level=1. otherwise the orchestrator will race away from the trainer

return self.step
# Bound the orchestrator/trainer gap to max_off_policy_steps so the orchestrator can't
# race ahead indefinitely when the trainer is slow to broadcast. When latest_ckpt_step
# lags below this bound, we return the bound to trigger a wait path.
min_required_ckpt_step = max(self.step - self.max_off_policy_steps, 0)
return max(min_required_ckpt_step, latest_ckpt_step)
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated

async def _apply_policy_update(self, next_ckpt_step: int) -> None:
async_away_ckpt_step = max(self.step - self.max_async_level, 0)
if next_ckpt_step == async_away_ckpt_step:
latest_ckpt_step = get_latest_ckpt_step(get_broadcast_dir(self.config.output_dir)) or 0
if next_ckpt_step > latest_ckpt_step:
reason = (
"on-policy mode, on_policy=true" if self.on_policy else f">{self.max_off_policy_steps} step(s) ahead"
)
self.logger.info(
f"Orchestrator paused: waiting for trainer process to complete checkpoint {next_ckpt_step} "
f"(>{self.max_async_level} step(s) ahead). Training is progressing normally."
f"Orchestrator paused: waiting for trainer to complete checkpoint {next_ckpt_step} "
f"({reason}). Training is progressing normally."
)
self.checkpoint_ready.clear()
wait_for_ckpt_start_time = time.perf_counter()
Expand Down Expand Up @@ -379,8 +383,8 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]:
if self.update_policy_task is not None:
await safe_cancel(self.update_policy_task)

# Manually check the async barrier before starting the step, then re-create the update policy loop
# This ensures that we respect max_async_level, while still listening for policy updates mid-step
# Manually check the async barrier before starting the step, then re-create the update policy loop.
# In on_policy mode this also blocks until the matching on-policy checkpoint is ready.
await self.maybe_update_policy()
self.update_policy_task = asyncio.create_task(self.update_policy_loop())
else:
Expand Down
3 changes: 1 addition & 2 deletions src/prime_rl/trainer/rl/broadcast/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,10 @@ def _notify_orchestrator(self, save_dir: Path):
stable_file = save_dir / "STABLE"
stable_file.touch()

def maybe_clean(self, max_async_level: int, interval_to_keep: int | None):
def maybe_clean(self, interval_to_keep: int | None):
for idx in self.multi_run_manager.used_idxs:
maybe_clean(
get_broadcast_dir(self.multi_run_manager.get_run_dir(idx)),
self.multi_run_manager.progress[idx].step,
max_async_level,
interval_to_keep,
)
11 changes: 7 additions & 4 deletions src/prime_rl/trainer/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,20 +236,23 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
is_last_step = config.max_steps is not None and progress.step == config.max_steps

# Broadcast weights at every step, (except step 0, because no need to broadcast the base model)
# Also, with NCCL broadcast, we do not broadcast weights the last async level step as the orchestrator is already finished and will not initialize the receive on the inference; for filesystem broadcast, we do "broadcast" until the final step to allow to resume from the broadcast directory
# Also, with NCCL broadcast, we do not broadcast weights on the final step as the orchestrator is already finished and will not initialize the receive on the inference; for filesystem broadcast, we do "broadcast" until the final step to allow to resume from the broadcast directory
if weight_broadcast is None:
broadcast_weights_time = 0
else:
last_async_level_steps = config.max_steps and progress.step >= config.max_steps - config.max_async_level
if progress.step > 0 and (not last_async_level_steps or config.weight_broadcast.type == "filesystem"):
# Skip the trailing broadcast that the orchestrator would never consume:
# async mode stays one step behind the trainer, on_policy is lockstep.
tail_skip = 0 if config.on_policy else 1
is_tail_step = config.max_steps and progress.step >= config.max_steps - tail_skip
if progress.step > 0 and (not is_tail_step or config.weight_broadcast.type == "filesystem"):
broadcast_weights_start_time = time.perf_counter()
weight_broadcast.broadcast_weights(model, step=progress.step)
broadcast_weights_time = time.perf_counter() - broadcast_weights_start_time
# Clean up old broadcast directories (unless at ckpt interval if using filesystem weight broadcast)
ckpt_interval = config.ckpt and config.ckpt.interval
interval_to_keep = ckpt_interval if config.weight_broadcast.type == "filesystem" else None
if config.weight_broadcast.type == "filesystem":
weight_broadcast.maybe_clean(config.max_async_level, interval_to_keep)
weight_broadcast.maybe_clean(interval_to_keep)
else:
broadcast_weights_time = 0
# Usually the broadcast will set this. If broadcast is skipped, we need to reset this here.
Expand Down
5 changes: 3 additions & 2 deletions src/prime_rl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,10 @@ def step(self):
self.step_num += 1


def maybe_clean(path: Path, step: int, async_level: int, interval_to_keep: int | None) -> None:
def maybe_clean(path: Path, step: int, interval_to_keep: int | None) -> None:
logger = get_logger()
step = max(step - (async_level + 1), 0) # Consider deleting async_level + 1 steps ago
# Keep the two most recent broadcasts (current step and previous), delete everything older.
step = max(step - 2, 0)
candidate_path_to_delete = get_step_path(path, step)
keep = bool(interval_to_keep and step % interval_to_keep == 0)
logger.debug(f"Considering deleting path {candidate_path_to_delete}")
Expand Down
6 changes: 3 additions & 3 deletions src/prime_rl/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ def validate_shared_max_steps(
)


def validate_shared_max_async_level(
def validate_shared_on_policy(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove

trainer: TrainerConfig,
orchestrator: OrchestratorConfig,
) -> None:
if trainer.max_async_level != orchestrator.max_async_level:
if trainer.on_policy != orchestrator.on_policy:
raise ValueError(
f"Trainer max async level ({trainer.max_async_level}) and orchestrator max async level ({orchestrator.max_async_level}) are not the same. Please specify the same max async level for both."
f"Trainer on_policy ({trainer.on_policy}) and orchestrator on_policy ({orchestrator.on_policy}) do not match. Please specify the same value for both."
)


Expand Down
3 changes: 1 addition & 2 deletions tests/unit/orchestrator/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

def make_scheduler() -> Scheduler:
scheduler = Scheduler.__new__(Scheduler)
scheduler.max_async_level = 1
scheduler.strict_async_level = False
scheduler.on_policy = False
scheduler.step = 9
scheduler.ckpt_step = 7
scheduler.config = SimpleNamespace(output_dir=Path("/tmp/prime-rl-test"))
Expand Down
Loading