Skip to content

Commit 97c06cf

Browse files
authored
Merge branch 'main' into calebmkim/stack/1
2 parents f212074 + 82bee2f commit 97c06cf

71 files changed

Lines changed: 8775 additions & 811 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e2b56015f5107caf4fecbe58273ea5d5ad53de27
1+
013936a6640107c22632debc47379a14e8e2501b

.github/matrix.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"runner": "mt-l-x86iamx-22-225-h100",
5151
"python-version": "3.12",
5252
"ref-eager": false,
53-
"image": "pytorch/pytorch:2.11.0-cuda13.0-cudnn9-devel",
53+
"image": "nvidia/cuda:13.1.0-devel-ubuntu24.04",
5454
"runtime-version": "cu130",
5555
"container-options": "--gpus all",
5656
"pytorch-version": "pytorch-nightly",
@@ -61,7 +61,7 @@
6161
"runner": "mt-l-x86iamx-88-900-h100-4",
6262
"python-version": "3.12",
6363
"ref-eager": false,
64-
"image": "pytorch/pytorch:2.11.0-cuda13.0-cudnn9-devel",
64+
"image": "nvidia/cuda:13.1.0-devel-ubuntu24.04",
6565
"runtime-version": "cu130",
6666
"container-options": "--gpus all",
6767
"pytorch-version": "pytorch-nightly",

.github/workflows/benchmark.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ jobs:
6969

7070
steps:
7171
- name: Run NVIDIA command
72-
if: startsWith(inputs.image, 'nvidia') || (startsWith(inputs.image, 'pytorch') && contains(inputs.image, 'cuda'))
72+
if: startsWith(inputs.image, 'nvidia')
7373
run: |
7474
echo "Detected NVIDIA image"
7575
nvidia-smi || echo "nvidia-smi not found"
@@ -122,7 +122,7 @@ jobs:
122122
./scripts/install_cute.sh
123123
124124
- name: CUDA Compute Check
125-
if: startsWith(inputs.image, 'nvidia') || (startsWith(inputs.image, 'pytorch') && contains(inputs.image, 'cuda'))
125+
if: startsWith(inputs.image, 'nvidia')
126126
run: |
127127
source .venv/bin/activate
128128
python -c "

.github/workflows/benchmark_dispatch.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ jobs:
8181
with:
8282
runner: mt-l-x86iamx-22-225-h100
8383
python-version: "3.12"
84-
image: pytorch/pytorch:2.11.0-cuda13.0-cudnn9-devel
84+
image: nvidia/cuda:13.1.0-devel-ubuntu24.04
8585
runtime-version: cu130
8686
container-options: --gpus all
8787
alias: h100

.github/workflows/benchmark_tpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ jobs:
110110
cd -
111111
rm -rf /tmp/torch_tpu
112112
# Verify
113-
python -c "from torch_tpu import api; print(f'TPU device: {api.tpu_device()}')"
113+
python -c "import torch, sys; print('Success') if torch.tpu.is_available() else (print('(Torch)TPU not available'), sys.exit(1))"
114114
115115
- name: Run TPU Benchmark
116116
run: |

.github/workflows/test.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949

5050
steps:
5151
- name: Run NVIDIA command
52-
if: startsWith(matrix.image, 'nvidia') || (startsWith(matrix.image, 'pytorch') && contains(matrix.image, 'cuda'))
52+
if: startsWith(matrix.image, 'nvidia')
5353
run: |
5454
echo "Detected NVIDIA image"
5555
nvidia-smi || echo "nvidia-smi not found"
@@ -235,7 +235,7 @@ jobs:
235235
cd -
236236
rm -rf /tmp/torch_tpu
237237
# Verify
238-
python -c "from torch_tpu import api; print(f'TPU device: {api.tpu_device()}')"
238+
python -c "import torch, sys; print('Success') if torch.tpu.is_available() else (print('(Torch)TPU not available'), sys.exit(1))"
239239
240240
- name: Install Pallas interpret dependencies
241241
if: matrix.alias == 'pallas-interpret'
@@ -250,7 +250,7 @@ jobs:
250250
./scripts/install_cute.sh
251251
252252
- name: CUDA Compute Check
253-
if: startsWith(matrix.image, 'nvidia') || (startsWith(matrix.image, 'pytorch') && contains(matrix.image, 'cuda'))
253+
if: startsWith(matrix.image, 'nvidia')
254254
run: |
255255
source .venv/bin/activate
256256
python -c "
@@ -271,7 +271,7 @@ jobs:
271271
"
272272
273273
- name: Inductor Worker Check
274-
if: startsWith(matrix.image, 'nvidia') || (startsWith(matrix.image, 'pytorch') && contains(matrix.image, 'cuda'))
274+
if: startsWith(matrix.image, 'nvidia')
275275
run: |
276276
source .venv/bin/activate
277277
python -c "

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ site
8888
tags
8989
TAGS
9090
torch
91-
triton
91+
/triton
9292
*.user
9393
uv.lock
9494
venv

benchmarks/cute/compare_matmul_backends.py

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,10 @@
192192
"silu",
193193
"gelu",
194194
"residual_add",
195+
# FP8 RowWise scaled_mm: out = scale_a[m] * scale_b[n] * (a_fp8 @ b_fp8).
196+
# The rowwise scale is fused into the epilogue. Intended for --dtype
197+
# float8_e4m3fn; the reference is torch._scaled_mm.
198+
"scaled_mm",
195199
)
196200
QUACK_TUNE_CHOICES = ("off", "brief")
197201
# Brief tuning covers the documented default, larger cluster/swizzle variants,
@@ -725,9 +729,14 @@ def _dtype_from_name(name: str) -> torch.dtype:
725729
"float16": torch.float16,
726730
"bfloat16": torch.bfloat16,
727731
"float32": torch.float32,
732+
"float8_e4m3fn": torch.float8_e4m3fn,
728733
}[name]
729734

730735

736+
def _is_fp8(dtype: torch.dtype) -> bool:
737+
return dtype == torch.float8_e4m3fn
738+
739+
731740
def _tflops(m: int, n: int, k: int, ms: float) -> float:
732741
return (2.0 * m * n * k) / (ms * 1e9)
733742

@@ -832,20 +841,41 @@ def _make_inputs(
832841
seed: int,
833842
) -> tuple[torch.Tensor, torch.Tensor]:
834843
torch.manual_seed(seed)
844+
if _is_fp8(dtype):
845+
# fp8 has a tiny dynamic range, so build the operands in f32 and cast.
846+
# b is laid out column-major (K-contiguous), the layout the tcgen05
847+
# fp8 path and torch._scaled_mm expect for the second operand.
848+
a = (torch.randn((m, k), device="cuda") * 0.4).to(dtype)
849+
b = (torch.randn((k, n), device="cuda") * 0.4).to(dtype).T.contiguous().T
850+
return a, b
835851
a = torch.randn((m, k), device="cuda", dtype=dtype)
836852
b = torch.randn((k, n), device="cuda", dtype=dtype) / math.sqrt(k)
837853
return a, b
838854

839855

856+
def _make_scales(
857+
args: argparse.Namespace,
858+
) -> tuple[torch.Tensor, torch.Tensor]:
859+
"""Per-row (scale_a [m,1]) and per-column (scale_b [1,n]) f32 rowwise scales
860+
for the ``scaled_mm`` epilogue. Non-trivial (random) values so a broadcast
861+
bug actually surfaces in the correctness check."""
862+
scale_a = (torch.rand((args.m, 1), device="cuda") + 0.5).to(torch.float32)
863+
scale_b = (torch.rand((1, args.n), device="cuda") + 0.5).to(torch.float32)
864+
return scale_a, scale_b
865+
866+
840867
def _make_epilogue_inputs(
841868
args: argparse.Namespace, dtype: torch.dtype
842869
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
843870
bias = None
844871
residual = None
872+
# fp8 epilogue aux tensors (bias/residual) are kept in the *output* dtype
873+
# (bf16); only the matmul operands are fp8.
874+
aux_dtype = torch.bfloat16 if _is_fp8(dtype) else dtype
845875
if args.epilogue in ("bias", "bias_relu", "bias_residual_gelu"):
846-
bias = torch.randn((args.n,), device="cuda", dtype=dtype)
876+
bias = torch.randn((args.n,), device="cuda", dtype=aux_dtype)
847877
if args.epilogue in ("bias_residual_gelu", "residual_add"):
848-
residual = torch.randn((args.m, args.n), device="cuda", dtype=dtype)
878+
residual = torch.randn((args.m, args.n), device="cuda", dtype=aux_dtype)
849879
return bias, residual
850880

851881

@@ -857,9 +887,20 @@ def _make_matmul_problem(
857887
dtype = _dtype_from_name(args.dtype)
858888
a, b = _make_inputs(args.m, args.n, args.k, dtype, seed=args.seed)
859889
bias, residual = _make_epilogue_inputs(args, dtype)
890+
# Stash the rowwise scales on args so the (a, b, bias, residual) tuple
891+
# threaded through every impl stays unchanged. Only the scaled_mm path reads
892+
# them, via _scaled_mm_scales().
893+
if args.epilogue == "scaled_mm":
894+
args._scale_a, args._scale_b = _make_scales(args)
860895
return dtype, a, b, bias, residual
861896

862897

898+
def _scaled_mm_scales(
899+
args: argparse.Namespace,
900+
) -> tuple[torch.Tensor, torch.Tensor]:
901+
return args._scale_a, args._scale_b
902+
903+
863904
def _apply_epilogue(
864905
args: argparse.Namespace,
865906
acc: torch.Tensor,
@@ -889,6 +930,11 @@ def _apply_epilogue(
889930
if args.epilogue == "residual_add":
890931
assert residual is not None
891932
return acc + residual
933+
if args.epilogue == "scaled_mm":
934+
# out = scale_a[m] * scale_b[n] * acc, cast to bf16. The scale is folded
935+
# on the f32 accumulator before the cast (matches the fused epilogue).
936+
scale_a, scale_b = _scaled_mm_scales(args)
937+
return (acc.float() * scale_a * scale_b).to(torch.bfloat16)
892938
raise AssertionError(f"unhandled epilogue {args.epilogue!r}")
893939

894940

@@ -900,14 +946,25 @@ def _matmul_expected(
900946
residual: torch.Tensor | None,
901947
dtype: torch.dtype,
902948
) -> torch.Tensor:
903-
return _apply_epilogue(args, a @ b, bias, residual, dtype)
949+
# acc is f32; for fp8 inputs the product is computed in f32 to mirror the
950+
# tensor-core accumulate before any epilogue (scaled_mm, activation, ...).
951+
acc = (a.float() @ b.float()) if _is_fp8(dtype) else (a @ b)
952+
return _apply_epilogue(args, acc, bias, residual, dtype)
904953

905954

906955
def _check_close(
907956
actual: torch.Tensor, expected: torch.Tensor, dtype: torch.dtype
908957
) -> None:
909958
if dtype == torch.float32:
910959
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)
960+
elif _is_fp8(dtype):
961+
# fp8 (e4m3) operands carry ~2 decimal digits, so the GEMM accumulates
962+
# substantial quantization error; use a relative-error tolerance like
963+
# the scaled_mm unit checks.
964+
ref_max = expected.float().abs().max().item() + 1e-12
965+
rel = (actual.float() - expected.float()).abs().max().item() / ref_max
966+
if rel > 0.1:
967+
raise AssertionError(f"fp8 mismatch: rel_err={rel:.4f} > 0.1")
911968
else:
912969
# bf16/fp16 GEMMs accumulate enough rounding noise that benchmark
913970
# smoke tests need a looser threshold than unit tests.
@@ -1020,7 +1077,17 @@ def _result(
10201077

10211078
def _benchmark_aten(args: argparse.Namespace) -> dict[str, Any]:
10221079
dtype, a, b, bias, residual = _make_matmul_problem(args)
1023-
fn = lambda: _apply_epilogue(args, a @ b, bias, residual, dtype) # noqa: E731
1080+
if args.epilogue == "scaled_mm":
1081+
# ATen's fp8 rowwise GEMM is torch._scaled_mm — the SOTA baseline to
1082+
# time against (a dequantized f32 matmul would be a misleadingly slow
1083+
# reference). _apply_epilogue still owns the scaled_mm *semantics* for
1084+
# the correctness reference in _matmul_expected.
1085+
scale_a, scale_b = _scaled_mm_scales(args)
1086+
fn = lambda: torch._scaled_mm( # noqa: E731
1087+
a, b, scale_a, scale_b, use_fast_accum=False, out_dtype=torch.bfloat16
1088+
)
1089+
else:
1090+
fn = lambda: _apply_epilogue(args, a @ b, bias, residual, dtype) # noqa: E731
10241091
stats = _bench_steady(
10251092
fn,
10261093
num_runs=args.num_runs,
@@ -1533,18 +1600,32 @@ def _helion_matmul_args(
15331600
if args.epilogue == "residual_add":
15341601
assert residual is not None
15351602
return (a, b, ResidualAddEpilogue(residual))
1603+
if args.epilogue == "scaled_mm":
1604+
scale_a, scale_b = _scaled_mm_scales(args)
1605+
# examples/fp8_matmul.fp8_matmul takes (x, y, sa2d, sb1d) directly and
1606+
# bakes the scale in itself (not via an epilogue callable): scale_a as a
1607+
# (M, N) stride-(1,0) colvec view, scale_b as a rank-1 row vector.
1608+
scale_a2d = scale_a.reshape(args.m, 1).expand(args.m, args.n)
1609+
scale_b1d = scale_b.reshape(args.n)
1610+
return (a, b, scale_a2d, scale_b1d)
15361611
raise AssertionError(f"unhandled epilogue {args.epilogue!r}")
15371612

15381613

15391614
def _prepare_helion(args: argparse.Namespace) -> _PreparedHelion:
15401615
backend = args.helion_backend
15411616
os.environ["HELION_BACKEND"] = backend
1542-
from examples.matmul import matmul
15431617

15441618
dtype, a, b, bias, residual = _make_matmul_problem(args)
15451619
expected = _matmul_expected(args, a, b, bias, residual, dtype)
15461620
kernel_args = _helion_matmul_args(args, a, b, bias, residual)
15471621

1622+
if args.epilogue == "scaled_mm":
1623+
# fp8 RowWise scaled_mm uses examples/fp8_matmul.py (hl.dot + fused
1624+
# rowwise scale); the example matmul's torch.addmm does not accept fp8.
1625+
from examples.fp8_matmul import fp8_matmul as matmul
1626+
else:
1627+
from examples.matmul import matmul
1628+
15481629
bound = matmul.bind(kernel_args)
15491630
config = _make_helion_config_from_args(args) if args.helion_force_config else None
15501631
if config is not None and any(key in config.config for key in _TCGEN05_CONFIG_KEYS):
@@ -5931,7 +6012,7 @@ def parse_args() -> argparse.Namespace:
59316012
)
59326013
parser.add_argument(
59336014
"--dtype",
5934-
choices=("float16", "bfloat16", "float32"),
6015+
choices=("float16", "bfloat16", "float32", "float8_e4m3fn"),
59356016
default="bfloat16",
59366017
)
59376018
parser.add_argument("--num-runs", type=int, default=5)
@@ -6387,6 +6468,27 @@ def _uses_invalid_output_diagnostic_mode(args: argparse.Namespace) -> bool:
63876468

63886469

63896470
def _validate_args(args: argparse.Namespace) -> None:
6471+
# fp8 + scaled_mm wiring. The scaled_mm epilogue is the fp8 RowWise path; it
6472+
# is only meaningful for fp8 inputs, and the only impls that implement it are
6473+
# ATen (torch._scaled_mm) and Helion. quack-direct/quack do not.
6474+
if args.epilogue == "scaled_mm" and args.dtype != "float8_e4m3fn":
6475+
raise SystemExit("--epilogue scaled_mm requires --dtype float8_e4m3fn")
6476+
if args.dtype == "float8_e4m3fn":
6477+
if args.epilogue != "scaled_mm":
6478+
raise SystemExit("--dtype float8_e4m3fn requires --epilogue scaled_mm")
6479+
impls = args.impls or list(DEFAULT_IMPLS)
6480+
requested = [args.impl] if args.impl != "all" else impls
6481+
bad = [i for i in requested if i in ("quack", "quack-direct")]
6482+
if bad:
6483+
raise SystemExit(
6484+
f"impl(s) {bad} do not support fp8 scaled_mm; use --impls with "
6485+
"aten and/or helion-cute (quack-direct has no fp8 rowwise GEMM here)"
6486+
)
6487+
if "helion-triton" in requested:
6488+
raise SystemExit(
6489+
"helion-triton does not support the fp8 tcgen05 path; "
6490+
"use --impls aten helion-cute"
6491+
)
63906492
special_modes = (
63916493
args.helion_two_cta_diagnostic_sweep,
63926494
args.helion_two_cta_codegen_report,

docs/api/settings.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def my_kernel(x: torch.Tensor) -> torch.Tensor:
143143
144144
.. autoattribute:: Settings.autotune_log
145145
146-
When set, Helion writes per-config autotuning telemetry (config index, generation, status, perf, compile time, timestamp, config JSON) to ``<value>.csv`` and mirrors the autotune log output to ``<value>.log`` for population-based autotuners (currently ``PatternSearch`` and ``DifferentialEvolution``).
146+
When set, Helion writes per-config autotuning telemetry (kernel id, sample id, config index, generation, status, perf, compile time, timestamp, config JSON) to ``<value>.csv`` and mirrors the autotune log output to ``<value>.log`` for population-based autotuners (currently ``PatternSearch`` and ``DifferentialEvolution``).
147+
The kernel identity (id, name, source, input shapes, dtypes, hardware) is written once per run to ``<value>.meta.json``. ``kernel_id`` is a stable content hash (of the kernel source and code-generation settings) that appears on every CSV row, acting as the foreign key to join rows back to the sidecar and group them by kernel across runs; ``sample_id`` additionally identifies each ``(kernel, config)`` pair so repeated benchmarks of the same config can be deduplicated.
147148
Controlled by ``HELION_AUTOTUNE_LOG``.
148149
149150
.. autoattribute:: Settings.autotune_compile_timeout

0 commit comments

Comments
 (0)