Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions assets/lab/environments/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class MyGameEnv(vf.MultiTurnEnv):
return state.get("lives", 1) <= 0
```

`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, and `max_turns` by default.
`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, and `max_total_completion_tokens` by default.

Execution order can be controlled with `priority` (higher runs first). This is useful for checking cheap conditions before expensive ones:

Expand Down Expand Up @@ -891,7 +891,11 @@ These require additional dependencies installed via extras (e.g., `uv add 'verif
Newer and more experimental environment classes include:

- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling (like `sandbox_client_max_workers`) parameters via `SandboxMixin`. Subclasses can override `get_sandbox_resources(state)` for per-instance resource allocation and `build_env_vars(state)` for custom environment variables (`PROTECTED_ENV_VARS` cannot be overridden). VMs are auto-enabled when `gpu_count > 0`
- **`ComposableEnv`** — `CliAgentEnv` subclass that separates *what to solve* (`TaskSet`) from *how to solve it* (`Harness`). Wire a task collection and an agent config together with zero subclassing. Delegates sandbox spec, instruction, setup, and env vars to the `TaskSet`; install script, run command, and system prompt to the `Harness`. Scoring is owned by per-taskset rubrics
- **`TaskSet`** / **`SandboxTaskSet`** — define task collections. `SandboxTaskSet` adds `SandboxSpec` (image, CPU, memory, GPU, timeout) per instance, a `setup(state)` hook, and `validate_instance(state)` for gold-patch validation. Key methods: `get_instruction(info)`, `get_rubric()`, `get_sandbox_spec(info)`, `get_env_vars()`. Includes `validate(n, concurrency)` for bulk validation and `filter()`/`take()` combinators
- **`Harness`** — agent-side config dataclass: `install_script`, `run_command`, `system_prompt`, `system_prompt_path`, `instruction_path`, `log_path`
- **`SandboxSpec`** — per-instance sandbox requirements: `image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls
- **`OpenCodeEnv`** — runs [OpenCode](https://opencode.ai) CLI agents inside sandboxes with API call interception
Expand Down
8 changes: 6 additions & 2 deletions environments/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class MyGameEnv(vf.MultiTurnEnv):
return state.get("lives", 1) <= 0
```

`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, and `max_turns` by default.
`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, and `max_total_completion_tokens` by default.

Execution order can be controlled with `priority` (higher runs first). This is useful for checking cheap conditions before expensive ones:

Expand Down Expand Up @@ -891,7 +891,11 @@ These require additional dependencies installed via extras (e.g., `uv add 'verif
Newer and more experimental environment classes include:

- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling (like `sandbox_client_max_workers`) parameters via `SandboxMixin`. Subclasses can override `get_sandbox_resources(state)` for per-instance resource allocation and `build_env_vars(state)` for custom environment variables (`PROTECTED_ENV_VARS` cannot be overridden). VMs are auto-enabled when `gpu_count > 0`
- **`ComposableEnv`** — `CliAgentEnv` subclass that separates *what to solve* (`TaskSet`) from *how to solve it* (`Harness`). Wire a task collection and an agent config together with zero subclassing. Delegates sandbox spec, instruction, setup, and env vars to the `TaskSet`; install script, run command, and system prompt to the `Harness`. Scoring is owned by per-taskset rubrics
- **`TaskSet`** / **`SandboxTaskSet`** — define task collections. `SandboxTaskSet` adds `SandboxSpec` (image, CPU, memory, GPU, timeout) per instance, a `setup(state)` hook, and `validate_instance(state)` for gold-patch validation. Key methods: `get_instruction(info)`, `get_rubric()`, `get_sandbox_spec(info)`, `get_env_vars()`. Includes `validate(n, concurrency)` for bulk validation and `filter()`/`take()` combinators
- **`Harness`** — agent-side config dataclass: `install_script`, `run_command`, `system_prompt`, `system_prompt_path`, `instruction_path`, `log_path`
- **`SandboxSpec`** — per-instance sandbox requirements: `image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls
- **`OpenCodeEnv`** — runs [OpenCode](https://opencode.ai) CLI agents inside sandboxes with API call interception
Expand Down
178 changes: 178 additions & 0 deletions tests/test_interception_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Tests for per-rollout authentication on the interception server.

Verifies that:
- Requests with valid tokens are accepted
- Requests with invalid/missing tokens are rejected (401)
- Unregistered rollout IDs are rejected (404)
- Graceful fallback: rollouts registered without a token skip auth
"""

import asyncio
from typing import Any

import pytest
from aiohttp import ClientSession

from verifiers.utils.interception_utils import (
InterceptionServer,
generate_interception_token,
)


@pytest.fixture
async def server():
srv = InterceptionServer(port=0)
await srv.start()
yield srv
await srv.stop()


def _chat_payload(content: str = "hello") -> dict:
return {
"model": "test-model",
"messages": [{"role": "user", "content": content}],
}


async def _post(
base: str,
rollout_id: str,
token: str | None = None,
timeout: float = 0.5,
payload: Any | None = None,
):
"""POST to a rollout endpoint, return (status, body) or 'timeout'."""
headers = {}
if token is not None:
headers["Authorization"] = f"Bearer {token}"
try:
async with ClientSession() as session:
async with session.post(
f"{base}/rollout/{rollout_id}/v1/chat/completions",
json=_chat_payload() if payload is None else payload,
headers=headers,
timeout=__import__("aiohttp").ClientTimeout(total=timeout),
) as resp:
body = await resp.json()
return resp.status, body
except asyncio.TimeoutError:
return "accepted", None


@pytest.mark.asyncio
async def test_valid_token_accepted(server: InterceptionServer):
"""Request with the correct bearer token is accepted."""
token = generate_interception_token()
server.register_rollout("rollout_auth_ok", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

result = await _post(base, "rollout_auth_ok", token=token)
# "accepted" means the server didn't reject — it's waiting for a model response
assert result[0] == "accepted" or result[0] == 200

server.unregister_rollout("rollout_auth_ok")


@pytest.mark.asyncio
async def test_missing_token_rejected(server: InterceptionServer):
"""Request with no Authorization header is rejected when auth is configured."""
token = generate_interception_token()
server.register_rollout("rollout_no_token", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(base, "rollout_no_token", token=None)
assert status == 401
assert body["error"] == "Unauthorized"

server.unregister_rollout("rollout_no_token")


@pytest.mark.asyncio
async def test_wrong_token_rejected(server: InterceptionServer):
"""Request with an incorrect bearer token is rejected."""
token = generate_interception_token()
server.register_rollout("rollout_bad_token", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(base, "rollout_bad_token", token="wrong-token")
assert status == 401
assert body["error"] == "Unauthorized"

server.unregister_rollout("rollout_bad_token")


@pytest.mark.asyncio
async def test_unknown_rollout_404(server: InterceptionServer):
"""Request to a non-existent rollout ID returns 404."""
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(base, "rollout_nonexistent", token=None)
assert status == 404


@pytest.mark.asyncio
async def test_no_token_graceful_fallback(server: InterceptionServer):
"""Rollout registered without a token accepts any request (backwards compat)."""
server.register_rollout("rollout_no_auth")
base = f"http://127.0.0.1:{server.port}"

result = await _post(base, "rollout_no_auth", token=None)
# Should be accepted (not 401), waiting for model response
assert result[0] == "accepted" or result[0] == 200

server.unregister_rollout("rollout_no_auth")


@pytest.mark.asyncio
async def test_cross_rollout_blocked(server: InterceptionServer):
"""Token for rollout A cannot be used to access rollout B."""
token_a = generate_interception_token()
token_b = generate_interception_token()
server.register_rollout("rollout_a", auth_token=token_a)
server.register_rollout("rollout_b", auth_token=token_b)
base = f"http://127.0.0.1:{server.port}"

# Use A's token to access B's endpoint
status, body = await _post(base, "rollout_b", token=token_a)
assert status == 401, "Cross-rollout access should be rejected"

server.unregister_rollout("rollout_a")
server.unregister_rollout("rollout_b")


@pytest.mark.asyncio
async def test_missing_messages_rejected(server: InterceptionServer):
"""Authenticated requests without messages return a 400 instead of crashing."""
token = generate_interception_token()
server.register_rollout("rollout_missing_messages", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(
base,
"rollout_missing_messages",
token=token,
payload={"model": "test-model"},
)
assert status == 400
assert body["error"] == "Request body must include 'messages'"

server.unregister_rollout("rollout_missing_messages")


@pytest.mark.asyncio
async def test_non_object_body_rejected(server: InterceptionServer):
"""Authenticated requests must send a JSON object."""
token = generate_interception_token()
server.register_rollout("rollout_non_object", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(
base,
"rollout_non_object",
token=token,
payload=["not", "an", "object"],
)
assert status == 400
assert body["error"] == "Request body must be a JSON object"

server.unregister_rollout("rollout_non_object")
23 changes: 23 additions & 0 deletions tests/test_interception_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from verifiers.types import (
Response,
ResponseMessage,
Expand All @@ -7,6 +9,7 @@
)
from verifiers.utils.interception_utils import (
create_empty_completion,
has_valid_bearer_auth,
serialize_intercept_response,
)

Expand Down Expand Up @@ -61,3 +64,23 @@ def test_serialize_intercept_response_passthrough_native_chat_completion():
assert payload["object"] == "chat.completion"
assert payload["model"] == "native-model"
assert len(payload["choices"]) == 1


def test_serialize_intercept_response_rejects_invalid_input_type():
with pytest.raises(TypeError, match="Unsupported intercepted response type: str"):
serialize_intercept_response("0")


@pytest.mark.parametrize(
("auth_header", "expected_token", "expected_valid"),
[
("Bearer secret-token", "secret-token", True),
("Bearer wrong-token", "secret-token", False),
("", "secret-token", False),
("", None, True),
],
)
def test_has_valid_bearer_auth(
auth_header: str, expected_token: str | None, expected_valid: bool
):
assert has_valid_bearer_auth(auth_header, expected_token) is expected_valid
33 changes: 33 additions & 0 deletions tests/test_rlm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,13 +645,16 @@ def _run_helper(
argv: list[str],
stdin_data: str = "",
response_data: dict | None = None,
env_overrides: dict[str, str] | None = None,
) -> tuple[str, str, int, dict | None]:
helper_source = extract_bash_helper_source()
stdout_buffer = io.StringIO()
stderr_buffer = io.StringIO()
env = {
"RLM_ROOT_TOOL_URL": "http://example.invalid/",
}
if env_overrides:
env.update(env_overrides)
captured_payload: dict | None = None
with patch("urllib.request.urlopen") as mock_urlopen:

Expand All @@ -664,6 +667,7 @@ def _capture_request(req, timeout=300):
"tool_name": data.get("tool_name"),
"args": args,
"kwargs": kwargs,
"headers": dict(req.header_items()),
}
return response

Expand Down Expand Up @@ -819,6 +823,35 @@ def test_llm_batch_output_json(self):
parsed = json.loads(stdout.strip())
assert parsed == ["first", "second"]

def test_root_tool_auth_header_added_when_token_present(self):
stdout, stderr, code, captured = self._run_helper(
["--tool", "other_tool", "--json", json.dumps({"args": [1]})],
env_overrides={"RLM_AUTH_TOKEN": "secret-token"},
)
assert code == 0
assert stderr == ""
assert "ok" in stdout
assert captured is not None
assert captured["headers"]["Authorization"] == "Bearer secret-token"


class TestRLMAuthCheck:
def test_check_rollout_auth_accepts_matching_bearer_token(self):
env = build_env(make_dataset({}))
request = MagicMock(headers={"Authorization": "Bearer secret-token"})

assert env._check_rollout_auth(request, {"auth_token": "secret-token"}) is None

def test_check_rollout_auth_rejects_wrong_bearer_token(self):
env = build_env(make_dataset({}))
request = MagicMock(headers={"Authorization": "Bearer wrong-token"})

response = env._check_rollout_auth(request, {"auth_token": "secret-token"})

assert response is not None
assert response.status == 401
assert json.loads(response.text)["error"] == "Unauthorized"


# =============================================================================
# 3. Initialization and Configuration
Expand Down
11 changes: 10 additions & 1 deletion verifiers/envs/experimental/cli_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from verifiers.utils.interception_utils import (
InterceptionServer,
deliver_response,
generate_interception_token,
synthesize_stream,
)
from verifiers.utils.logging_utils import print_time, truncate
Expand Down Expand Up @@ -188,6 +189,8 @@ async def setup_state(self, state: State) -> State:

rollout_id = f"rollout_{uuid.uuid4().hex[:8]}"
state["rollout_id"] = rollout_id
auth_token = generate_interception_token()
state["interception_auth_token"] = auth_token
Comment thread
cursor[bot] marked this conversation as resolved.

interception_server = self._require_interception_server()
await interception_server.start()
Expand Down Expand Up @@ -227,7 +230,9 @@ async def setup_state(self, state: State) -> State:
await self.create_sandbox(state, sandbox_request)

# Register rollout for interception
request_id_queue = interception_server.register_rollout(rollout_id)
request_id_queue = interception_server.register_rollout(
rollout_id, auth_token=auth_token
)
state["request_id_queue"] = request_id_queue
state["agent_completed"] = False

Expand Down Expand Up @@ -261,6 +266,7 @@ def get_sandbox_resources(self, state: State) -> dict[str, Any]:
PROTECTED_ENV_VARS = frozenset(
{
"OPENAI_BASE_URL",
"OPENAI_API_KEY",
"OPENAI_TIMEOUT",
"OPENAI_REQUEST_TIMEOUT",
"HTTPX_TIMEOUT",
Expand All @@ -272,6 +278,9 @@ async def build_env_vars(self, state: State) -> dict[str, str]:
"""Build environment variables for the sandbox. Override to add custom vars."""
env_vars = dict(self.environment_vars) if self.environment_vars else {}
env_vars["OPENAI_BASE_URL"] = state["interception_base_url"]
auth_token = state.get("interception_auth_token")
if auth_token:
env_vars["OPENAI_API_KEY"] = auth_token
env_vars.setdefault("OPENAI_TIMEOUT", "3600")
env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "3600")
env_vars.setdefault("HTTPX_TIMEOUT", "3600")
Expand Down
2 changes: 1 addition & 1 deletion verifiers/envs/experimental/opencode_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def build_opencode_config(
"name": "${OPENAI_MODEL%%/*}",
"options": {
"baseURL": "$OPENAI_BASE_URL",
"apiKey": "intercepted",
"apiKey": "${OPENAI_API_KEY:-intercepted}",
"timeout": self.provider_timeout_ms,
},
"models": {
Expand Down
Loading
Loading