Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,11 @@ Token-level data for training.
```python
class RolloutTiming(TypedDict, total=False):
start_time: float
generation_ms: float
scoring_ms: float
total_ms: float
setup_s: float
generation_s: float
scoring_s: float
overhead_s: float
total_s: float
```

### TokenUsage
Expand Down
31 changes: 15 additions & 16 deletions packages/verifiers-rl/verifiers_rl/rl/trainer/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,24 +301,23 @@ async def generate_batch(self, batch_id: int) -> Batch:
masked_fraction = 1.0 - (valid_tokens / total_tokens)
metrics_dict["tokens/masked_fraction"] = float(masked_fraction)

generation_ms: list[float] = []
scoring_ms: list[float] = []
total_ms: list[float] = []
timing_fields = (
"setup_s",
"generation_s",
"scoring_s",
"overhead_s",
"total_s",
)
timing_accum: dict[str, list[float]] = {k: [] for k in timing_fields}
for output in outputs:
timing = output.get("timing", {})
if "generation_ms" in timing:
generation_ms.append(float(timing["generation_ms"]))
if "scoring_ms" in timing:
scoring_ms.append(float(timing["scoring_ms"]))
if "total_ms" in timing:
total_ms.append(float(timing["total_ms"]))

if generation_ms:
metrics_dict["timing/generation_ms"] = float(np.mean(generation_ms))
if scoring_ms:
metrics_dict["timing/scoring_ms"] = float(np.mean(scoring_ms))
if total_ms:
metrics_dict["timing/total_ms"] = float(np.mean(total_ms))
for key in timing_fields:
if key in timing:
timing_accum[key].append(float(timing[key]))

for key in timing_fields:
if timing_accum[key]:
metrics_dict[f"timing/{key}"] = float(np.mean(timing_accum[key]))

metrics_dict["wall_clock/generate_s"] = float(wall_clock_s)
errors = [output.get("error") for output in outputs]
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,9 @@ def _make_state(
tool_defs: list[Tool] | None = None,
trajectory: list[TrajectoryStep] = [],
timing=RolloutTiming(
generation_ms=0.0,
scoring_ms=0.0,
total_ms=0.0,
generation_s=0.0,
scoring_s=0.0,
total_s=0.0,
),
foo: str = "bar", # custom field
**kwargs,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_composable_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ async def test_composable_env_quotes_log_path_when_collecting_logs():
teardown=lambda: None,
)

state = {"sandbox_id": "sbx", "timing": {"total_ms": 0}}
state = {"sandbox_id": "sbx", "timing": {"total_s": 0}}

await env.post_rollout(state)

Expand Down Expand Up @@ -594,7 +594,7 @@ async def test_composable_env_collects_harness_metrics():
state = {
"sandbox_id": "sbx",
"info": {"id": 0},
"timing": {"total_ms": 0},
"timing": {"total_s": 0},
"trajectory": [],
}

Expand Down Expand Up @@ -633,7 +633,7 @@ async def test_composable_env_metrics_with_key_whitelist():
state = {
"sandbox_id": "sbx",
"info": {"id": 0},
"timing": {"total_ms": 0},
"timing": {"total_s": 0},
"trajectory": [],
}

Expand All @@ -659,7 +659,7 @@ async def test_composable_env_no_metrics_when_path_not_set():
state = {
"sandbox_id": "sbx",
"info": {"id": 0},
"timing": {"total_ms": 0},
"timing": {"total_s": 0},
"trajectory": [],
}

Expand Down
12 changes: 6 additions & 6 deletions tests/test_env_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def func2(completion, **kwargs):
state["completion"] = "Test completion"
state["trajectory"] = []
state["timing"] = {
"generation_ms": 0.0,
"scoring_ms": 0.0,
"total_ms": 0.0,
"generation_s": 0.0,
"scoring_s": 0.0,
"total_s": 0.0,
"start_time": 0.0,
}
state["is_completed"] = False
Expand Down Expand Up @@ -119,9 +119,9 @@ async def test_env_group_rubric_unknown_task(self, mock_client, make_input):
state["completion"] = "Test"
state["trajectory"] = []
state["timing"] = {
"generation_ms": 0.0,
"scoring_ms": 0.0,
"total_ms": 0.0,
"generation_s": 0.0,
"scoring_s": 0.0,
"total_s": 0.0,
"start_time": 0.0,
}
state["is_completed"] = False
Expand Down
3 changes: 2 additions & 1 deletion tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def rollout(
from verifiers.utils.response_utils import parse_response_message

completion_messages = await parse_response_message(response)
from verifiers.types import TrajectoryStep
from verifiers.types import StepTiming, TrajectoryStep
from verifiers.utils.response_utils import parse_response_tokens

tokens = await parse_response_tokens(response)
Expand All @@ -60,6 +60,7 @@ async def rollout(
is_truncated=False,
trajectory_id=state["trajectory_id"],
extras={},
timing=StepTiming(model_s=0.0, env_s=0.0, turn_s=0.0),
)
state["trajectory"].append(trajectory_step)
state["is_completed"] = True
Expand Down
3 changes: 2 additions & 1 deletion tests/test_environment_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def rollout(
response = await self.get_model_response(state=state, prompt=prompt_messages)
assert response is not None

from verifiers.types import TrajectoryStep
from verifiers.types import StepTiming, TrajectoryStep
from verifiers.utils.response_utils import (
parse_response_message,
parse_response_tokens,
Expand All @@ -76,6 +76,7 @@ async def rollout(
is_truncated=False,
trajectory_id=state["trajectory_id"],
extras={},
timing=StepTiming(model_s=0.0, env_s=0.0, turn_s=0.0),
)
state["trajectory"].append(trajectory_step)
state["is_completed"] = True
Expand Down
18 changes: 9 additions & 9 deletions tests/test_math_rubric.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ async def test_score_valid_answers(self, test_case, make_input):
state["completion"] = test_case["completion"]
state["trajectory"] = []
state["timing"] = {
"generation_ms": 0.0,
"scoring_ms": 0.0,
"total_ms": 0.0,
"generation_s": 0.0,
"scoring_s": 0.0,
"total_s": 0.0,
"start_time": 0.0,
}

Expand Down Expand Up @@ -81,9 +81,9 @@ async def test_score_invalid_answers(self, test_case, make_input):
state["completion"] = test_case["completion"]
state["trajectory"] = []
state["timing"] = {
"generation_ms": 0.0,
"scoring_ms": 0.0,
"total_ms": 0.0,
"generation_s": 0.0,
"scoring_s": 0.0,
"total_s": 0.0,
"start_time": 0.0,
}

Expand Down Expand Up @@ -114,9 +114,9 @@ async def test_timeout(self, timeout_seconds, make_input):
state["completion"] = completion
state["trajectory"] = []
state["timing"] = {
"generation_ms": 0.0,
"scoring_ms": 0.0,
"total_ms": 0.0,
"generation_s": 0.0,
"scoring_s": 0.0,
"total_s": 0.0,
"start_time": 0.0,
}

Expand Down
107 changes: 107 additions & 0 deletions tests/test_per_turn_timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Tests for per-turn StepTiming on TrajectoryStep."""

import pytest
from datasets import Dataset

from verifiers import Messages, MultiTurnEnv, Parser, Rubric, SingleTurnEnv, State


class TestSingleTurnStepTiming:
@pytest.mark.asyncio
async def test_single_turn_has_step_timing(self, mock_client, make_input):
"""SingleTurnEnv rollout produces a step with timing."""
dataset = Dataset.from_dict({"question": ["q1"], "answer": ["a1"]})
env = SingleTurnEnv(
client=mock_client,
model="test-model",
dataset=dataset,
rubric=Rubric(),
)
mock_client.set_default_response("hello")

state = await env.rollout(
input=make_input(prompt=[{"role": "user", "content": "q1"}], answer="a1"),
client=mock_client,
model="test-model",
)

assert len(state["trajectory"]) == 1
step = state["trajectory"][0]
assert "timing" in step
t = step["timing"]
assert t["model_s"] > 0
assert t["env_s"] == 0.0
assert t["turn_s"] == t["model_s"]

@pytest.mark.asyncio
async def test_timing_values_are_seconds(self, mock_client, make_input):
"""Assert timing values are small floats (seconds, not ms)."""
dataset = Dataset.from_dict({"question": ["q1"], "answer": ["a1"]})
env = SingleTurnEnv(
client=mock_client,
model="test-model",
dataset=dataset,
rubric=Rubric(),
)
mock_client.set_default_response("hello")

state = await env.rollout(
input=make_input(prompt=[{"role": "user", "content": "q1"}], answer="a1"),
client=mock_client,
model="test-model",
)

step = state["trajectory"][0]
assert step["timing"]["model_s"] < 10
assert step["timing"]["turn_s"] < 10


class TestMultiTurnStepTiming:
@pytest.mark.asyncio
async def test_multi_turn_backfills_env_timing(self, mock_client, make_input):
"""In a 2-turn env, step 0 gets env_s backfilled > 0, last step has env_s == 0."""

class TwoTurnEnv(MultiTurnEnv):
def __init__(self, **kwargs):
super().__init__(max_turns=2, **kwargs)

async def env_response(self, messages: Messages, state: State, **kwargs):
return [{"role": "user", "content": "follow-up"}]

dataset = Dataset.from_dict({"question": ["q1"], "answer": ["a1"]})
env = TwoTurnEnv(
client=mock_client,
model="test-model",
dataset=dataset,
parser=Parser(),
rubric=Rubric(),
)
mock_client.set_default_response("response")

state = await env.rollout(
input=make_input(prompt=[{"role": "user", "content": "q1"}], answer="a1"),
client=mock_client,
model="test-model",
)

assert len(state["trajectory"]) == 2

step0 = state["trajectory"][0]
step1 = state["trajectory"][1]

# step 0 should have env_s backfilled from the get_prompt_messages call
# that produced step 1's prompt
assert "timing" in step0
assert step0["timing"]["env_s"] > 0
assert step0["timing"]["turn_s"] >= step0["timing"]["model_s"]

# last step should have env_s == 0 (no subsequent get_prompt_messages)
assert "timing" in step1
assert step1["timing"]["env_s"] == 0.0
assert step1["timing"]["turn_s"] == step1["timing"]["model_s"]

# All values should be seconds (small floats)
for step in state["trajectory"]:
assert step["timing"]["model_s"] < 10
assert step["timing"]["env_s"] < 10
assert step["timing"]["turn_s"] < 10
2 changes: 1 addition & 1 deletion tests/test_rlm_composable_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ async def test_rlm_collects_logs_and_metrics(tmp_path):
state = {
"sandbox_id": "sbx",
"info": {"id": 0},
"timing": {"total_ms": 0},
"timing": {"total_s": 0},
"trajectory": [],
}

Expand Down
Loading
Loading