Skip to content

Commit 72c5d4e

Browse files
authored
Extend discrepancy check unit test for latency tuple
1 parent b21b4f6 commit 72c5d4e

1 file changed

Lines changed: 29 additions & 0 deletions

File tree

test/passes/onnx/test_discrepancy_check.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,32 @@ def test_measure_speedup_skips_when_timing_iterations_is_zero(self):
192192
assert result is None
193193
ref_model.assert_not_called()
194194
session.run.assert_not_called()
195+
196+
def test_measure_speedup_returns_latencies_and_speedup(self):
197+
import torch
198+
199+
from olive.passes.onnx.discrepancy_check import OnnxDiscrepancyCheck
200+
201+
pass_instance = OnnxDiscrepancyCheck.__new__(OnnxDiscrepancyCheck)
202+
ref_model = MagicMock()
203+
session = MagicMock()
204+
input_data = {"input_ids": torch.tensor([[1, 2, 3]], dtype=torch.int64)}
205+
dataloader = [(input_data, None)]
206+
207+
with (
208+
patch("olive.common.utils.format_data", return_value={"input_ids": [1, 2, 3]}),
209+
patch("olive.passes.onnx.discrepancy_check.time.perf_counter", side_effect=[10.0, 14.0, 20.0, 22.0]),
210+
):
211+
result = pass_instance._measure_speedup(
212+
ref_model=ref_model,
213+
session=session,
214+
dataloader=dataloader,
215+
io_config=MagicMock(),
216+
torch_device=torch.device("cpu"),
217+
warmup_iterations=1,
218+
timing_iterations=2,
219+
)
220+
221+
assert result == (2.0, 1.0, 2.0)
222+
assert ref_model.call_count == 3
223+
assert session.run.call_count == 3

0 commit comments

Comments
 (0)