Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Documenting **breaking** configuration changes — renamed, removed, or moved fields that require users to update existing configs.

- **`orchestrator.max_async_level` / `orchestrator.strict_async_level` / `trainer.max_async_level` removed; `orchestrator.on_policy` (NEW)**: The numeric async-level knobs are gone. Async training now caps the orchestrator at most one step ahead of the trainer and drops rollouts older than `orchestrator.max_off_policy_steps`. Set `orchestrator.on_policy = true` (default `false`) to force fully synchronous on-policy RL — debug-only, significantly slower. Existing configs using `max_async_level` or `strict_async_level` must delete those fields. (2026-04-20)
- **`model.attn = "eager"` (NEW option)**: Added `eager` as a valid value for the `model.attn` field. Required for GPT-OSS models on non-Hopper GPUs, since the only flash attention kernel GPT-OSS supports (`kernels-community/vllm-flash-attn3`) is Hopper-only. A clear error message is raised at model load time if GPT-OSS is used without `eager` on unsupported hardware. Also added `kernels` as a core dependency. (2026-04-05)
- **`[[orchestrator.env]]` → `[[orchestrator.train.env]]`**: Training environments and sampling are now configured under `[orchestrator.train]`. The old `[[orchestrator.env]]` and `[orchestrator.sampling]` paths are auto-translated with a deprecation warning and will be removed in a future release. (2026-04-09)
- **Per-env sampling overrides (NEW)**: Both `TrainEnvConfig` and `EvalEnvConfig` now accept a `[sampling]` section for per-env overrides. Unset fields inherit from the group-level sampling config (`[orchestrator.train.sampling]` or `[orchestrator.eval.sampling]`). (2026-04-09)
Expand Down
1 change: 0 additions & 1 deletion configs/ci/integration/rl_multi_run/trainer.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Trainer config for multi-run RL integration test
max_concurrent_runs = 2
max_async_level = 5

[model]
name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT"
Expand Down
1 change: 0 additions & 1 deletion configs/debug/orch.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
max_steps = 5
max_async_level = 5
batch_size = 16

[sampling]
Expand Down
1 change: 0 additions & 1 deletion configs/debug/rl/train.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
max_steps = 5
max_async_level = 5

[data.fake]
batch_size = 2
Expand Down
8 changes: 5 additions & 3 deletions docs/async.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Asynchronous Training

PRIME-RL implements asynchronous off-policy training, instead of the traditional synchronous on-policy training. This means that we allow inference to generate rollouts from a stale policy up to $k$ (in the code we call this `max_async_level`) steps ahead of the trainer. With `k=1` and trainer and inference step timings being equal, this allows to run without any idle time on either the trainer or inference. By default, we set `k=2` to allow overlap with a weight broadcast over the Internet, which is needed for decentralized training.
PRIME-RL implements asynchronous off-policy training, instead of the traditional synchronous on-policy training. This means that we allow inference to generate rollouts from a stale policy while the trainer progresses. The orchestrator always serves rollouts from the latest available policy, and off-policy rollouts beyond `max_off_policy_steps` are dropped.

For debugging, you can set `on_policy = true` to force fully synchronous on-policy RL. In this mode, the orchestrator blocks until the trainer's checkpoint for the current step is ready. This is significantly slower and is intended only for testing.

![Two-Step Off-Policy Training](assets/two-step-off-policy.png)

Expand Down Expand Up @@ -34,6 +36,6 @@ where $\mu$ refers to the policy that generated the rollout, $\pi$ refers to the
PRIME-RL uses a global training step $n=1,2,3,\dots$ that is used to tag artifacts:

- **Trainer**: Produces policy $\pi_n$ with weights $\theta_n$ from rollouts $(x_n, y_n)$
- **Inference**: Produces rollouts $(x_n, y_n)$ from policy $\pi_{max(0, n-k)}$
- **Inference**: Produces rollouts $(x_n, y_n)$ from the latest available policy $\pi_m$ with $m \le n$

Here, $k$ is the `max_async_level` parameter, which defaults to 2. Note that we use 0-indexed steps to cleanly indicate that at each step, the divergence off-policy gap is at most $k$ steps.
Rollouts whose off-policy gap $n - m$ exceeds `max_off_policy_steps` are dropped.
22 changes: 3 additions & 19 deletions src/prime_rl/configs/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,25 +1001,17 @@ class OrchestratorConfig(BaseConfig):
),
] = 8

max_async_level: Annotated[
int,
Field(
ge=0,
description="Maximum number of steps the inference can be ahead of training. If 0, will degenerate to synchronous on-policy RL. If >=1, training and inference will be overlapped.",
),
Comment thread
samsja marked this conversation as resolved.
] = 1

strict_async_level: Annotated[
on_policy: Annotated[
bool,
Field(
description="Whether to strictly enforce the max async level. If True, will always ensure that the policy used for generating rollouts is exactly `max_async_level` steps ahead of training. If False, any policy that is at most `max_async_level` steps ahead of training is allowed, i.e. we always use the latest available policy.",
description="Debug-only flag to force fully synchronous on-policy RL. If True, the orchestrator waits for the trainer to produce a checkpoint at the current step before generating rollouts. This is significantly slower than async training and is intended only for testing/debugging on-policy behavior.",
),
] = False

bench: Annotated[
bool,
Field(
description="Whether to run in benchmark mode. It will automatically set the maximum number of steps to run to 5, max async level to ~infinity and disable W&B.",
description="Whether to run in benchmark mode. It will automatically set the maximum number of steps to run to 5 and disable W&B.",
),
] = False

Expand Down Expand Up @@ -1086,13 +1078,6 @@ def validate_unique_filter_types(self):
raise ValueError(f"Duplicate filter types: {types}. Each filter type may only appear once.")
return self

@model_validator(mode="after")
Comment thread
mikasenghaas marked this conversation as resolved.
def nccl_max_async_level(self):
if self.weight_broadcast.type == "nccl":
if not self.max_async_level == 1:
raise ValueError("max_async_level must be 1 for NCCL broadcast")
return self

@model_validator(mode="after")
def resolve_batching(self):
has_rollout_batch = self.batch_size is not None
Expand Down Expand Up @@ -1136,7 +1121,6 @@ def resolve_batching(self):
def auto_setup_bench(self):
if self.bench:
self.max_steps = 4 # Run for 1 warmup step + 3 evaluation steps
self.max_async_level = int(1e9) # Never wait for RL weight checkpoints
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bench mode lost "never wait" async override

Medium Severity

The auto_setup_bench method previously set self.max_async_level = int(1e9) with the explicit comment "Never wait for RL weight checkpoints." This line was removed without any replacement. Bench mode now uses the hardcoded async level of 1, meaning the orchestrator will block waiting for trainer checkpoints at step 2+. This defeats the purpose of benchmark mode, which is to measure orchestrator throughput without trainer bottlenecks.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 7c7817b. Configure here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh this is valid i think. if we hardcode async level 1 then the benchmark mode will stop at the async barrier of 1. maybe we can circumvent this by setting enable_policy_updates=False


# Disable evaluation
self.eval = None
Expand Down
19 changes: 0 additions & 19 deletions src/prime_rl/configs/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from prime_rl.utils.logger import get_logger
from prime_rl.utils.validation import (
validate_shared_ckpt_config,
validate_shared_max_async_level,
validate_shared_max_steps,
validate_shared_model_name,
validate_shared_output_dir,
Expand Down Expand Up @@ -330,13 +329,6 @@ class RLConfig(BaseConfig):
),
] = None

max_async_level: Annotated[
int | None,
Field(
description="The async level to use. If None, will fallback to the async level specified on submodule configs."
),
] = None

weight_broadcast: Annotated[
SharedWeightBroadcastConfig | None, Field(description="The weight broadcast config.")
] = None
Expand Down Expand Up @@ -607,17 +599,6 @@ def auto_setup_max_steps(self):

return self

@model_validator(mode="after")
def auto_setup_async_level(self):
"""Auto-setup shared async level for trainer and orchestrator."""
if self.max_async_level is not None:
self.trainer.max_async_level = self.max_async_level
self.orchestrator.max_async_level = self.max_async_level

validate_shared_max_async_level(self.trainer, self.orchestrator)

return self

@model_validator(mode="after")
def auto_setup_seq_len(self):
"""Auto-setup shared seq_len for trainer and orchestrator.
Expand Down
14 changes: 0 additions & 14 deletions src/prime_rl/configs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,14 +784,6 @@ class TrainerConfig(BaseConfig):
),
] = None

max_async_level: Annotated[
int,
Field(
ge=0,
description="Maximum number of steps that inference can be ahead of training. Determines how 'off-policy' the inference engines can be. Higher values yield better throughput through async execution, but may yield lower performance. If 0, will be fully synchronous.",
),
] = 1

enable_router_replay: Annotated[
bool,
Field(
Expand Down Expand Up @@ -908,12 +900,6 @@ def validate_lora_adapter_saving(self):
)
return self

@model_validator(mode="after")
def validate_weight_broadcast_type(self):
if self.weight_broadcast.type == "nccl" and self.max_async_level != 1:
raise ValueError("NCCL weight broadcast only works with async level 1")
return self

@model_validator(mode="after")
def validate_opt_and_fsdp_offload(self):
if self.optim.type == "muon" and self.model.fsdp_cpu_offload:
Expand Down
5 changes: 2 additions & 3 deletions src/prime_rl/orchestrator/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,8 @@ async def orchestrate(config: OrchestratorConfig):
buffer=buffer,
inference_pool=inference_pool,
max_inflight_rollouts=config.max_inflight_rollouts,
max_async_level=config.max_async_level,
max_off_policy_steps=config.max_off_policy_steps,
strict_async_level=config.strict_async_level,
on_policy=config.on_policy,
tasks_per_minute=config.tasks_per_minute,
enable_policy_updates=enable_policy_updates,
lora_name=config.model.lora.name if config.model.lora else None,
Expand Down Expand Up @@ -485,7 +484,7 @@ async def orchestrate(config: OrchestratorConfig):

# VLM: build image cache in a thread so it doesn't block the event loop.
# This lets the scheduler continue servicing inflight rollout requests
# and — with max_async_level >= 2 — overlap with the next batch's inference.
# and overlap with the next batch's inference.
if is_vlm:
vlm_cache = await asyncio.to_thread(build_vlm_image_cache, train_rollouts, processor)
mm_token_type_ids_mapping = {}
Expand Down
26 changes: 13 additions & 13 deletions src/prime_rl/orchestrator/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def __init__(
buffer: Buffer,
config: OrchestratorConfig,
max_inflight_rollouts: int,
max_async_level: int,
max_off_policy_steps: int,
strict_async_level: bool,
on_policy: bool,
tasks_per_minute: int | None,
enable_policy_updates: bool = True,
lora_name: str | None = None,
Expand All @@ -81,9 +80,8 @@ def __init__(
self.token_batch_size = config.token_batch_size
self.rollouts_per_example = config.rollouts_per_example
self.max_inflight_rollouts = max_inflight_rollouts
self.max_async_level = max_async_level
self.max_off_policy_steps = max_off_policy_steps
self.strict_async_level = strict_async_level
self.on_policy = on_policy
self.enable_policy_updates = enable_policy_updates
self.lora_name = lora_name
self.model_name = self.config.model.name
Expand Down Expand Up @@ -274,17 +272,19 @@ async def update_policy_loop(self):

def _compute_next_ckpt_step(self) -> int:
latest_ckpt_step = get_latest_ckpt_step(get_broadcast_dir(self.config.output_dir)) or 0
async_away_ckpt_step = max(self.step - self.max_async_level, 0)
if self.strict_async_level:
return async_away_ckpt_step
if self.on_policy:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still need the async_away_ckpt_step but hardcoded with max_async_level=1. otherwise the orchestrator will race away from the trainer

return self.step
# Cap the orchestrator at most 1 step ahead of the trainer so it doesn't race away.
async_away_ckpt_step = max(self.step - 1, 0)
return max(async_away_ckpt_step, latest_ckpt_step)

async def _apply_policy_update(self, next_ckpt_step: int) -> None:
async_away_ckpt_step = max(self.step - self.max_async_level, 0)
if next_ckpt_step == async_away_ckpt_step:
latest_ckpt_step = get_latest_ckpt_step(get_broadcast_dir(self.config.output_dir)) or 0
if next_ckpt_step > latest_ckpt_step:
reason = "on-policy mode, on_policy=true" if self.on_policy else "1 step ahead"
self.logger.info(
f"Orchestrator paused: waiting for trainer process to complete checkpoint {next_ckpt_step} "
f"(>{self.max_async_level} step(s) ahead). Training is progressing normally."
f"Orchestrator paused: waiting for trainer to complete checkpoint {next_ckpt_step} "
f"({reason}). Training is progressing normally."
)
self.checkpoint_ready.clear()
wait_for_ckpt_start_time = time.perf_counter()
Expand Down Expand Up @@ -379,8 +379,8 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]:
if self.update_policy_task is not None:
await safe_cancel(self.update_policy_task)

# Manually check the async barrier before starting the step, then re-create the update policy loop
# This ensures that we respect max_async_level, while still listening for policy updates mid-step
# Manually check the async barrier before starting the step, then re-create the update policy loop.
# In on_policy mode this also blocks until the matching on-policy checkpoint is ready.
await self.maybe_update_policy()
self.update_policy_task = asyncio.create_task(self.update_policy_loop())
else:
Expand Down
3 changes: 1 addition & 2 deletions src/prime_rl/trainer/rl/broadcast/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,10 @@ def _notify_orchestrator(self, save_dir: Path):
stable_file = save_dir / "STABLE"
stable_file.touch()

def maybe_clean(self, max_async_level: int, interval_to_keep: int | None):
def maybe_clean(self, interval_to_keep: int | None):
for idx in self.multi_run_manager.used_idxs:
maybe_clean(
get_broadcast_dir(self.multi_run_manager.get_run_dir(idx)),
self.multi_run_manager.progress[idx].step,
max_async_level,
interval_to_keep,
)
10 changes: 6 additions & 4 deletions src/prime_rl/trainer/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,20 +236,22 @@ def load_run_checkpoint(_optimizer, idx: int) -> None:
is_last_step = config.max_steps is not None and progress.step == config.max_steps

# Broadcast weights at every step, (except step 0, because no need to broadcast the base model)
# Also, with NCCL broadcast, we do not broadcast weights the last async level step as the orchestrator is already finished and will not initialize the receive on the inference; for filesystem broadcast, we do "broadcast" until the final step to allow to resume from the broadcast directory
# Also, with NCCL broadcast, we do not broadcast weights on the final step as the orchestrator is already finished and will not initialize the receive on the inference; for filesystem broadcast, we do "broadcast" until the final step to allow to resume from the broadcast directory
if weight_broadcast is None:
broadcast_weights_time = 0
else:
last_async_level_steps = config.max_steps and progress.step >= config.max_steps - config.max_async_level
if progress.step > 0 and (not last_async_level_steps or config.weight_broadcast.type == "filesystem"):
# Async mode stays at most one step behind the trainer, so the orchestrator
# would never consume a broadcast emitted on the final step.
is_tail_step = config.max_steps and progress.step >= config.max_steps - 1
if progress.step > 0 and (not is_tail_step or config.weight_broadcast.type == "filesystem"):
broadcast_weights_start_time = time.perf_counter()
weight_broadcast.broadcast_weights(model, step=progress.step)
broadcast_weights_time = time.perf_counter() - broadcast_weights_start_time
# Clean up old broadcast directories (unless at ckpt interval if using filesystem weight broadcast)
ckpt_interval = config.ckpt and config.ckpt.interval
interval_to_keep = ckpt_interval if config.weight_broadcast.type == "filesystem" else None
if config.weight_broadcast.type == "filesystem":
weight_broadcast.maybe_clean(config.max_async_level, interval_to_keep)
weight_broadcast.maybe_clean(interval_to_keep)
else:
broadcast_weights_time = 0
# Usually the broadcast will set this. If broadcast is skipped, we need to reset this here.
Expand Down
5 changes: 3 additions & 2 deletions src/prime_rl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,10 @@ def step(self):
self.step_num += 1


def maybe_clean(path: Path, step: int, async_level: int, interval_to_keep: int | None) -> None:
def maybe_clean(path: Path, step: int, interval_to_keep: int | None) -> None:
logger = get_logger()
step = max(step - (async_level + 1), 0) # Consider deleting async_level + 1 steps ago
# Keep the two most recent broadcasts (current step and previous), delete everything older.
step = max(step - 2, 0)
candidate_path_to_delete = get_step_path(path, step)
keep = bool(interval_to_keep and step % interval_to_keep == 0)
logger.debug(f"Considering deleting path {candidate_path_to_delete}")
Expand Down
10 changes: 0 additions & 10 deletions src/prime_rl/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,6 @@ def validate_shared_max_steps(
)


def validate_shared_max_async_level(
trainer: TrainerConfig,
orchestrator: OrchestratorConfig,
) -> None:
if trainer.max_async_level != orchestrator.max_async_level:
raise ValueError(
f"Trainer max async level ({trainer.max_async_level}) and orchestrator max async level ({orchestrator.max_async_level}) are not the same. Please specify the same max async level for both."
)


def validate_shared_tokenizer(
trainer: TrainerConfig,
orchestrator: OrchestratorConfig,
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/orchestrator/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

def make_scheduler() -> Scheduler:
scheduler = Scheduler.__new__(Scheduler)
scheduler.max_async_level = 1
scheduler.strict_async_level = False
scheduler.on_policy = False
scheduler.step = 9
scheduler.ckpt_step = 7
scheduler.config = SimpleNamespace(output_dir=Path("/tmp/prime-rl-test"))
Expand Down
Loading