-
Notifications
You must be signed in to change notification settings - Fork 262
feat: replace max_async_level with on_policy debug flag #2328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
19b7c55
5449932
56d2a90
51a3265
7c7817b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,4 @@ | ||
| max_steps = 5 | ||
| max_async_level = 5 | ||
| batch_size = 16 | ||
|
|
||
| [sampling] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,4 @@ | ||
| max_steps = 5 | ||
| max_async_level = 5 | ||
|
|
||
| [data.fake] | ||
| batch_size = 2 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1001,25 +1001,17 @@ class OrchestratorConfig(BaseConfig): | |
| ), | ||
| ] = 8 | ||
|
|
||
| max_async_level: Annotated[ | ||
| int, | ||
| Field( | ||
| ge=0, | ||
| description="Maximum number of steps the inference can be ahead of training. If 0, will degenerate to synchronous on-policy RL. If >=1, training and inference will be overlapped.", | ||
| ), | ||
| ] = 1 | ||
|
|
||
| strict_async_level: Annotated[ | ||
| on_policy: Annotated[ | ||
| bool, | ||
| Field( | ||
| description="Whether to strictly enforce the max async level. If True, will always ensure that the policy used for generating rollouts is exactly `max_async_level` steps ahead of training. If False, any policy that is at most `max_async_level` steps ahead of training is allowed, i.e. we always use the latest available policy.", | ||
| description="Debug-only flag to force fully synchronous on-policy RL. If True, the orchestrator waits for the trainer to produce a checkpoint at the current step before generating rollouts. This is significantly slower than async training and is intended only for testing/debugging on-policy behavior.", | ||
| ), | ||
| ] = False | ||
|
|
||
| bench: Annotated[ | ||
| bool, | ||
| Field( | ||
| description="Whether to run in benchmark mode. It will automatically set the maximum number of steps to run to 5, max async level to ~infinity and disable W&B.", | ||
| description="Whether to run in benchmark mode. It will automatically set the maximum number of steps to run to 5 and disable W&B.", | ||
| ), | ||
| ] = False | ||
|
|
||
|
|
@@ -1086,13 +1078,6 @@ def validate_unique_filter_types(self): | |
| raise ValueError(f"Duplicate filter types: {types}. Each filter type may only appear once.") | ||
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
|
mikasenghaas marked this conversation as resolved.
|
||
| def nccl_max_async_level(self): | ||
| if self.weight_broadcast.type == "nccl": | ||
| if not self.max_async_level == 1: | ||
| raise ValueError("max_async_level must be 1 for NCCL broadcast") | ||
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
| def resolve_batching(self): | ||
| has_rollout_batch = self.batch_size is not None | ||
|
|
@@ -1136,7 +1121,6 @@ def resolve_batching(self): | |
| def auto_setup_bench(self): | ||
| if self.bench: | ||
| self.max_steps = 4 # Run for 1 warmup step + 3 evaluation steps | ||
| self.max_async_level = int(1e9) # Never wait for RL weight checkpoints | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bench mode lost "never wait" async overrideMedium Severity The Reviewed by Cursor Bugbot for commit 7c7817b. Configure here.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh this is valid i think. if we hardcode async level 1 then the benchmark mode will stop at the async barrier of 1. maybe we can circumvent this by setting |
||
|
|
||
| # Disable evaluation | ||
| self.eval = None | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,9 +42,9 @@ | |
| from prime_rl.utils.logger import get_logger | ||
| from prime_rl.utils.validation import ( | ||
| validate_shared_ckpt_config, | ||
| validate_shared_max_async_level, | ||
| validate_shared_max_steps, | ||
| validate_shared_model_name, | ||
| validate_shared_on_policy, | ||
| validate_shared_output_dir, | ||
| validate_shared_tokenizer, | ||
| validate_shared_wandb_config, | ||
|
|
@@ -330,10 +330,10 @@ class RLConfig(BaseConfig): | |
| ), | ||
| ] = None | ||
|
|
||
| max_async_level: Annotated[ | ||
| int | None, | ||
| on_policy: Annotated[ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think we can remove this from the shared config? the trainer doesn't need to know about the async level anymore. thus far, it needed to to know at which point it can start cleaning broadcast checkpoints. since now async level is implicitly <=1, we are fine with cleaning dirs >=2 steps away |
||
| bool | None, | ||
| Field( | ||
| description="The async level to use. If None, will fallback to the async level specified on submodule configs." | ||
| description="Debug-only flag to force fully synchronous on-policy RL. If None, falls back to the value on submodule configs. Significantly slower than async training." | ||
| ), | ||
| ] = None | ||
|
|
||
|
|
@@ -608,13 +608,13 @@ def auto_setup_max_steps(self): | |
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
| def auto_setup_async_level(self): | ||
| """Auto-setup shared async level for trainer and orchestrator.""" | ||
| if self.max_async_level is not None: | ||
| self.trainer.max_async_level = self.max_async_level | ||
| self.orchestrator.max_async_level = self.max_async_level | ||
| def auto_setup_on_policy(self): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can remove this as well |
||
| """Auto-setup shared on_policy flag for trainer and orchestrator.""" | ||
| if self.on_policy is not None: | ||
| self.trainer.on_policy = self.on_policy | ||
| self.orchestrator.on_policy = self.on_policy | ||
|
|
||
| validate_shared_max_async_level(self.trainer, self.orchestrator) | ||
| validate_shared_on_policy(self.trainer, self.orchestrator) | ||
|
|
||
| return self | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -784,13 +784,12 @@ class TrainerConfig(BaseConfig): | |
| ), | ||
| ] = None | ||
|
|
||
| max_async_level: Annotated[ | ||
| int, | ||
| on_policy: Annotated[ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can remove |
||
| bool, | ||
| Field( | ||
| ge=0, | ||
| description="Maximum number of steps that inference can be ahead of training. Determines how 'off-policy' the inference engines can be. Higher values yield better throughput through async execution, but may yield lower performance. If 0, will be fully synchronous.", | ||
| description="Debug-only flag to force fully synchronous on-policy RL. When True, the trainer broadcasts weights every step (including the final one) and the orchestrator blocks until the matching checkpoint is available. Significantly slower than async training.", | ||
| ), | ||
| ] = 1 | ||
| ] = False | ||
|
|
||
| enable_router_replay: Annotated[ | ||
| bool, | ||
|
|
@@ -908,12 +907,6 @@ def validate_lora_adapter_saving(self): | |
| ) | ||
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
| def validate_weight_broadcast_type(self): | ||
| if self.weight_broadcast.type == "nccl" and self.max_async_level != 1: | ||
| raise ValueError("NCCL weight broadcast only works with async level 1") | ||
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
| def validate_opt_and_fsdp_offload(self): | ||
| if self.optim.type == "muon" and self.model.fsdp_cpu_offload: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,9 +62,8 @@ def __init__( | |
| buffer: Buffer, | ||
| config: OrchestratorConfig, | ||
| max_inflight_rollouts: int, | ||
| max_async_level: int, | ||
| max_off_policy_steps: int, | ||
| strict_async_level: bool, | ||
| on_policy: bool, | ||
| tasks_per_minute: int | None, | ||
| enable_policy_updates: bool = True, | ||
| lora_name: str | None = None, | ||
|
|
@@ -81,9 +80,8 @@ def __init__( | |
| self.token_batch_size = config.token_batch_size | ||
| self.rollouts_per_example = config.rollouts_per_example | ||
| self.max_inflight_rollouts = max_inflight_rollouts | ||
| self.max_async_level = max_async_level | ||
| self.max_off_policy_steps = max_off_policy_steps | ||
| self.strict_async_level = strict_async_level | ||
| self.on_policy = on_policy | ||
| self.enable_policy_updates = enable_policy_updates | ||
| self.lora_name = lora_name | ||
| self.model_name = self.config.model.name | ||
|
|
@@ -274,17 +272,23 @@ async def update_policy_loop(self): | |
|
|
||
| def _compute_next_ckpt_step(self) -> int: | ||
| latest_ckpt_step = get_latest_ckpt_step(get_broadcast_dir(self.config.output_dir)) or 0 | ||
| async_away_ckpt_step = max(self.step - self.max_async_level, 0) | ||
| if self.strict_async_level: | ||
| return async_away_ckpt_step | ||
| return max(async_away_ckpt_step, latest_ckpt_step) | ||
| if self.on_policy: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we still need the |
||
| return self.step | ||
| # Bound the orchestrator/trainer gap to max_off_policy_steps so the orchestrator can't | ||
| # race ahead indefinitely when the trainer is slow to broadcast. When latest_ckpt_step | ||
| # lags below this bound, we return the bound to trigger a wait path. | ||
| min_required_ckpt_step = max(self.step - self.max_off_policy_steps, 0) | ||
| return max(min_required_ckpt_step, latest_ckpt_step) | ||
|
cursor[bot] marked this conversation as resolved.
Outdated
|
||
|
|
||
| async def _apply_policy_update(self, next_ckpt_step: int) -> None: | ||
| async_away_ckpt_step = max(self.step - self.max_async_level, 0) | ||
| if next_ckpt_step == async_away_ckpt_step: | ||
| latest_ckpt_step = get_latest_ckpt_step(get_broadcast_dir(self.config.output_dir)) or 0 | ||
| if next_ckpt_step > latest_ckpt_step: | ||
| reason = ( | ||
| "on-policy mode, on_policy=true" if self.on_policy else f">{self.max_off_policy_steps} step(s) ahead" | ||
| ) | ||
| self.logger.info( | ||
| f"Orchestrator paused: waiting for trainer process to complete checkpoint {next_ckpt_step} " | ||
| f"(>{self.max_async_level} step(s) ahead). Training is progressing normally." | ||
| f"Orchestrator paused: waiting for trainer to complete checkpoint {next_ckpt_step} " | ||
| f"({reason}). Training is progressing normally." | ||
| ) | ||
| self.checkpoint_ready.clear() | ||
| wait_for_ckpt_start_time = time.perf_counter() | ||
|
|
@@ -379,8 +383,8 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: | |
| if self.update_policy_task is not None: | ||
| await safe_cancel(self.update_policy_task) | ||
|
|
||
| # Manually check the async barrier before starting the step, then re-create the update policy loop | ||
| # This ensures that we respect max_async_level, while still listening for policy updates mid-step | ||
| # Manually check the async barrier before starting the step, then re-create the update policy loop. | ||
| # In on_policy mode this also blocks until the matching on-policy checkpoint is ready. | ||
| await self.maybe_update_policy() | ||
| self.update_policy_task = asyncio.create_task(self.update_policy_loop()) | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -94,13 +94,13 @@ def validate_shared_max_steps( | |
| ) | ||
|
|
||
|
|
||
| def validate_shared_max_async_level( | ||
| def validate_shared_on_policy( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can remove |
||
| trainer: TrainerConfig, | ||
| orchestrator: OrchestratorConfig, | ||
| ) -> None: | ||
| if trainer.max_async_level != orchestrator.max_async_level: | ||
| if trainer.on_policy != orchestrator.on_policy: | ||
| raise ValueError( | ||
| f"Trainer max async level ({trainer.max_async_level}) and orchestrator max async level ({orchestrator.max_async_level}) are not the same. Please specify the same max async level for both." | ||
| f"Trainer on_policy ({trainer.on_policy}) and orchestrator on_policy ({orchestrator.on_policy}) do not match. Please specify the same value for both." | ||
| ) | ||
|
|
||
|
|
||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.