-
Notifications
You must be signed in to change notification settings - Fork 262
feat: replace max_async_level with on_policy debug flag #2328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
19b7c55
5449932
56d2a90
51a3265
7c7817b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,4 @@ | ||
| max_steps = 5 | ||
| max_async_level = 5 | ||
| batch_size = 16 | ||
|
|
||
| [sampling] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,4 @@ | ||
| max_steps = 5 | ||
| max_async_level = 5 | ||
|
|
||
| [data.fake] | ||
| batch_size = 2 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1001,25 +1001,17 @@ class OrchestratorConfig(BaseConfig): | |
| ), | ||
| ] = 8 | ||
|
|
||
| max_async_level: Annotated[ | ||
| int, | ||
| Field( | ||
| ge=0, | ||
| description="Maximum number of steps the inference can be ahead of training. If 0, will degenerate to synchronous on-policy RL. If >=1, training and inference will be overlapped.", | ||
| ), | ||
| ] = 1 | ||
|
|
||
| strict_async_level: Annotated[ | ||
| on_policy: Annotated[ | ||
| bool, | ||
| Field( | ||
| description="Whether to strictly enforce the max async level. If True, will always ensure that the policy used for generating rollouts is exactly `max_async_level` steps ahead of training. If False, any policy that is at most `max_async_level` steps ahead of training is allowed, i.e. we always use the latest available policy.", | ||
| description="Debug-only flag to force fully synchronous on-policy RL. If True, the orchestrator waits for the trainer to produce a checkpoint at the current step before generating rollouts. This is significantly slower than async training and is intended only for testing/debugging on-policy behavior.", | ||
| ), | ||
| ] = False | ||
|
|
||
| bench: Annotated[ | ||
| bool, | ||
| Field( | ||
| description="Whether to run in benchmark mode. It will automatically set the maximum number of steps to run to 5, max async level to ~infinity and disable W&B.", | ||
| description="Whether to run in benchmark mode. It will automatically set the maximum number of steps to run to 5 and disable W&B.", | ||
| ), | ||
| ] = False | ||
|
|
||
|
|
@@ -1086,13 +1078,6 @@ def validate_unique_filter_types(self): | |
| raise ValueError(f"Duplicate filter types: {types}. Each filter type may only appear once.") | ||
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
|
mikasenghaas marked this conversation as resolved.
|
||
| def nccl_max_async_level(self): | ||
| if self.weight_broadcast.type == "nccl": | ||
| if not self.max_async_level == 1: | ||
| raise ValueError("max_async_level must be 1 for NCCL broadcast") | ||
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
| def resolve_batching(self): | ||
| has_rollout_batch = self.batch_size is not None | ||
|
|
@@ -1136,7 +1121,6 @@ def resolve_batching(self): | |
| def auto_setup_bench(self): | ||
| if self.bench: | ||
| self.max_steps = 4 # Run for 1 warmup step + 3 evaluation steps | ||
| self.max_async_level = int(1e9) # Never wait for RL weight checkpoints | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bench mode lost "never wait" async overrideMedium Severity The Reviewed by Cursor Bugbot for commit 7c7817b. Configure here.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh this is valid i think. if we hardcode async level 1 then the benchmark mode will stop at the async barrier of 1. maybe we can circumvent this by setting |
||
|
|
||
| # Disable evaluation | ||
| self.eval = None | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,9 +62,8 @@ def __init__( | |
| buffer: Buffer, | ||
| config: OrchestratorConfig, | ||
| max_inflight_rollouts: int, | ||
| max_async_level: int, | ||
| max_off_policy_steps: int, | ||
| strict_async_level: bool, | ||
| on_policy: bool, | ||
| tasks_per_minute: int | None, | ||
| enable_policy_updates: bool = True, | ||
| lora_name: str | None = None, | ||
|
|
@@ -81,9 +80,8 @@ def __init__( | |
| self.token_batch_size = config.token_batch_size | ||
| self.rollouts_per_example = config.rollouts_per_example | ||
| self.max_inflight_rollouts = max_inflight_rollouts | ||
| self.max_async_level = max_async_level | ||
| self.max_off_policy_steps = max_off_policy_steps | ||
| self.strict_async_level = strict_async_level | ||
| self.on_policy = on_policy | ||
| self.enable_policy_updates = enable_policy_updates | ||
| self.lora_name = lora_name | ||
| self.model_name = self.config.model.name | ||
|
|
@@ -274,17 +272,19 @@ async def update_policy_loop(self): | |
|
|
||
| def _compute_next_ckpt_step(self) -> int: | ||
| latest_ckpt_step = get_latest_ckpt_step(get_broadcast_dir(self.config.output_dir)) or 0 | ||
| async_away_ckpt_step = max(self.step - self.max_async_level, 0) | ||
| if self.strict_async_level: | ||
| return async_away_ckpt_step | ||
| if self.on_policy: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we still need the |
||
| return self.step | ||
| # Cap the orchestrator at most 1 step ahead of the trainer so it doesn't race away. | ||
| async_away_ckpt_step = max(self.step - 1, 0) | ||
| return max(async_away_ckpt_step, latest_ckpt_step) | ||
|
|
||
| async def _apply_policy_update(self, next_ckpt_step: int) -> None: | ||
| async_away_ckpt_step = max(self.step - self.max_async_level, 0) | ||
| if next_ckpt_step == async_away_ckpt_step: | ||
| latest_ckpt_step = get_latest_ckpt_step(get_broadcast_dir(self.config.output_dir)) or 0 | ||
| if next_ckpt_step > latest_ckpt_step: | ||
| reason = "on-policy mode, on_policy=true" if self.on_policy else "1 step ahead" | ||
| self.logger.info( | ||
| f"Orchestrator paused: waiting for trainer process to complete checkpoint {next_ckpt_step} " | ||
| f"(>{self.max_async_level} step(s) ahead). Training is progressing normally." | ||
| f"Orchestrator paused: waiting for trainer to complete checkpoint {next_ckpt_step} " | ||
| f"({reason}). Training is progressing normally." | ||
| ) | ||
| self.checkpoint_ready.clear() | ||
| wait_for_ckpt_start_time = time.perf_counter() | ||
|
|
@@ -379,8 +379,8 @@ async def generate_batch(self, step: int) -> list[vf.RolloutOutput]: | |
| if self.update_policy_task is not None: | ||
| await safe_cancel(self.update_policy_task) | ||
|
|
||
| # Manually check the async barrier before starting the step, then re-create the update policy loop | ||
| # This ensures that we respect max_async_level, while still listening for policy updates mid-step | ||
| # Manually check the async barrier before starting the step, then re-create the update policy loop. | ||
| # In on_policy mode this also blocks until the matching on-policy checkpoint is ready. | ||
| await self.maybe_update_policy() | ||
| self.update_policy_task = asyncio.create_task(self.update_policy_loop()) | ||
| else: | ||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.