Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions assets/lab/environments/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class MyGameEnv(vf.MultiTurnEnv):
return state.get("lives", 1) <= 0
```

`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, and `max_turns` by default.
`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, and `max_total_completion_tokens` by default.

Execution order can be controlled with `priority` (higher runs first). This is useful for checking cheap conditions before expensive ones:

Expand Down Expand Up @@ -891,7 +891,11 @@ These require additional dependencies installed via extras (e.g., `uv add 'verif
Newer and more experimental environment classes include:

- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling (like `sandbox_client_max_workers`) parameters via `SandboxMixin`. Subclasses can override `get_sandbox_resources(state)` for per-instance resource allocation and `build_env_vars(state)` for custom environment variables (`PROTECTED_ENV_VARS` cannot be overridden). VMs are auto-enabled when `gpu_count > 0`
- **`ComposableEnv`** — `CliAgentEnv` subclass that separates *what to solve* (`TaskSet`) from *how to solve it* (`Harness`). Wire a task collection and an agent config together with zero subclassing. Delegates sandbox spec, instruction, setup, and env vars to the `TaskSet`; install script, run command, and system prompt to the `Harness`. Scoring is owned by per-taskset rubrics
- **`TaskSet`** / **`SandboxTaskSet`** — define task collections. `SandboxTaskSet` adds `SandboxSpec` (image, CPU, memory, GPU, timeout) per instance, a `setup(state)` hook, and `validate_instance(state)` for gold-patch validation. Key methods: `get_instruction(info)`, `get_rubric()`, `get_sandbox_spec(info)`, `get_env_vars()`. Includes `validate(n, concurrency)` for bulk validation and `filter()`/`take()` combinators
- **`Harness`** — agent-side config dataclass: `install_script`, `run_command`, `system_prompt`, `system_prompt_path`, `instruction_path`, `log_path`
- **`SandboxSpec`** — per-instance sandbox requirements: `image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls
- **`OpenCodeEnv`** — runs [OpenCode](https://opencode.ai) CLI agents inside sandboxes with API call interception
Expand Down
8 changes: 6 additions & 2 deletions environments/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class MyGameEnv(vf.MultiTurnEnv):
return state.get("lives", 1) <= 0
```

`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, and `max_turns` by default.
`MultiTurnEnv` includes built-in stop conditions for errors, prompt length limits, `max_turns`, and `max_total_completion_tokens` by default.

Execution order can be controlled with `priority` (higher runs first). This is useful for checking cheap conditions before expensive ones:

Expand Down Expand Up @@ -891,7 +891,11 @@ These require additional dependencies installed via extras (e.g., `uv add 'verif
Newer and more experimental environment classes include:

- **`GymEnv`** — universal runner for Gym-compatible environments (OpenAI Gym / Gymnasium API)
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling ( like `sandbox_client_max_workers`) parameters via `SandboxMixin`
- **`CliAgentEnv`** — runs custom agent code inside sandboxes, intercepting API requests. Accepts sandbox configuration parameters including `docker_image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`, `environment_vars`, and `labels` for sandbox categorization. Also accepts retry tuning (like `max_retries`) and connection pooling (like `sandbox_client_max_workers`) parameters via `SandboxMixin`. Subclasses can override `get_sandbox_resources(state)` for per-instance resource allocation and `build_env_vars(state)` for custom environment variables (`PROTECTED_ENV_VARS` cannot be overridden). VMs are auto-enabled when `gpu_count > 0`
- **`ComposableEnv`** — `CliAgentEnv` subclass that separates *what to solve* (`TaskSet`) from *how to solve it* (`Harness`). Wire a task collection and an agent config together with zero subclassing. Delegates sandbox spec, instruction, setup, and env vars to the `TaskSet`; install script, run command, and system prompt to the `Harness`. Scoring is owned by per-taskset rubrics
- **`TaskSet`** / **`SandboxTaskSet`** — define task collections. `SandboxTaskSet` adds `SandboxSpec` (image, CPU, memory, GPU, timeout) per instance, a `setup(state)` hook, and `validate_instance(state)` for gold-patch validation. Key methods: `get_instruction(info)`, `get_rubric()`, `get_sandbox_spec(info)`, `get_env_vars()`. Includes `validate(n, concurrency)` for bulk validation and `filter()`/`take()` combinators
- **`Harness`** — agent-side config dataclass: `install_script`, `run_command`, `system_prompt`, `system_prompt_path`, `instruction_path`, `log_path`
- **`SandboxSpec`** — per-instance sandbox requirements: `image`, `cpu_cores`, `memory_gb`, `disk_size_gb`, `gpu_count`, `gpu_type`, `timeout_minutes`
- **`HarborEnv`** — loads Harbor-format agent benchmark tasks
- **`RLMEnv`** — implements [Recursive Language Models](https://alexzhang13.github.io/blog/2025/rlm/) for unbounded context processing via REPL-based decomposition and recursive sub-LLM calls
- **`OpenCodeEnv`** — runs [OpenCode](https://opencode.ai) CLI agents inside sandboxes with API call interception
Expand Down
178 changes: 178 additions & 0 deletions tests/test_interception_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Tests for per-rollout authentication on the interception server.

Verifies that:
- Requests with valid tokens are accepted
- Requests with invalid/missing tokens are rejected (401)
- Unregistered rollout IDs are rejected (404)
- Graceful fallback: rollouts registered without a token skip auth
"""

import asyncio
from typing import Any

import pytest
from aiohttp import ClientSession

from verifiers.utils.interception_utils import (
InterceptionServer,
generate_interception_token,
)


@pytest.fixture
async def server():
srv = InterceptionServer(port=0)
await srv.start()
yield srv
await srv.stop()


def _chat_payload(content: str = "hello") -> dict:
return {
"model": "test-model",
"messages": [{"role": "user", "content": content}],
}


async def _post(
base: str,
rollout_id: str,
token: str | None = None,
timeout: float = 0.5,
payload: Any | None = None,
):
"""POST to a rollout endpoint, return (status, body) or 'timeout'."""
headers = {}
if token is not None:
headers["Authorization"] = f"Bearer {token}"
try:
async with ClientSession() as session:
async with session.post(
f"{base}/rollout/{rollout_id}/v1/chat/completions",
json=_chat_payload() if payload is None else payload,
headers=headers,
timeout=__import__("aiohttp").ClientTimeout(total=timeout),
) as resp:
body = await resp.json()
return resp.status, body
except asyncio.TimeoutError:
return "accepted", None


@pytest.mark.asyncio
async def test_valid_token_accepted(server: InterceptionServer):
"""Request with the correct bearer token is accepted."""
token = generate_interception_token()
server.register_rollout("rollout_auth_ok", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

result = await _post(base, "rollout_auth_ok", token=token)
# "accepted" means the server didn't reject — it's waiting for a model response
assert result[0] == "accepted" or result[0] == 200

server.unregister_rollout("rollout_auth_ok")


@pytest.mark.asyncio
async def test_missing_token_rejected(server: InterceptionServer):
"""Request with no Authorization header is rejected when auth is configured."""
token = generate_interception_token()
server.register_rollout("rollout_no_token", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(base, "rollout_no_token", token=None)
assert status == 401
assert body["error"] == "Unauthorized"

server.unregister_rollout("rollout_no_token")


@pytest.mark.asyncio
async def test_wrong_token_rejected(server: InterceptionServer):
"""Request with an incorrect bearer token is rejected."""
token = generate_interception_token()
server.register_rollout("rollout_bad_token", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(base, "rollout_bad_token", token="wrong-token")
assert status == 401
assert body["error"] == "Unauthorized"

server.unregister_rollout("rollout_bad_token")


@pytest.mark.asyncio
async def test_unknown_rollout_404(server: InterceptionServer):
"""Request to a non-existent rollout ID returns 404."""
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(base, "rollout_nonexistent", token=None)
assert status == 404


@pytest.mark.asyncio
async def test_no_token_graceful_fallback(server: InterceptionServer):
"""Rollout registered without a token accepts any request (backwards compat)."""
server.register_rollout("rollout_no_auth")
base = f"http://127.0.0.1:{server.port}"

result = await _post(base, "rollout_no_auth", token=None)
# Should be accepted (not 401), waiting for model response
assert result[0] == "accepted" or result[0] == 200

server.unregister_rollout("rollout_no_auth")


@pytest.mark.asyncio
async def test_cross_rollout_blocked(server: InterceptionServer):
"""Token for rollout A cannot be used to access rollout B."""
token_a = generate_interception_token()
token_b = generate_interception_token()
server.register_rollout("rollout_a", auth_token=token_a)
server.register_rollout("rollout_b", auth_token=token_b)
base = f"http://127.0.0.1:{server.port}"

# Use A's token to access B's endpoint
status, body = await _post(base, "rollout_b", token=token_a)
assert status == 401, "Cross-rollout access should be rejected"

server.unregister_rollout("rollout_a")
server.unregister_rollout("rollout_b")


@pytest.mark.asyncio
async def test_missing_messages_rejected(server: InterceptionServer):
"""Authenticated requests without messages return a 400 instead of crashing."""
token = generate_interception_token()
server.register_rollout("rollout_missing_messages", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(
base,
"rollout_missing_messages",
token=token,
payload={"model": "test-model"},
)
assert status == 400
assert body["error"] == "Request body must include 'messages'"

server.unregister_rollout("rollout_missing_messages")


@pytest.mark.asyncio
async def test_non_object_body_rejected(server: InterceptionServer):
"""Authenticated requests must send a JSON object."""
token = generate_interception_token()
server.register_rollout("rollout_non_object", auth_token=token)
base = f"http://127.0.0.1:{server.port}"

status, body = await _post(
base,
"rollout_non_object",
token=token,
payload=["not", "an", "object"],
)
assert status == 400
assert body["error"] == "Request body must be a JSON object"

server.unregister_rollout("rollout_non_object")
23 changes: 23 additions & 0 deletions tests/test_interception_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from verifiers.types import (
Response,
ResponseMessage,
Expand All @@ -7,6 +9,7 @@
)
from verifiers.utils.interception_utils import (
create_empty_completion,
has_valid_bearer_auth,
serialize_intercept_response,
)

Expand Down Expand Up @@ -61,3 +64,23 @@ def test_serialize_intercept_response_passthrough_native_chat_completion():
assert payload["object"] == "chat.completion"
assert payload["model"] == "native-model"
assert len(payload["choices"]) == 1


def test_serialize_intercept_response_rejects_invalid_input_type():
with pytest.raises(TypeError, match="Unsupported intercepted response type: str"):
serialize_intercept_response("0")


@pytest.mark.parametrize(
("auth_header", "expected_token", "expected_valid"),
[
("Bearer secret-token", "secret-token", True),
("Bearer wrong-token", "secret-token", False),
("", "secret-token", False),
("", None, True),
],
)
def test_has_valid_bearer_auth(
auth_header: str, expected_token: str | None, expected_valid: bool
):
assert has_valid_bearer_auth(auth_header, expected_token) is expected_valid
33 changes: 33 additions & 0 deletions tests/test_rlm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,13 +645,16 @@ def _run_helper(
argv: list[str],
stdin_data: str = "",
response_data: dict | None = None,
env_overrides: dict[str, str] | None = None,
) -> tuple[str, str, int, dict | None]:
helper_source = extract_bash_helper_source()
stdout_buffer = io.StringIO()
stderr_buffer = io.StringIO()
env = {
"RLM_ROOT_TOOL_URL": "http://example.invalid/",
}
if env_overrides:
env.update(env_overrides)
captured_payload: dict | None = None
with patch("urllib.request.urlopen") as mock_urlopen:

Expand All @@ -664,6 +667,7 @@ def _capture_request(req, timeout=300):
"tool_name": data.get("tool_name"),
"args": args,
"kwargs": kwargs,
"headers": dict(req.header_items()),
}
return response

Expand Down Expand Up @@ -819,6 +823,35 @@ def test_llm_batch_output_json(self):
parsed = json.loads(stdout.strip())
assert parsed == ["first", "second"]

def test_root_tool_auth_header_added_when_token_present(self):
stdout, stderr, code, captured = self._run_helper(
["--tool", "other_tool", "--json", json.dumps({"args": [1]})],
env_overrides={"RLM_AUTH_TOKEN": "secret-token"},
)
assert code == 0
assert stderr == ""
assert "ok" in stdout
assert captured is not None
assert captured["headers"]["Authorization"] == "Bearer secret-token"


class TestRLMAuthCheck:
def test_check_rollout_auth_accepts_matching_bearer_token(self):
env = build_env(make_dataset({}))
request = MagicMock(headers={"Authorization": "Bearer secret-token"})

assert env._check_rollout_auth(request, {"auth_token": "secret-token"}) is None

def test_check_rollout_auth_rejects_wrong_bearer_token(self):
env = build_env(make_dataset({}))
request = MagicMock(headers={"Authorization": "Bearer wrong-token"})

response = env._check_rollout_auth(request, {"auth_token": "secret-token"})

assert response is not None
assert response.status == 401
assert json.loads(response.text)["error"] == "Unauthorized"


# =============================================================================
# 3. Initialization and Configuration
Expand Down
11 changes: 10 additions & 1 deletion verifiers/envs/experimental/cli_agent_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from verifiers.utils.interception_utils import (
InterceptionServer,
deliver_response,
generate_interception_token,
synthesize_stream,
)
from verifiers.utils.logging_utils import print_time, truncate
Expand Down Expand Up @@ -188,6 +189,8 @@ async def setup_state(self, state: State) -> State:

rollout_id = f"rollout_{uuid.uuid4().hex[:8]}"
state["rollout_id"] = rollout_id
auth_token = generate_interception_token()
state["interception_auth_token"] = auth_token
Comment thread
cursor[bot] marked this conversation as resolved.

interception_server = self._require_interception_server()
await interception_server.start()
Expand Down Expand Up @@ -227,7 +230,9 @@ async def setup_state(self, state: State) -> State:
await self.create_sandbox(state, sandbox_request)

# Register rollout for interception
request_id_queue = interception_server.register_rollout(rollout_id)
request_id_queue = interception_server.register_rollout(
rollout_id, auth_token=auth_token
)
state["request_id_queue"] = request_id_queue
state["agent_completed"] = False

Expand Down Expand Up @@ -261,6 +266,7 @@ def get_sandbox_resources(self, state: State) -> dict[str, Any]:
PROTECTED_ENV_VARS = frozenset(
{
"OPENAI_BASE_URL",
"OPENAI_API_KEY",
"OPENAI_TIMEOUT",
"OPENAI_REQUEST_TIMEOUT",
"HTTPX_TIMEOUT",
Expand All @@ -272,6 +278,9 @@ async def build_env_vars(self, state: State) -> dict[str, str]:
"""Build environment variables for the sandbox. Override to add custom vars."""
env_vars = dict(self.environment_vars) if self.environment_vars else {}
env_vars["OPENAI_BASE_URL"] = state["interception_base_url"]
auth_token = state.get("interception_auth_token")
if auth_token:
env_vars["OPENAI_API_KEY"] = auth_token
env_vars.setdefault("OPENAI_TIMEOUT", "3600")
env_vars.setdefault("OPENAI_REQUEST_TIMEOUT", "3600")
env_vars.setdefault("HTTPX_TIMEOUT", "3600")
Expand Down
Loading
Loading