Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions tests/test_interception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,76 @@ async def fake_write(data: bytes) -> None:

assert isinstance(state["error"], StreamInterrupted)
assert "ConnectionResetError" in str(state["error"])


async def test_keepalive_emitted_during_idle(monkeypatch):
"""During the idle window (no chunks on chunk_queue) the handler must
emit SSE keepalive comments so upstream idle-timeouts don't fire."""
monkeypatch.setattr(interception_utils, "KEEPALIVE_INTERVAL_SECONDS", 0.05)
server = InterceptionServer(port=0)
state: dict = {}
server.register_rollout("r1", state=state)

writes: list[bytes] = []

async def fake_write(data: bytes) -> None:
writes.append(data)

fake_response = MagicMock()
fake_response.prepare = AsyncMock()
fake_response.write = AsyncMock(side_effect=fake_write)
fake_response.write_eof = AsyncMock()
monkeypatch.setattr(
interception_utils.web, "StreamResponse", lambda **_: fake_response
)

chunk_queue: asyncio.Queue = asyncio.Queue() # starts empty
response_future: asyncio.Future = asyncio.Future()
intercept = {
"chunk_queue": chunk_queue,
"response_future": response_future,
}

task = asyncio.create_task(
server._handle_streaming_response(MagicMock(), "r1", intercept)
)
await asyncio.sleep(0.2) # enough for a few keepalive cycles

# Close the loop cleanly: EOF sentinel + resolved future → handler returns.
response_future.set_result(None)
await chunk_queue.put(None)
await task

assert any(w == b": keepalive\n\n" for w in writes), (
f"expected at least one keepalive write, got writes={writes}"
)


async def test_keepalive_write_failure_surfaces_to_state(monkeypatch):
"""A failed keepalive write (upstream already cut the TCP connection)
must funnel into ``state["error"]`` with elapsed-time instrumentation."""
monkeypatch.setattr(interception_utils, "KEEPALIVE_INTERVAL_SECONDS", 0.05)
server = InterceptionServer(port=0)
state: dict = {}
server.register_rollout("r1", state=state)

fake_response = MagicMock()
fake_response.prepare = AsyncMock()
fake_response.write = AsyncMock(side_effect=ConnectionResetError("tunnel died"))
fake_response.write_eof = AsyncMock()
monkeypatch.setattr(
interception_utils.web, "StreamResponse", lambda **_: fake_response
)

chunk_queue: asyncio.Queue = asyncio.Queue() # never produces
intercept = {
"chunk_queue": chunk_queue,
"response_future": asyncio.Future(),
}

await server._handle_streaming_response(MagicMock(), "r1", intercept)

assert isinstance(state["error"], StreamInterrupted)
msg = str(state["error"])
assert "keepalive write failed" in msg
assert "ConnectionResetError" in msg
50 changes: 47 additions & 3 deletions verifiers/utils/interception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
logger = logging.getLogger(__name__)


KEEPALIVE_INTERVAL_SECONDS = 10.0


class StreamInterrupted(InfraError):
"""Raised when the intercepted streaming response to the agent is cut short.

Expand Down Expand Up @@ -226,9 +229,45 @@ async def _handle_streaming_response(
)
await response.prepare(http_request)

start = time.monotonic()
# Reuse a single get() task across keepalive cycles instead of
# recreating it each iteration. ``asyncio.wait_for`` on Python
# 3.10/3.11 has a race where a timeout cancels an inner task that
# may have already dequeued an item, silently dropping it.
# ``asyncio.wait`` does not cancel its tasks on timeout, so a
# pending ``get()`` task carries forward safely.
get_task: asyncio.Task | None = None
try:
while True:
chunk_dict = await chunk_queue.get()
if get_task is None:
get_task = asyncio.ensure_future(chunk_queue.get())
done, _ = await asyncio.wait(
{get_task}, timeout=KEEPALIVE_INTERVAL_SECONDS
)
if get_task not in done:
# Idle window — emit SSE keepalive comment to keep
# intermediaries (tunnel, LB, kube-proxy) from closing
# the connection during the long vLLM wait.
try:
await response.write(b": keepalive\n\n")
except Exception as e:
waited_s = time.monotonic() - start
logger.debug(
f"[{rollout_id}] Streaming error during keepalive "
f"after {waited_s:.1f}s: {e}"
)
self._set_rollout_error(
rollout_id,
StreamInterrupted(
f"keepalive write failed after {waited_s:.1f}s: "
f"{type(e).__name__}: {e}"
),
)
return response
continue

chunk_dict = get_task.result()
get_task = None

if chunk_dict is None:
await response.write(b"data: [DONE]\n\n")
Expand All @@ -240,14 +279,19 @@ async def _handle_streaming_response(
except asyncio.CancelledError:
logger.debug(f"[{rollout_id}] Streaming cancelled")
except Exception as e:
logger.error(f"[{rollout_id}] Streaming error: {e}")
waited_s = time.monotonic() - start
logger.debug(f"[{rollout_id}] Streaming error after {waited_s:.1f}s: {e}")
Comment thread
rasdani marked this conversation as resolved.
Outdated
self._set_rollout_error(
rollout_id,
StreamInterrupted(
f"Interception stream to agent interrupted: {type(e).__name__}: {e}"
f"stream write failed after {waited_s:.1f}s: "
f"{type(e).__name__}: {e}"
),
)
return response
finally:
if get_task is not None and not get_task.done():
get_task.cancel()

try:
await response_future
Expand Down
Loading