Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions assets/lab/environments/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class MyGameEnv(vf.MultiTurnEnv):
return state.get("lives", 1) <= 0
```

`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, and `max_turns` by default.
`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, and `max_total_completion_tokens` by default.

Execution order can be controlled with `priority` (higher runs first). This is useful for checking cheap conditions before expensive ones:

Expand Down Expand Up @@ -891,7 +891,11 @@ These require additional dependencies installed via extras (e.g., `uv add 'verif
Newer and more experimental environment classes include:

- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling (like `sandbox_client_max_workers`) parameters via `SandboxMixin`. Subclasses can override `get_sandbox_resources(state)` for per-instance resource allocation and `build_env_vars(state)` for custom environment variables (`PROTECTED_ENV_VARS` cannot be overridden). VMs are auto-enabled when `gpu_count > 0`
- **`ComposableEnv`** — `CliAgentEnv` subclass that separates *what to solve* (`TaskSet`) from *how to solve it* (`Harness`). Wire a task collection and an agent config together with zero subclassing. Delegates sandbox spec, instruction, setup, and env vars to the `TaskSet`; install script, run command, and system prompt to the `Harness`. Scoring is owned by per-taskset rubrics
- **`TaskSet`** / **`SandboxTaskSet`** — define task collections. `SandboxTaskSet` adds `SandboxSpec` (image, CPU, memory, GPU, timeout) per instance, a `setup(state)` hook, and `validate_instance(state)` for gold-patch validation. Key methods: `get_instruction(info)`, `get_rubric()`, `get_sandbox_spec(info)`, `get_env_vars()`. Includes `validate(n, concurrency)` for bulk validation and `filter()`/`take()` combinators
- **`Harness`** — agent-side config dataclass: `install_script`, `run_command`, `system_prompt`, `system_prompt_path`, `instruction_path`, `log_path`
- **`SandboxSpec`** — per-instance sandbox requirements: `image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls
- **`OpenCodeEnv`** — runs [OpenCode](https://opencode.ai) CLI agents inside sandboxes with API call interception
Expand Down
8 changes: 6 additions & 2 deletions environments/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class MyGameEnv(vf.MultiTurnEnv):
return state.get("lives", 1) <= 0
```

`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, and `max_turns` by default.
`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, and `max_total_completion_tokens` by default.

Execution order can be controlled with `priority` (higher runs first). This is useful for checking cheap conditions before expensive ones:

Expand Down Expand Up @@ -891,7 +891,11 @@ These require additional dependencies installed via extras (e.g., `uv add 'verif
Newer and more experimental environment classes include:

- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling (like `sandbox_client_max_workers`) parameters via `SandboxMixin`. Subclasses can override `get_sandbox_resources(state)` for per-instance resource allocation and `build_env_vars(state)` for custom environment variables (`PROTECTED_ENV_VARS` cannot be overridden). VMs are auto-enabled when `gpu_count > 0`
- **`ComposableEnv`** — `CliAgentEnv` subclass that separates *what to solve* (`TaskSet`) from *how to solve it* (`Harness`). Wire a task collection and an agent config together with zero subclassing. Delegates sandbox spec, instruction, setup, and env vars to the `TaskSet`; install script, run command, and system prompt to the `Harness`. Scoring is owned by per-taskset rubrics
- **`TaskSet`** / **`SandboxTaskSet`** — define task collections. `SandboxTaskSet` adds `SandboxSpec` (image, CPU, memory, GPU, timeout) per instance, a `setup(state)` hook, and `validate_instance(state)` for gold-patch validation. Key methods: `get_instruction(info)`, `get_rubric()`, `get_sandbox_spec(info)`, `get_env_vars()`. Includes `validate(n, concurrency)` for bulk validation and `filter()`/`take()` combinators
- **`Harness`** — agent-side config dataclass: `install_script`, `run_command`, `system_prompt`, `system_prompt_path`, `instruction_path`, `log_path`
- **`SandboxSpec`** — per-instance sandbox requirements: `image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls
- **`OpenCodeEnv`** — runs [OpenCode](https://opencode.ai) CLI agents inside sandboxes with API call interception
Expand Down
135 changes: 135 additions & 0 deletions tests/test_interception_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Tests for per-rollout authentication on the interception server.

Verifies that:
- Requests with valid tokens are accepted
- Requests with invalid/missing tokens are rejected (401)
- Unregistered rollout IDs are rejected (404)
- Graceful fallback: rollouts registered without a token skip auth
"""

import asyncio

import pytest
from aiohttp import ClientSession

from verifiers.utils.interception_utils import (
InterceptionServer,
generate_interception_token,
)


@pytest.fixture
async def server():
srv = InterceptionServer(port=0)
await srv.start()
yield srv
await srv.stop()


def _chat_payload(content: str = "hello") -> dict:
return {
"model": "test-model",
"messages": [{"role": "user", "content": content}],
}


async def _post(
base: str, rollout_id: str, token: str | None = None, timeout: float = 0.5
):
"""POST to a rollout endpoint, return (status, body) or 'timeout'."""
headers = {}
if token is not None:
headers["Authorization"] = f"Bearer {token}"
try:
async with ClientSession() as session:
async with session.post(
f"{base}/rollout/{rollout_id}/v1/chat/completions",
json=_chat_payload(),
headers=headers,
timeout=__import__("aiohttp").ClientTimeout(total=timeout),
) as resp:
body = await resp.json()
return resp.status, body
except asyncio.TimeoutError:
return "accepted", None


@pytest.mark.asyncio
async def test_valid_token_accepted(server: InterceptionServer):
"""Request with the correct bearer token is accepted."""
token = generate_interception_token()
server.register_rollout("rollout_auth_ok", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

result = await _post(base, "rollout_auth_ok", token=token)
# "accepted" means the server didn't reject — it's waiting for a model response
assert result[0] == "accepted" or result[0] == 200

server.unregister_rollout("rollout_auth_ok")


@pytest.mark.asyncio
async def test_missing_token_rejected(server: InterceptionServer):
"""Request with no Authorization header is rejected when auth is configured."""
token = generate_interception_token()
server.register_rollout("rollout_no_token", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(base, "rollout_no_token", token=None)
assert status == 401
assert body["error"] == "Unauthorized"

server.unregister_rollout("rollout_no_token")


@pytest.mark.asyncio
async def test_wrong_token_rejected(server: InterceptionServer):
"""Request with an incorrect bearer token is rejected."""
token = generate_interception_token()
server.register_rollout("rollout_bad_token", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(base, "rollout_bad_token", token="wrong-token")
assert status == 401
assert body["error"] == "Unauthorized"

server.unregister_rollout("rollout_bad_token")


@pytest.mark.asyncio
async def test_unknown_rollout_404(server: InterceptionServer):
"""Request to a non-existent rollout ID returns 404."""
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(base, "rollout_nonexistent", token=None)
assert status == 404


@pytest.mark.asyncio
async def test_no_token_graceful_fallback(server: InterceptionServer):
"""Rollout registered without a token accepts any request (backwards compat)."""
server.register_rollout("rollout_no_auth")
base = f"http://127.0.0.1:{server.port}"

result = await _post(base, "rollout_no_auth", token=None)
# Should be accepted (not 401), waiting for model response
assert result[0] == "accepted" or result[0] == 200

server.unregister_rollout("rollout_no_auth")


@pytest.mark.asyncio
async def test_cross_rollout_blocked(server: InterceptionServer):
"""Token for rollout A cannot be used to access rollout B."""
token_a = generate_interception_token()
token_b = generate_interception_token()
server.register_rollout("rollout_a", auth_token=token_a)
server.register_rollout("rollout_b", auth_token=token_b)
base = f"http://127.0.0.1:{server.port}"

# Use A's token to access B's endpoint
status, body = await _post(base, "rollout_b", token=token_a)
assert status == 401, "Cross-rollout access should be rejected"

server.unregister_rollout("rollout_a")
server.unregister_rollout("rollout_b")
11 changes: 10 additions & 1 deletion verifiers/envs/experimental/cli_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from verifiers.utils.interception_utils import (
InterceptionServer,
deliver_response,
generate_interception_token,
synthesize_stream,
)
from verifiers.utils.logging_utils import print_time, truncate
Expand Down Expand Up @@ -188,6 +189,8 @@ async def setup_state(self, state: State) -> State:

rollout_id = f"rollout_{uuid.uuid4().hex[:8]}"
state["rollout_id"] = rollout_id
auth_token = generate_interception_token()
state["interception_auth_token"] = auth_token
Comment thread
cursor[bot] marked this conversation as resolved.

interception_server = self._require_interception_server()
await interception_server.start()
Expand Down Expand Up @@ -227,7 +230,9 @@ async def setup_state(self, state: State) -> State:
await self.create_sandbox(state, sandbox_request)

# Register rollout for interception
request_id_queue = interception_server.register_rollout(rollout_id)
request_id_queue = interception_server.register_rollout(
rollout_id, auth_token=auth_token
)
state["request_id_queue"] = request_id_queue
state["agent_completed"] = False

Expand Down Expand Up @@ -261,6 +266,7 @@ def get_sandbox_resources(self, state: State) -> dict[str, Any]:
PROTECTED_ENV_VARS = frozenset(
{
"OPENAI_BASE_URL",
"OPENAI_API_KEY",
"OPENAI_TIMEOUT",
"OPENAI_REQUEST_TIMEOUT",
"HTTPX_TIMEOUT",
Expand All @@ -272,6 +278,9 @@ async def build_env_vars(self, state: State) -> dict[str, str]:
"""Build environment variables for the sandbox. Override to add custom vars."""
env_vars = dict(self.environment_vars) if self.environment_vars else {}
env_vars["OPENAI_BASE_URL"] = state["interception_base_url"]
auth_token = state.get("interception_auth_token")
if auth_token:
env_vars["OPENAI_API_KEY"] = auth_token
env_vars.setdefault("OPENAI_TIMEOUT", "3600")
env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "3600")
env_vars.setdefault("HTTPX_TIMEOUT", "3600")
Expand Down
31 changes: 31 additions & 0 deletions verifiers/envs/experimental/rlm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import logging
import os
import re
import secrets
import shlex
import shutil
import sys
Expand Down Expand Up @@ -554,6 +555,8 @@ def _build_python_worker_script_template() -> str:
"",
'ROOT_TOOL_URL = os.environ.get("RLM_ROOT_TOOL_URL", "")',
'ROOT_TOOL_NAMES_RAW = os.environ.get("RLM_ROOT_TOOL_NAMES", "[]")',
'_RLM_AUTH_TOKEN = os.environ.get("RLM_AUTH_TOKEN", "")',
'_RLM_AUTH_HEADERS = {"Authorization": f"Bearer {_RLM_AUTH_TOKEN}"} if _RLM_AUTH_TOKEN else {}',
"try:",
" ROOT_TOOL_NAMES = json.loads(ROOT_TOOL_NAMES_RAW)",
"except Exception:",
Expand All @@ -572,6 +575,7 @@ def _build_python_worker_script_template() -> str:
" resp = requests.post(",
" ROOT_TOOL_URL,",
" json=payload,",
" headers=_RLM_AUTH_HEADERS,",
" timeout=SUB_LLM_TIMEOUT,",
" )",
" resp.raise_for_status()",
Expand Down Expand Up @@ -2619,6 +2623,7 @@ def _build_worker_env_vars(self, state: State) -> dict[str, str]:
"RLM_ROOT_TOOL_URL": state.get("root_tool_url", ""),
"RLM_ROOT_TOOL_NAMES": json.dumps(self.root_tool_names),
"RLM_SUB_LLM_TIMEOUT": str(self.sub_llm_timeout),
"RLM_AUTH_TOKEN": state.get("interception_auth_token", ""),
Comment thread
cursor[bot] marked this conversation as resolved.
}

def _compute_fs_metadata(
Expand Down Expand Up @@ -3237,6 +3242,10 @@ async def _handle_root_tool_request(self, request: Any) -> Any:
if not context:
return web.json_response({"error": "Rollout not found"}, status=404)

auth_error = self._check_rollout_auth(request, context)
if auth_error is not None:
return auth_error
Comment thread
cursor[bot] marked this conversation as resolved.

try:
request_body = await request.json()
except Exception as e:
Expand Down Expand Up @@ -3314,13 +3323,31 @@ async def _handle_root_tool_request(self, request: Any) -> Any:
response_body["result_repr"] = repr(result_value)
return web.json_response(response_body)

def _check_rollout_auth(self, request: Any, context: dict) -> web.Response | None:
"""Check bearer token for a rollout request. Returns error response or None."""
expected_token = context.get("auth_token")
if expected_token is not None:
auth_header = request.headers.get("Authorization", "")
bearer_token = (
auth_header.removeprefix("Bearer ")
if auth_header.startswith("Bearer ")
else ""
)
if not secrets.compare_digest(bearer_token, expected_token):
return web.json_response({"error": "Unauthorized"}, status=401)
return None
Comment thread
cursor[bot] marked this conversation as resolved.

async def _handle_sub_llm_request(self, request: Any) -> Any:
"""Handle sub-LLM requests from worker code."""
rollout_id = request.match_info["rollout_id"]
context = self.active_rollouts.get(rollout_id)
if not context:
return web.json_response({"error": "Rollout not found"}, status=404)

auth_error = self._check_rollout_auth(request, context)
if auth_error is not None:
return auth_error

try:
request_body = await request.json()
except Exception as e:
Expand Down Expand Up @@ -3452,11 +3479,15 @@ async def _setup_interception_and_register(
state["interception_url"] = interception_url
state["root_tool_url"] = root_tool_url

auth_token = secrets.token_hex(32)
state["interception_auth_token"] = auth_token

self.active_rollouts[rollout_id] = {
"client": state.get("client"),
"model": state.get("model"),
"sub_model": self.sub_model or state.get("model"),
"state": state,
"auth_token": auth_token,
}
return state

Expand Down
22 changes: 21 additions & 1 deletion verifiers/utils/interception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import json
import logging
import secrets
import time
import uuid
from typing import Any, cast
Expand Down Expand Up @@ -99,10 +100,13 @@ async def stop(self) -> None:
self._site = None
self._app = None

def register_rollout(self, rollout_id: str) -> asyncio.Queue:
def register_rollout(
self, rollout_id: str, auth_token: str | None = None
) -> asyncio.Queue:
request_queue: asyncio.Queue = asyncio.Queue()
self.active_rollouts[rollout_id] = {
"request_id_queue": request_queue,
"auth_token": auth_token,
}
return request_queue

Expand Down Expand Up @@ -133,6 +137,17 @@ async def _handle_request(self, request: Any) -> Any:
if not context:
return web.json_response({"error": "Rollout not found"}, status=404)

expected_token = context.get("auth_token")
if expected_token is not None:
auth_header = request.headers.get("Authorization", "")
bearer_token = (
auth_header.removeprefix("Bearer ")
if auth_header.startswith("Bearer ")
else ""
)
if not secrets.compare_digest(bearer_token, expected_token):
return web.json_response({"error": "Unauthorized"}, status=401)

try:
request_body = await request.json()
except Exception as e:
Expand Down Expand Up @@ -230,6 +245,11 @@ async def _handle_streaming_response(
return response


def generate_interception_token() -> str:
"""Generate a cryptographically random token for rollout authentication."""
return secrets.token_hex(32)


def deliver_response(
intercept: dict,
response: Response | ChatCompletion | None,
Expand Down
Loading