Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 84 additions & 7 deletions tests/test_env_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,27 @@
ZMQEnvServer,
)

_DEFAULT_CLIENT_AUTH = object()

def make_client(address: str = "tcp://127.0.0.1:5555", **kwargs) -> ZMQEnvClient:

def make_client(
address: str = "tcp://127.0.0.1:5555",
auth_token: str | None = None,
**kwargs,
) -> ZMQEnvClient:
"""Create a ZMQEnvClient with health checks disabled by default."""
kwargs.setdefault("health_check_interval", 0)
return ZMQEnvClient(address=address, **kwargs)
return ZMQEnvClient(address=address, auth_token=auth_token, **kwargs)


def make_mock_server(address: str) -> ZMQEnvServer:
def make_mock_server(address: str, auth_token: str | None = None) -> ZMQEnvServer:
"""Create a ZMQEnvServer with a mocked environment (no real env loading)."""
with patch("verifiers.serve.server.env_server.vf") as mock_vf:
mock_env = MagicMock()
mock_env._teardown = AsyncMock()
mock_vf.load_environment.return_value = mock_env
mock_vf.setup_logging = MagicMock()
return ZMQEnvServer(env_id="test", address=address)
return ZMQEnvServer(env_id="test", address=address, auth_token=auth_token)


def make_rollout_request() -> RunRolloutRequest:
Expand Down Expand Up @@ -77,7 +83,12 @@ def make_pending_request(


@contextlib.asynccontextmanager
async def run_server_and_client():
async def run_server_and_client(
*,
auth_token: str | None = None,
client_auth_token: object | str | None = _DEFAULT_CLIENT_AUTH,
client_health_check_interval: float = 0,
):
"""Start a mock ZMQ server and connected client, tearing both down on exit.

The router's worker spawning is mocked out so no subprocesses are created.
Expand All @@ -87,7 +98,7 @@ async def run_server_and_client():
port = get_free_port()
address = f"tcp://127.0.0.1:{port}"

server = make_mock_server(address)
server = make_mock_server(address, auth_token=auth_token)

# Mock out worker lifecycle — we don't want real subprocesses in unit tests
server.router.start_workers = MagicMock()
Expand All @@ -98,7 +109,14 @@ async def run_server_and_client():
server_loop = asyncio.create_task(server.serve(stop_event=stop_event))
await asyncio.sleep(0.1) # let server bind and start polling

client = make_client(address=address)
resolved_client_auth = (
auth_token if client_auth_token is _DEFAULT_CLIENT_AUTH else client_auth_token
)
client = make_client(
address=address,
auth_token=resolved_client_auth,
health_check_interval=client_health_check_interval,
)

try:
yield server, client
Expand Down Expand Up @@ -406,3 +424,62 @@ async def test_dispatch_called_with_correct_frames(self):
client_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await client_task


class TestZMQAuth:
"""Tests for optional transport-level auth on the ZMQ env server."""

@pytest.mark.asyncio
async def test_valid_token_dispatches_request(self):
"""Matching client/server tokens allow requests through to the router."""
async with run_server_and_client(auth_token="shared-token") as (server, client):
client_task = asyncio.create_task(
client.send_request(
make_rollout_request(), RunRolloutResponse, timeout=30
)
)

await asyncio.sleep(0.3)

assert server.router.dispatch_request.call_count == 1

client_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await client_task

@pytest.mark.asyncio
async def test_missing_token_rejected_before_dispatch(self):
"""An auth-enabled server rejects unauthenticated requests."""
async with run_server_and_client(
auth_token="shared-token",
client_auth_token=None,
) as (server, client):
with pytest.raises(RuntimeError, match="Unauthorized"):
await client.send_request(
make_rollout_request(), RunRolloutResponse, timeout=1
)

assert server.router.dispatch_request.call_count == 0

@pytest.mark.asyncio
async def test_wrong_token_rejected_before_dispatch(self):
"""Wrong tokens are rejected without reaching the router."""
async with run_server_and_client(
auth_token="shared-token",
client_auth_token="wrong-token",
) as (server, client):
with pytest.raises(RuntimeError, match="Unauthorized"):
await client.send_request(
make_rollout_request(), RunRolloutResponse, timeout=1
)

assert server.router.dispatch_request.call_count == 0

@pytest.mark.asyncio
async def test_health_checks_include_auth_token(self):
"""Health probes succeed when the client carries the shared token."""
async with run_server_and_client(
auth_token="shared-token",
client_health_check_interval=0.05,
) as (_server, client):
await client.wait_for_server_startup(timeout=1.0)
39 changes: 38 additions & 1 deletion tests/test_environment_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import asyncio
import json
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from datasets import Dataset
Expand Down Expand Up @@ -670,3 +670,40 @@ async def test_generate_resume_raises_on_metadata_mismatch(
model="test-model",
results_path=results_path,
)


@pytest.mark.asyncio
async def test_start_server_wires_shared_auth_token(make_dummy_env, mock_client):
env = make_dummy_env(mock_client)

client_instance = AsyncMock()
client_instance.wait_for_server_startup = AsyncMock()
process = MagicMock()
ctx = MagicMock()
ctx.Process.return_value = process

with (
patch("verifiers.envs.environment.get_free_port", return_value=4321),
patch("verifiers.envs.environment.mp.get_context", return_value=ctx),
patch(
"verifiers.envs.environment.ZMQEnvClient", return_value=client_instance
) as mock_env_client,
):
await env.start_server()

process_kwargs = ctx.Process.call_args.kwargs["kwargs"]
auth_token = process_kwargs["auth_token"]

assert isinstance(auth_token, str)
assert auth_token
assert process_kwargs["address"] == "tcp://127.0.0.1:4321"
assert ctx.Process.call_args.kwargs["target"].__qualname__.endswith("run_server")
assert client_instance.wait_for_server_startup.await_count == 1
assert env.env_server_process is process
assert env.env_client is client_instance
assert mock_env_client.call_args.kwargs["auth_token"] == auth_token
assert mock_env_client.call_args.kwargs["address"] == "tcp://127.0.0.1:4321"

if env.death_pipe_writer is not None:
env.death_pipe_writer.close()
env.death_pipe_writer = None
4 changes: 4 additions & 0 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import multiprocessing as mp
import secrets
import signal
import time
import uuid
Expand Down Expand Up @@ -1288,6 +1289,7 @@ async def start_server(
from verifiers.serve import ZMQEnvServer

address = address or f"tcp://127.0.0.1:{get_free_port()}"
auth_token = secrets.token_hex(32)
extra_env_kwargs = extra_env_kwargs or {}

# Death pipe: parent keeps writer, children monitor reader.
Expand All @@ -1311,6 +1313,7 @@ async def start_server(
),
kwargs=dict(
address=address,
auth_token=auth_token,
num_workers=num_workers,
death_pipe=death_pipe_reader,
),
Expand All @@ -1321,6 +1324,7 @@ async def start_server(
death_pipe_reader.close()
self.env_client = ZMQEnvClient(
address=address,
auth_token=auth_token,
health_check_interval=health_check_interval,
startup_timeout=startup_timeout,
recovery_timeout=recovery_timeout,
Expand Down
26 changes: 21 additions & 5 deletions verifiers/serve/client/zmq_env_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ class ZMQEnvClient(EnvClient):

DEFAULT_REQUEST_TIMEOUT = 36_000 # 10h

def __init__(self, address: str = "tcp://127.0.0.1:5000", **kwargs):
def __init__(
self,
address: str = "tcp://127.0.0.1:5000",
auth_token: str | None = None,
**kwargs,
):
super().__init__(address=address, **kwargs)
self.auth_token = auth_token.encode() if auth_token is not None else None

# ZMQ context
self.ctx = zmq.asyncio.Context()
Expand Down Expand Up @@ -142,10 +148,17 @@ async def close(self) -> None:
self.socket.close()
self.ctx.term()

def _build_request_frames(self, request_id: bytes, payload: bytes) -> list[bytes]:
if self.auth_token is None:
return [request_id, payload]
return [request_id, self.auth_token, payload]

async def send_cancel(self, request_id: str) -> None:
"""Send a cancel signal (empty payload) to the server for a request."""
try:
await self.socket.send_multipart([request_id.encode(), b""])
await self.socket.send_multipart(
self._build_request_frames(request_id.encode(), b"")
)
except BaseException:
pass

Expand Down Expand Up @@ -293,7 +306,9 @@ async def send_request(
async with self.pending_lock:
self.pending_requests[request_id] = pending_req

await self.socket.send_multipart([request_id.encode(), payload_bytes])
await self.socket.send_multipart(
self._build_request_frames(request_id.encode(), payload_bytes)
)

try:
raw_response = await asyncio.wait_for(future, timeout=effective_timeout)
Expand Down Expand Up @@ -359,7 +374,8 @@ def run_health_check_thread(self):
forwarded to the event loop via ``call_soon_threadsafe``.

Uses a DEALER socket on the main address (same port as requests).
Sends ``b"ping"`` as the payload; the server responds inline.
Sends ``b"ping"`` as the payload; the server responds inline. When
configured, the auth token is sent as a dedicated frame.
"""
ctx = zmq.Context()
sock = ctx.socket(zmq.DEALER)
Expand All @@ -379,7 +395,7 @@ def run_health_check_thread(self):
while not self.stop_health_thread.is_set():
is_healthy = False
try:
sock.send_multipart([b"health", b"ping"])
sock.send_multipart(self._build_request_frames(b"health", b"ping"))
frames = sock.recv_multipart()
if len(frames) == 2:
resp = msgpack.unpackb(frames[1], raw=False)
Expand Down
53 changes: 48 additions & 5 deletions verifiers/serve/server/zmq_env_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

Health checks are handled inline on the ROUTER socket — clients send a
``b"ping"`` payload and receive a pre-serialized health response back on the
same connection. No separate port is needed.
same connection. No separate port is needed. When ``auth_token`` is set,
clients must include it as a dedicated frame ahead of the payload.
"""

import asyncio
import secrets

import msgpack
import zmq
Expand All @@ -23,13 +25,26 @@
# Sentinel payload used by health-check probes.
_HEALTH_PING = b"ping"

# Pre-serialized auth failure response for regular requests and pings.
_UNAUTHORIZED_RESPONSE = msgpack.packb(
{"success": False, "error": "Unauthorized"},
use_bin_type=True,
)


class ZMQEnvServer(EnvServer):
"""ZMQ ROUTER frontend + EnvRouter worker pool."""

def __init__(self, *args, address: str = "tcp://127.0.0.1:5000", **kwargs):
def __init__(
self,
*args,
address: str = "tcp://127.0.0.1:5000",
auth_token: str | None = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.address = address
self.auth_token = auth_token.encode() if auth_token is not None else None

# Client-facing ROUTER socket (also serves health checks)
self.ctx = zmq.asyncio.Context()
Expand All @@ -49,6 +64,24 @@ async def send_response(
except zmq.ZMQError as e:
self.logger.warning(f"Failed to forward response: {e}")

def _parse_frames(
self, frames: list[bytes]
) -> tuple[bytes, bytes, bytes | None, bytes] | None:
if len(frames) == 3:
client_id, request_id, payload = frames
return client_id, request_id, None, payload
if len(frames) == 4:
client_id, request_id, auth_token, payload = frames
return client_id, request_id, auth_token, payload
return None

def _is_authorized(self, auth_token: bytes | None) -> bool:
if self.auth_token is None:
return True
if auth_token is None:
return False
return secrets.compare_digest(auth_token, self.auth_token)

async def serve(self, stop_event: asyncio.Event | None = None) -> None:
self.logger.info(f"ZMQEnvServer started on {self.address}")

Expand Down Expand Up @@ -79,12 +112,22 @@ async def serve(self, stop_event: asyncio.Event | None = None) -> None:

if self.frontend in events:
frames = await self.frontend.recv_multipart()
if len(frames) != 3:
parsed = self._parse_frames(frames)
if parsed is None:
self.logger.warning(
f"Invalid message: expected 3 frames, got {len(frames)}"
f"Invalid message: expected 3 or 4 frames, got {len(frames)}"
)
else:
client_id, request_id, payload = frames
client_id, request_id, auth_token, payload = parsed
if not self._is_authorized(auth_token):
if payload != b"":
try:
await self.frontend.send_multipart(
[client_id, request_id, _UNAUTHORIZED_RESPONSE]
)
except zmq.ZMQError:
pass
continue
if payload == _HEALTH_PING:
# Health check — respond immediately
try:
Expand Down
Loading