Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for the base Environment class."""

import time
from unittest.mock import AsyncMock, Mock, patch

import pytest
Expand Down Expand Up @@ -185,6 +186,46 @@ def test_get_dataset(self, sample_dataset):
subset = env.get_dataset(n=1)
assert len(subset) == 1

def test_get_eval_dataset_wraps_builder_errors(self):
"""Test eval dataset builder errors include environment context."""

def failing_builder():
raise RuntimeError("Dataset 'cais/hle' is gated")

env = SimpleEnvironment(
eval_dataset=failing_builder,
env_id="hle",
parser=Parser(),
rubric=Rubric(),
)

with pytest.raises(
RuntimeError,
match="Failed to build evaluation dataset for environment 'hle': Dataset 'cais/hle' is gated",
):
env.get_eval_dataset()

def test_get_eval_dataset_timeout_raises_clear_error(self):
"""Test slow eval dataset builders fail with a timeout instead of hanging forever."""

def slow_builder():
time.sleep(0.2)
return Dataset.from_dict({"question": ["q"], "answer": ["a"]})

env = SimpleEnvironment(
eval_dataset=slow_builder,
env_id="hle",
parser=Parser(),
rubric=Rubric(),
dataset_build_timeout_seconds=0.01,
)

with pytest.raises(
RuntimeError,
match="Building evaluation dataset for environment 'hle' timed out after 10ms",
):
env.get_eval_dataset()

@pytest.mark.asyncio
async def test_get_model_response_chat(self, mock_client, make_input):
"""Test get_model_response with chat format."""
Expand Down
43 changes: 43 additions & 0 deletions tests/test_run_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from unittest.mock import AsyncMock, patch

import pytest

from verifiers.types import ClientConfig, EvalConfig
from verifiers.utils.eval_utils import run_evaluation


@pytest.mark.asyncio
async def test_run_evaluation_builds_dataset_before_starting_env_server():
order: list[tuple[str, int]] = []

class FakeEnv:
def set_kwargs(self, **kwargs):
return None

def get_eval_dataset(self, n: int = -1, seed=None):
order.append(("get_eval_dataset", n))
raise RuntimeError("dataset unavailable")

start_server = AsyncMock()
stop_server = AsyncMock()

fake_env = FakeEnv()
config = EvalConfig(
env_id="hle",
env_args={},
env_dir_path="./environments",
model="openai/gpt-4.1-mini",
client_config=ClientConfig(),
sampling_args={},
num_examples=10,
rollouts_per_example=3,
max_concurrent=1,
)

with patch("verifiers.utils.eval_utils.vf.load_environment", return_value=fake_env):
with pytest.raises(RuntimeError, match="dataset unavailable"):
await run_evaluation(config)

assert order == [("get_eval_dataset", 1)]
fake_env.start_server.assert_not_awaited()
fake_env.stop_server.assert_not_awaited()
94 changes: 92 additions & 2 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import json
import logging
import multiprocessing as mp
import os
import queue
import signal
import threading
import time
import uuid
import warnings
Expand Down Expand Up @@ -71,6 +74,7 @@
with_sem,
)
from verifiers.utils.error_utils import ErrorChain
from verifiers.utils.logging_utils import print_time
from verifiers.utils.message_utils import normalize_messages
from verifiers.utils.save_utils import (
GenerateOutputsBuilder,
Expand All @@ -87,6 +91,12 @@
from verifiers.workers.client.env_client import EnvClient

_MESSAGE_TYPE_UNSET = object()
_DATASET_BUILD_TIMEOUT_ENV_VAR = "VF_DATASET_BUILD_TIMEOUT"
_DEFAULT_DATASET_BUILD_TIMEOUT_SECONDS = 300.0


class DatasetBuildError(RuntimeError):
"""Raised when building an environment dataset fails or times out."""
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated


class Environment(ABC):
Expand All @@ -112,6 +122,7 @@ def __init__(
max_seq_len: int | None = None,
score_rollouts: bool = True,
pass_threshold: float = 0.5,
dataset_build_timeout_seconds: float | None = None,
**kwargs,
):
if message_type is _MESSAGE_TYPE_UNSET:
Expand Down Expand Up @@ -148,6 +159,9 @@ def __init__(

self.set_score_rollouts(score_rollouts)
self.pass_threshold = pass_threshold
self.dataset_build_timeout_seconds = self._resolve_dataset_build_timeout(
dataset_build_timeout_seconds
)

self.env_client: EnvClient | None = None
self.env_server_process: BaseProcess | None = None
Expand Down Expand Up @@ -393,13 +407,86 @@ def _format_dataset_source(self, dataset: Dataset) -> Dataset:
map_kwargs=self.map_kwargs,
)

def _resolve_dataset_build_timeout(
self, dataset_build_timeout_seconds: float | None
) -> float | None:
if dataset_build_timeout_seconds is not None:
return (
None
if dataset_build_timeout_seconds <= 0
else dataset_build_timeout_seconds
)

raw_timeout = os.getenv(_DATASET_BUILD_TIMEOUT_ENV_VAR)
if raw_timeout is not None:
try:
parsed_timeout = float(raw_timeout)
except ValueError:
self.logger.warning(
"Invalid %s=%r; using default %.0fs",
_DATASET_BUILD_TIMEOUT_ENV_VAR,
raw_timeout,
_DEFAULT_DATASET_BUILD_TIMEOUT_SECONDS,
)
else:
return None if parsed_timeout <= 0 else parsed_timeout

return _DEFAULT_DATASET_BUILD_TIMEOUT_SECONDS

def _build_dataset_from_source(
self,
source: DatasetBuilder,
*,
source_name: str,
) -> Dataset:
timeout_seconds = self.dataset_build_timeout_seconds
if timeout_seconds is None:
return source()

result_queue: queue.SimpleQueue[Dataset | BaseException] = queue.SimpleQueue()

def build_dataset() -> None:
try:
result_queue.put(source())
except BaseException as exc: # pragma: no cover - exercised via caller
result_queue.put(exc)

builder_thread = threading.Thread(
target=build_dataset,
name=f"dataset-builder-{source_name.replace(' ', '-')}",
daemon=True,
)
builder_thread.start()
builder_thread.join(timeout_seconds)

if builder_thread.is_alive():
raise DatasetBuildError(
f"Building {source_name} for environment '{self.env_id or self.__class__.__name__}' "
f"timed out after {print_time(timeout_seconds)}. "
f"Check dataset access and network reachability, or increase "
f"{_DATASET_BUILD_TIMEOUT_ENV_VAR}."
)

result = result_queue.get()
if isinstance(result, BaseException):
if isinstance(result, DatasetBuildError):
raise result
raise DatasetBuildError(
f"Failed to build {source_name} for environment '{self.env_id or self.__class__.__name__}': {result}"
) from result

return result

def build_dataset(self) -> Dataset | None:
"""Build and cache the training dataset from source if needed."""
if self.dataset is not None:
return self.dataset
if self.dataset_source is None:
return None
built = self.dataset_source()
built = self._build_dataset_from_source(
self.dataset_source,
source_name="training dataset",
)
self.dataset = self._format_dataset_source(built)
return self.dataset

Expand All @@ -409,7 +496,10 @@ def build_eval_dataset(self) -> Dataset | None:
return self.eval_dataset
if self.eval_dataset_source is None:
return None
built = self.eval_dataset_source()
built = self._build_dataset_from_source(
self.eval_dataset_source,
source_name="evaluation dataset",
)
self.eval_dataset = self._format_dataset_source(built)
return self.eval_dataset

Expand Down
4 changes: 4 additions & 0 deletions verifiers/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,10 @@ async def run_evaluation(

results_path = config.resume_path or get_eval_results_path(config)

logger.info(f"Preparing evaluation dataset for {config.env_id}")
vf_env.get_eval_dataset(n=1)
logger.info(f"Evaluation dataset ready for {config.env_id}")

try:
if not config.disable_env_server:
extra_env_kwargs = dict(config.extra_env_kwargs)
Expand Down
Loading