-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathbench_fp8_quant.py
More file actions
65 lines (48 loc) · 2.52 KB
/
Copy pathbench_fp8_quant.py
File metadata and controls
65 lines (48 loc) · 2.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from typing import Optional
import pytest
import torch
from benchmarks.benchmark_base import BenchmarkBase, BenchmarkReport
from tileops.ops import FP8QuantOp
from workloads.fp8_quant import FP8QuantTest
class _FP8QuantTestBaseline(FP8QuantTest):
"""Adds baseline ref_program for benchmark profiling."""
def ref_program(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# input_tensor: (batch, seq_len_kv, kv_group, index_dim)
amax_value = torch.abs(input_tensor).amax(dim=-1, keepdim=True).clamp(min=1e-4)
scale_tensor = amax_value / 448.0
output_tensor = torch.clamp(input_tensor / scale_tensor, min=-448.0, max=448.0)
output_tensor = output_tensor.to(torch.float8_e4m3fn)
return scale_tensor.squeeze(dim=-1), output_tensor
class FP8QuantBenchmark(BenchmarkBase[FP8QuantTest]):
def calculate_flops(self) -> Optional[float]:
t = self.workload
return (2 * t.batch * t.seq_len_kv * t.kv_group * t.index_dim +
t.batch * t.seq_len_kv * t.kv_group + 4 * t.batch * t.seq_len_kv * t.kv_group * t.index_dim)
def calculate_memory(self) -> Optional[float]:
t = self.workload
return t.batch * t.seq_len_kv * t.kv_group * t.index_dim * t.in_dtype.itemsize
_FP8_QUANT_BENCH_PARAMS = [
pytest.param(1, 8192, 1, 64, torch.float16, True, id="mainstream-fp16"),
pytest.param(1, 8192, 1, 64, torch.bfloat16, True, id="mainstream-bf16"),
pytest.param(1, 4096, 1, 128, torch.float32, True, id="wider-index"),
pytest.param(1, 16384, 1, 32, torch.float32, True, id="long-sequence"),
]
@pytest.mark.parametrize("batch, seq_len_kv, kv_group, index_dim, in_dtype, tune",
_FP8_QUANT_BENCH_PARAMS)
def test_fp8_quant_bench(batch: int, seq_len_kv: int, kv_group: int, index_dim: int,
in_dtype: torch.dtype, tune: bool) -> None:
test = _FP8QuantTestBaseline(batch, seq_len_kv, kv_group, index_dim, in_dtype)
bm = FP8QuantBenchmark(test)
inputs = test.gen_inputs()
op = FP8QuantOp(batch=batch,
seq_len_kv=seq_len_kv,
kv_group=kv_group,
index_dim=index_dim,
in_dtype=in_dtype,
tune=tune)
result = bm.profile(op, *inputs)
BenchmarkReport.record(op, locals(), result, tag="tileops")
result_bl = bm.profile(test.ref_program, *inputs)
BenchmarkReport.record(op, locals(), result_bl, tag="torch-ref")
if __name__ == "__main__":
pytest.main([__file__, "-vvs"])