Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions tests/test_interception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from verifiers.utils import interception_utils
from verifiers.utils.interception_utils import (
InterceptionError,
InterceptionServer,
StreamInterrupted,
create_empty_completion,
Expand Down Expand Up @@ -131,3 +132,96 @@ async def fake_write(data: bytes) -> None:

assert isinstance(state["error"], StreamInterrupted)
assert "ConnectionResetError" in str(state["error"])


async def test_streaming_response_future_failure_surfaces_to_state(monkeypatch):
"""If the model call underlying the stream fails (e.g. vLLM raised and
``synthesize_stream(error=X)`` was called), the ``response_future`` await
at the end of ``_handle_streaming_response`` raises. Previously that was
only logged at debug, letting the agent see a clean ``data: [DONE]`` and
exit 0 with an empty trajectory. Now it must funnel into ``state['error']``
as ``StreamInterrupted`` so the rollout halts visibly."""
server = InterceptionServer(port=0)
state: dict = {}
server.register_rollout("r1", state=state)

writes: list[bytes] = []

async def fake_write(data: bytes) -> None:
writes.append(data)

fake_response = MagicMock()
fake_response.prepare = AsyncMock()
fake_response.write = AsyncMock(side_effect=fake_write)
fake_response.write_eof = AsyncMock()
monkeypatch.setattr(
interception_utils.web, "StreamResponse", lambda **_: fake_response
)

chunk_queue: asyncio.Queue = asyncio.Queue()
await chunk_queue.put(None)

response_future: asyncio.Future = asyncio.Future()
response_future.set_exception(RuntimeError("vLLM raised"))

intercept = {
"chunk_queue": chunk_queue,
"response_future": response_future,
}

await server._handle_streaming_response(MagicMock(), "r1", intercept)

assert isinstance(state["error"], StreamInterrupted), (
f"expected StreamInterrupted, got {type(state.get('error'))}"
)
msg = str(state["error"])
assert "RuntimeError" in msg
assert "vLLM raised" in msg
assert any(w == b"data: [DONE]\n\n" for w in writes), writes
fake_response.write_eof.assert_awaited()


async def test_non_streaming_response_future_failure_surfaces_to_state(monkeypatch):
"""Non-streaming counterpart: if the model call fails and
``deliver_response`` sets the future's exception, the non-streaming
branch of ``_handle_request`` re-raises when awaiting it. That failure
must funnel into ``state['error']`` as ``InterceptionError`` so the
rollout halts visibly (HTTP 500 still returned to the client)."""
server = InterceptionServer(port=0)
state: dict = {}
server.register_rollout("r1", state=state)

request = MagicMock()
request.match_info = {"rollout_id": "r1"}
request.json = AsyncMock(
return_value={"stream": False, "messages": [], "model": "test"}
)
request.headers = {}

def fake_json_response(data, status=200):
return MagicMock(_body=data, status=status)

monkeypatch.setattr(interception_utils.web, "json_response", fake_json_response)

handler_task = asyncio.create_task(server._handle_request(request))

for _ in range(50):
if server.intercepts:
break
await asyncio.sleep(0.01)
assert server.intercepts, "handler did not register intercept"
intercept = next(iter(server.intercepts.values()))
interception_utils.deliver_response(
intercept, None, error=RuntimeError("vLLM raised")
)

response = await handler_task

assert response.status == 500
assert isinstance(state["error"], InterceptionError), (
f"expected InterceptionError, got {type(state.get('error'))}"
)
msg = str(state["error"])
assert "intercepted request failed" in msg
assert "RuntimeError" in msg
assert "vLLM raised" in msg
29 changes: 27 additions & 2 deletions verifiers/utils/interception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ class StreamInterrupted(InfraError):
"""


class InterceptionError(InfraError):
"""Raised when a non-streaming intercepted request cannot be fulfilled.

Distinct from ``StreamInterrupted`` so rubrics / metrics can tell the
two shapes apart: a streaming cut leaves the agent with a truncated
SSE body; a non-streaming failure returns HTTP 500 to the agent's
OpenAI client and the agent sees a normal API error.
"""


class InterceptionServer:
"""
HTTP server that intercepts API requests from agents.
Expand Down Expand Up @@ -201,7 +211,14 @@ async def _handle_request(self, request: Any) -> Any:
return web.json_response({"error": "Rollout cancelled"}, status=499)
except Exception as e:
logger.debug(
f"[{rollout_id}] Rollout error surfaced in non-streaming request: {type(e).__name__}: {e}"
f"[{rollout_id}] Rollout error surfaced in non-streaming "
f"request: {type(e).__name__}: {e}"
)
self._set_rollout_error(
rollout_id,
InterceptionError(
f"intercepted request failed: {type(e).__name__}: {e}"
),
)
return web.json_response({"error": str(e)}, status=500)

Expand Down Expand Up @@ -251,10 +268,18 @@ async def _handle_streaming_response(

try:
await response_future
except BaseException as e:
except asyncio.CancelledError:
raise
except Exception as e:
logger.debug(
f"[{rollout_id}] Rollout error surfaced in stream: {type(e).__name__}: {e}"
)
self._set_rollout_error(
rollout_id,
StreamInterrupted(
f"streaming response_future failed: {type(e).__name__}: {e}"
),
)

try:
await response.write_eof()
Expand Down
Loading