-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathbench_vector_norm.py
More file actions
107 lines (79 loc) · 3.72 KB
/
Copy pathbench_vector_norm.py
File metadata and controls
107 lines (79 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Benchmarks for vector norm ops (l1_norm, l2_norm, inf_norm).
Measures latency, TFLOPS, and DRAM bandwidth against PyTorch baselines.
Workload shapes and roofline formulas are loaded from the ops manifest (tileops/manifest/).
"""
import pytest
import torch
from benchmarks.benchmark_base import BenchmarkReport, ManifestBenchmark, workloads_to_params
from tileops.ops.reduction.inf_norm import InfNormFwdOp
from tileops.ops.reduction.l1_norm import L1NormFwdOp
from tileops.ops.reduction.l2_norm import L2NormFwdOp
from workloads.vector_norm import InfNormTest, L1NormTest, L2NormTest
# ===================================================================
# Op name constants
# ===================================================================
_L1_NORM_OP = "L1NormFwdOp"
_L2_NORM_OP = "L2NormFwdOp"
_INF_NORM_OP = "InfNormFwdOp"
# ===================================================================
# L1 Norm benchmarks
# ===================================================================
@pytest.mark.parametrize("shape, dtype", workloads_to_params(_L1_NORM_OP))
def test_l1_norm_bench(shape: tuple, dtype: torch.dtype) -> None:
test = L1NormTest(shape, dtype)
inputs = test.gen_inputs()
op = L1NormFwdOp(dtype=dtype, dim=-1)
bm = ManifestBenchmark(_L1_NORM_OP, op, test)
try:
result = bm.profile(op, *inputs)
except ValueError as exc:
if "No configurations to tune" in str(exc):
pytest.skip(f"Kernel does not support this shape: {exc}")
raise
BenchmarkReport.record(op, locals(), result, tag="tileops")
def baseline_fn(x):
return torch.linalg.vector_norm(x.float(), ord=1, dim=-1).to(x.dtype)
result_bl = bm.profile(baseline_fn, *inputs)
BenchmarkReport.record(op, locals(), result_bl, tag="torch")
# ===================================================================
# L2 Norm benchmarks
# ===================================================================
@pytest.mark.parametrize("shape, dtype", workloads_to_params(_L2_NORM_OP))
def test_l2_norm_bench(shape: tuple, dtype: torch.dtype) -> None:
test = L2NormTest(shape, dtype)
inputs = test.gen_inputs()
op = L2NormFwdOp(dtype=dtype, dim=-1)
bm = ManifestBenchmark(_L2_NORM_OP, op, test)
try:
result = bm.profile(op, *inputs)
except ValueError as exc:
if "No configurations to tune" in str(exc):
pytest.skip(f"Kernel does not support this shape: {exc}")
raise
BenchmarkReport.record(op, locals(), result, tag="tileops")
def baseline_fn(x):
return torch.linalg.vector_norm(x.float(), ord=2, dim=-1).to(x.dtype)
result_bl = bm.profile(baseline_fn, *inputs)
BenchmarkReport.record(op, locals(), result_bl, tag="torch")
# ===================================================================
# Inf Norm benchmarks
# ===================================================================
@pytest.mark.parametrize("shape, dtype", workloads_to_params(_INF_NORM_OP))
def test_inf_norm_bench(shape: tuple, dtype: torch.dtype) -> None:
test = InfNormTest(shape, dtype)
inputs = test.gen_inputs()
op = InfNormFwdOp(dtype=dtype, dim=-1)
bm = ManifestBenchmark(_INF_NORM_OP, op, test)
try:
result = bm.profile(op, *inputs)
except ValueError as exc:
if "No configurations to tune" in str(exc):
pytest.skip(f"Kernel does not support this shape: {exc}")
raise
BenchmarkReport.record(op, locals(), result, tag="tileops")
def baseline_fn(x):
return torch.linalg.vector_norm(x.float(), ord=float("inf"), dim=-1).to(x.dtype)
result_bl = bm.profile(baseline_fn, *inputs)
BenchmarkReport.record(op, locals(), result_bl, tag="torch")
if __name__ == "__main__":
pytest.main([__file__, "-vvs"])