-
Notifications
You must be signed in to change notification settings - Fork 535
Add wall-clock timeout support to MultiTurnEnv #1166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| import asyncio | ||
| import logging | ||
| import time | ||
| from abc import abstractmethod | ||
| from contextlib import suppress | ||
| from typing import final | ||
|
|
||
| import verifiers as vf | ||
|
|
@@ -35,9 +37,15 @@ async def num_turns(self, state: State) -> int: | |
|
|
||
|
|
||
| class MultiTurnEnv(vf.Environment): | ||
| def __init__(self, max_turns: int = -1, **kwargs): | ||
| def __init__( | ||
| self, | ||
| max_turns: int = -1, | ||
| timeout_seconds: float | None = None, | ||
| **kwargs, | ||
| ): | ||
| super().__init__(**kwargs) | ||
| self.max_turns = max_turns | ||
| self.timeout_seconds = timeout_seconds | ||
| self.max_total_completion_tokens: int = -1 | ||
|
|
||
| self.add_rubric(MultiTurnMonitorRubric()) | ||
|
|
@@ -67,6 +75,15 @@ async def prompt_too_long(self, state: State) -> bool: | |
| async def max_turns_reached(self, state: State) -> bool: | ||
| return len(state["trajectory"]) >= self.max_turns and self.max_turns > 0 | ||
|
|
||
| @vf.stop | ||
| async def timeout_reached(self, state: State) -> bool: | ||
| if self.timeout_seconds is None: | ||
| return False | ||
| if time.time() - state["timing"]["start_time"] <= self.timeout_seconds: | ||
|
xeophon marked this conversation as resolved.
Outdated
|
||
| return False | ||
|
Comment on lines
+79
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This timeout check only runs when stop conditions are polled, but Useful? React with 👍 / 👎. |
||
| state["timed_out"] = True | ||
| return True | ||
|
xeophon marked this conversation as resolved.
Outdated
|
||
|
|
||
| @vf.stop | ||
| async def max_total_completion_tokens_reached(self, state: State) -> bool: | ||
| if self.max_total_completion_tokens <= 0: | ||
|
|
@@ -151,7 +168,11 @@ async def rollout( | |
| sampling_args: SamplingArgs | None = None, | ||
| ) -> State: | ||
| state = await self.init_state(input, client, model, sampling_args) | ||
| try: | ||
| rollout_task: asyncio.Task[State] | None = None | ||
|
|
||
| async def run_rollout_loop() -> State: | ||
| nonlocal state | ||
|
|
||
| try: | ||
| state = await self.setup_state(state) | ||
| except vf.Error as e: | ||
|
|
@@ -175,6 +196,32 @@ async def rollout( | |
| state["error"] = e | ||
| await self.render_completion(state) | ||
| return state | ||
|
|
||
| try: | ||
| if self.timeout_seconds is None: | ||
| return await run_rollout_loop() | ||
|
|
||
| rollout_task = asyncio.create_task(run_rollout_loop()) | ||
| done, _ = await asyncio.wait({rollout_task}, timeout=self.timeout_seconds) | ||
| if rollout_task in done: | ||
| return await rollout_task | ||
|
|
||
| rollout_task.cancel() | ||
|
xeophon marked this conversation as resolved.
Outdated
|
||
| with suppress(asyncio.CancelledError): | ||
| await rollout_task | ||
|
|
||
| state["timed_out"] = True | ||
| state["is_completed"] = True | ||
| state["is_truncated"] = True | ||
| state["stop_condition"] = "timeout_reached" | ||
|
xeophon marked this conversation as resolved.
Outdated
|
||
| await self._render_timing(state) | ||
| await self._cleanup(state) | ||
| await self.render_completion(state) | ||
| return state | ||
|
cursor[bot] marked this conversation as resolved.
|
||
| except asyncio.CancelledError: | ||
| if rollout_task is not None and not rollout_task.done(): | ||
| rollout_task.cancel() | ||
| with suppress(asyncio.CancelledError): | ||
| await rollout_task | ||
| await self._cleanup(state) | ||
| raise | ||
Uh oh!
There was an error while loading. Please reload this page.