Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion assets/lab/environments/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class MyGameEnv(vf.MultiTurnEnv):
return state.get("lives", 1) <= 0
```

`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, and `max_total_completion_tokens` by default.
`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, `timeout_seconds`, and `max_total_completion_tokens` by default.

Execution order can be controlled with `priority` (higher runs first). This is useful for checking cheap conditions before expensive ones:

Expand Down
2 changes: 1 addition & 1 deletion docs/environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ class MyGameEnv(vf.MultiTurnEnv):
return state.get("lives", 1) <= 0
```

`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, and `max_total_completion_tokens` by default.
`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, `timeout_seconds`, and `max_total_completion_tokens` by default.

Execution order can be controlled with `priority` (higher runs first). This is useful for checking cheap conditions before expensive ones:

Expand Down
4 changes: 2 additions & 2 deletions docs/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ The `--env-args` flag passes arguments to your `load_environment()` function:
prime eval run my-env -a '{"difficulty": "hard", "num_examples": 100}'
```

The `--extra-env-kwargs` flag passes arguments directly to the environment constructor, useful for overriding defaults like `max_turns` which may not be exposed via `load_environment()`:
The `--extra-env-kwargs` flag passes arguments directly to the environment constructor, useful for overriding defaults like `max_turns` or setting rollout limits like `timeout_seconds` which may not be exposed via `load_environment()`:

```bash
prime eval run my-env -x '{"max_turns": 20}'
prime eval run my-env -x '{"max_turns": 20, "timeout_seconds": 600}'
Comment thread
cursor[bot] marked this conversation as resolved.
```

#### Executor autoscaling
Expand Down
9 changes: 7 additions & 2 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,12 @@ Single-response Q&A tasks. Inherits from `Environment`.

```python
class MultiTurnEnv(Environment):
def __init__(self, max_turns: int = -1, **kwargs): ...
def __init__(
self,
max_turns: int = -1,
timeout_seconds: float | None = None,
**kwargs,
): ...
```

Multi-turn interactions. Subclasses must implement `env_response`.
Expand All @@ -339,7 +344,7 @@ async def env_response(self, messages: Messages, state: State, **kwargs) -> Mess
"""Generate environment feedback after model turn."""
```

**Built-in stop conditions:** `has_error`, `prompt_too_long`, `max_turns_reached`, `max_total_completion_tokens_reached`, `has_final_env_response`
**Built-in stop conditions:** `has_error`, `prompt_too_long`, `max_turns_reached`, `timeout_reached`, `max_total_completion_tokens_reached`, `has_final_env_response`

**Hooks:**

Expand Down
2 changes: 1 addition & 1 deletion environments/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class MyGameEnv(vf.MultiTurnEnv):
return state.get("lives", 1) <= 0
```

`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, and `max_total_completion_tokens` by default.
`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, `timeout_seconds`, and `max_total_completion_tokens` by default.

Execution order can be controlled with `priority` (higher runs first). This is useful for checking cheap conditions before expensive ones:

Expand Down
25 changes: 25 additions & 0 deletions tests/test_eval_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,20 @@ def test_cli_temperature_not_added_when_none(monkeypatch, run_cli):
assert "temperature" not in sa


def test_cli_extra_env_kwargs_support_timeout_seconds(monkeypatch, run_cli):
captured = run_cli(
monkeypatch,
{
"extra_env_kwargs": {"timeout_seconds": 30, "foo": "bar"},
},
)

assert captured["configs"][0].extra_env_kwargs == {
"timeout_seconds": 30,
"foo": "bar",
}


def test_cli_headers_table_and_list_merge(monkeypatch, run_cli):
captured = run_cli(
monkeypatch,
Expand Down Expand Up @@ -872,6 +886,17 @@ def test_load_toml_config_global_values_with_per_eval_override():
assert result[1]["num_examples"] == 50 # per-eval override


def test_load_toml_config_with_extra_env_kwargs():
with tempfile.NamedTemporaryFile(suffix=".toml", delete=False, mode="w") as f:
f.write(
'[[eval]]\nenv_id = "env1"\n[eval.extra_env_kwargs]\ntimeout_seconds = 600\n'
)
f.flush()
result = load_toml_config(Path(f.name))

assert result[0]["extra_env_kwargs"] == {"timeout_seconds": 600}


def test_load_toml_config_invalid_global_field():
"""Invalid global field raises ValueError."""
with tempfile.NamedTemporaryFile(suffix=".toml", delete=False, mode="w") as f:
Expand Down
67 changes: 67 additions & 0 deletions tests/test_multiturn_env.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Tests for the MultiTurnEnv class."""

import asyncio
import time

import pytest
from datasets import Dataset

Expand All @@ -12,6 +15,7 @@ class TestMultiTurnEnv:
def test_multiturn_env_initialization(self, mock_multiturn_env):
"""Test MultiTurnEnv initialization."""
assert mock_multiturn_env.max_turns == 3
assert mock_multiturn_env.timeout_seconds is None
assert mock_multiturn_env.message_type == "chat" # Default from parent

def test_multiturn_env_default_max_turns(self, mock_client, sample_chat_dataset):
Expand All @@ -26,6 +30,31 @@ def test_multiturn_env_default_max_turns(self, mock_client, sample_chat_dataset)
rubric=Rubric(),
)
assert env.max_turns == -1 # Default value
assert env.timeout_seconds is None

@pytest.mark.asyncio
async def test_timeout_reached_stop_condition(
self, mock_client, sample_chat_dataset
):
"""Test the timeout_reached stop condition."""
from tests.conftest import SimpleMultiTurnEnv

env = SimpleMultiTurnEnv(
client=mock_client,
model="test-model",
dataset=sample_chat_dataset,
parser=Parser(),
rubric=Rubric(),
timeout_seconds=10.0,
)

state: State = {"timing": {"start_time": time.time()}}
assert await env.timeout_reached(state) is False
assert state.get("timed_out") is None

state = {"timing": {"start_time": time.time() - 20}}
assert await env.timeout_reached(state) is True
assert state["timed_out"] is True

@pytest.mark.asyncio
async def test_basic_multiturn_rollout(self, mock_multiturn_env, make_input):
Expand Down Expand Up @@ -103,6 +132,44 @@ async def test_max_turns_limiting(self, mock_multiturn_env_max_turns, make_input
assert completion[1]["role"] == "user"
assert completion[2]["role"] == "assistant"

@pytest.mark.asyncio
async def test_timeout_seconds_limits_rollout(
self, mock_client, sample_chat_dataset, make_input
):
"""Test that rollout stops when the wall-clock timeout is reached."""

class SlowMultiTurnEnv(MultiTurnEnv):
async def env_response(self, messages, state, **kwargs): # type: ignore[override]
await asyncio.sleep(0.05)
return [{"role": "user", "content": "Continue"}]

env = SlowMultiTurnEnv(
client=mock_client,
model="test-model",
dataset=sample_chat_dataset,
parser=Parser(),
rubric=Rubric(),
timeout_seconds=0.01,
)
mock_client.set_default_response("Still going")

prompt = [{"role": "user", "content": "Start conversation"}]
state = await env.rollout(
input=make_input(prompt=prompt, answer="target_answer"),
client=mock_client,
model="test-model",
)

assert len(state["trajectory"]) == 1
assert state["timed_out"] is True
assert state["is_completed"] is True
assert state["is_truncated"] is True
assert state["stop_condition"] == "timeout_reached"
completion = state["completion"]
assert len(completion) == 1
assert completion[0]["role"] == "assistant"
assert completion[0]["content"] == "Still going"

@pytest.mark.asyncio
async def test_override_is_completed_respects_max_turns(
self, mock_client, sample_chat_dataset, make_input
Expand Down
5 changes: 4 additions & 1 deletion verifiers/envs/experimental/cli_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,17 @@ async def get_docker_image(self, state: State) -> str:

def get_sandbox_resources(self, state: State) -> dict[str, Any]:
"""Get sandbox resource allocation. Override for per-instance resources."""
timeout_seconds = self.timeout_seconds
if timeout_seconds is None:
timeout_seconds = 0.0
return {
"cpu_cores": self.cpu_cores,
"memory_gb": self.memory_gb,
"disk_size_gb": self.disk_size_gb,
"gpu_count": self.gpu_count,
"gpu_type": None,
"vm": self.gpu_count > 0,
"timeout_minutes": math.ceil(self.timeout_seconds / 60),
"timeout_minutes": math.ceil(timeout_seconds / 60),
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
}

# Keys set by build_env_vars that subclasses must not override.
Expand Down
51 changes: 49 additions & 2 deletions verifiers/envs/multiturn_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import logging
import time
from abc import abstractmethod
from contextlib import suppress
from typing import final

import verifiers as vf
Expand Down Expand Up @@ -35,9 +37,15 @@ async def num_turns(self, state: State) -> int:


class MultiTurnEnv(vf.Environment):
def __init__(self, max_turns: int = -1, **kwargs):
def __init__(
self,
max_turns: int = -1,
timeout_seconds: float | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.max_turns = max_turns
self.timeout_seconds = timeout_seconds
self.max_total_completion_tokens: int = -1

self.add_rubric(MultiTurnMonitorRubric())
Expand Down Expand Up @@ -67,6 +75,15 @@ async def prompt_too_long(self, state: State) -> bool:
async def max_turns_reached(self, state: State) -> bool:
return len(state["trajectory"]) >= self.max_turns and self.max_turns > 0

@vf.stop
async def timeout_reached(self, state: State) -> bool:
if self.timeout_seconds is None:
return False
if time.time() - state["timing"]["start_time"] <= self.timeout_seconds:
Comment thread
xeophon marked this conversation as resolved.
Outdated
return False
Comment on lines +79 to +84
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Enforce timeout during in-flight rollout operations

This timeout check only runs when stop conditions are polled, but rollout() can spend unbounded time awaiting setup_state, get_prompt_messages, or get_model_response between those polls. As a result, a hung/slow setup or model call can exceed timeout_seconds by a large margin (or until an external timeout fires), so the new wall-clock timeout does not reliably cap rollout runtime.

Useful? React with 👍 / 👎.

state["timed_out"] = True
return True
Comment thread
xeophon marked this conversation as resolved.
Outdated

@vf.stop
async def max_total_completion_tokens_reached(self, state: State) -> bool:
if self.max_total_completion_tokens <= 0:
Expand Down Expand Up @@ -151,7 +168,11 @@ async def rollout(
sampling_args: SamplingArgs | None = None,
) -> State:
state = await self.init_state(input, client, model, sampling_args)
try:
rollout_task: asyncio.Task[State] | None = None

async def run_rollout_loop() -> State:
nonlocal state

try:
state = await self.setup_state(state)
except vf.Error as e:
Expand All @@ -175,6 +196,32 @@ async def rollout(
state["error"] = e
await self.render_completion(state)
return state

try:
if self.timeout_seconds is None:
return await run_rollout_loop()

rollout_task = asyncio.create_task(run_rollout_loop())
done, _ = await asyncio.wait({rollout_task}, timeout=self.timeout_seconds)
if rollout_task in done:
return await rollout_task

rollout_task.cancel()
Comment thread
xeophon marked this conversation as resolved.
Outdated
with suppress(asyncio.CancelledError):
await rollout_task

state["timed_out"] = True
state["is_completed"] = True
state["is_truncated"] = True
state["stop_condition"] = "timeout_reached"
Comment thread
xeophon marked this conversation as resolved.
Outdated
await self._render_timing(state)
await self._cleanup(state)
await self.render_completion(state)
return state
Comment thread
cursor[bot] marked this conversation as resolved.
except asyncio.CancelledError:
if rollout_task is not None and not rollout_task.done():
rollout_task.cancel()
with suppress(asyncio.CancelledError):
await rollout_task
await self._cleanup(state)
raise
Loading