-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathbench_argreduce.py
More file actions
106 lines (82 loc) · 3.82 KB
/
Copy pathbench_argreduce.py
File metadata and controls
106 lines (82 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Benchmarks for argreduce ops (argmax, argmin).
Measures latency, TFLOPS, and DRAM bandwidth against PyTorch baselines.
Workload shapes, dtypes, and op-call parameters (e.g. ``dim``) are loaded
from the ops manifest (``tileops/manifest/``) — the benchmark must not
hard-code op parameters that are declared on manifest workload entries.
"""
import pytest
import torch
from benchmarks.benchmark_base import BenchmarkReport, ManifestBenchmark, workloads_to_params
from tileops.ops.reduction.argmax import ArgmaxFwdOp
from tileops.ops.reduction.argmin import ArgminFwdOp
from workloads.argreduce import ArgmaxTest, ArgminTest
_ARGMAX_OP = "ArgmaxFwdOp"
_ARGMIN_OP = "ArgminFwdOp"
def _is_unsupported_large_argreduce_error(exc: Exception) -> bool:
"""Return True for known staged-rollout large-N argreduce failures."""
msg = str(exc)
return (
"scalable vector" in msg
or "No configurations to tune" in msg
or (
"A single row requires" in msg
and "shared memory" in msg
and "exceeds" in msg
)
)
# ===================================================================
# Argmax benchmarks
# ===================================================================
@pytest.mark.parametrize("shape, dtype, extra", workloads_to_params(_ARGMAX_OP, include_extra=True))
def test_argmax_bench(shape: tuple, dtype: torch.dtype, extra: dict) -> None:
workload = ArgmaxTest(shape, dtype)
inputs = workload.gen_inputs()
op = ArgmaxFwdOp(dtype=dtype, **extra)
bm = ManifestBenchmark(_ARGMAX_OP, op, workload)
# FIXME(staged-rollout): ArgreduceKernel skips large-N manifest workloads
#
# Broken invariant: benchmark must execute all manifest workload shapes
# Why: the current single-tile shared-memory kernel cannot fit lm-head
# N=102400 rows (204800 bytes per fp16/bf16 row exceeds 49152 bytes).
# Cleanup: remove try/skip once ArgreduceKernel has a tiled-N path.
try:
result = bm.profile(op, *inputs)
except Exception as exc:
if _is_unsupported_large_argreduce_error(exc):
pytest.skip(f"Kernel does not support this shape: {exc}")
raise
BenchmarkReport.record(op, locals(), result, tag="tileops")
dim = extra["dim"]
def baseline_fn(x):
return x.argmax(dim=dim)
result_bl = bm.profile(baseline_fn, *inputs)
BenchmarkReport.record(op, locals(), result_bl, tag="torch")
# ===================================================================
# Argmin benchmarks
# ===================================================================
@pytest.mark.parametrize("shape, dtype, extra", workloads_to_params(_ARGMIN_OP, include_extra=True))
def test_argmin_bench(shape: tuple, dtype: torch.dtype, extra: dict) -> None:
workload = ArgminTest(shape, dtype)
inputs = workload.gen_inputs()
op = ArgminFwdOp(dtype=dtype, **extra)
bm = ManifestBenchmark(_ARGMIN_OP, op, workload)
# FIXME(staged-rollout): ArgreduceKernel skips large-N manifest workloads
#
# Broken invariant: benchmark must execute all manifest workload shapes
# Why: the current single-tile shared-memory kernel cannot fit lm-head
# N=102400 rows (204800 bytes per fp16/bf16 row exceeds 49152 bytes).
# Cleanup: remove try/skip once ArgreduceKernel has a tiled-N path.
try:
result = bm.profile(op, *inputs)
except Exception as exc:
if _is_unsupported_large_argreduce_error(exc):
pytest.skip(f"Kernel does not support this shape: {exc}")
raise
BenchmarkReport.record(op, locals(), result, tag="tileops")
dim = extra["dim"]
def baseline_fn(x):
return x.argmin(dim=dim)
result_bl = bm.profile(baseline_fn, *inputs)
BenchmarkReport.record(op, locals(), result_bl, tag="torch")
if __name__ == "__main__":
pytest.main([__file__, "-vvs"])