192192 "silu" ,
193193 "gelu" ,
194194 "residual_add" ,
195+ # FP8 RowWise scaled_mm: out = scale_a[m] * scale_b[n] * (a_fp8 @ b_fp8).
196+ # The rowwise scale is fused into the epilogue. Intended for --dtype
197+ # float8_e4m3fn; the reference is torch._scaled_mm.
198+ "scaled_mm" ,
195199)
196200QUACK_TUNE_CHOICES = ("off" , "brief" )
197201# Brief tuning covers the documented default, larger cluster/swizzle variants,
@@ -725,9 +729,14 @@ def _dtype_from_name(name: str) -> torch.dtype:
725729 "float16" : torch .float16 ,
726730 "bfloat16" : torch .bfloat16 ,
727731 "float32" : torch .float32 ,
732+ "float8_e4m3fn" : torch .float8_e4m3fn ,
728733 }[name ]
729734
730735
736+ def _is_fp8 (dtype : torch .dtype ) -> bool :
737+ return dtype == torch .float8_e4m3fn
738+
739+
731740def _tflops (m : int , n : int , k : int , ms : float ) -> float :
732741 return (2.0 * m * n * k ) / (ms * 1e9 )
733742
@@ -832,20 +841,41 @@ def _make_inputs(
832841 seed : int ,
833842) -> tuple [torch .Tensor , torch .Tensor ]:
834843 torch .manual_seed (seed )
844+ if _is_fp8 (dtype ):
845+ # fp8 has a tiny dynamic range, so build the operands in f32 and cast.
846+ # b is laid out column-major (K-contiguous), the layout the tcgen05
847+ # fp8 path and torch._scaled_mm expect for the second operand.
848+ a = (torch .randn ((m , k ), device = "cuda" ) * 0.4 ).to (dtype )
849+ b = (torch .randn ((k , n ), device = "cuda" ) * 0.4 ).to (dtype ).T .contiguous ().T
850+ return a , b
835851 a = torch .randn ((m , k ), device = "cuda" , dtype = dtype )
836852 b = torch .randn ((k , n ), device = "cuda" , dtype = dtype ) / math .sqrt (k )
837853 return a , b
838854
839855
856+ def _make_scales (
857+ args : argparse .Namespace ,
858+ ) -> tuple [torch .Tensor , torch .Tensor ]:
859+ """Per-row (scale_a [m,1]) and per-column (scale_b [1,n]) f32 rowwise scales
860+ for the ``scaled_mm`` epilogue. Non-trivial (random) values so a broadcast
861+ bug actually surfaces in the correctness check."""
862+ scale_a = (torch .rand ((args .m , 1 ), device = "cuda" ) + 0.5 ).to (torch .float32 )
863+ scale_b = (torch .rand ((1 , args .n ), device = "cuda" ) + 0.5 ).to (torch .float32 )
864+ return scale_a , scale_b
865+
866+
840867def _make_epilogue_inputs (
841868 args : argparse .Namespace , dtype : torch .dtype
842869) -> tuple [torch .Tensor | None , torch .Tensor | None ]:
843870 bias = None
844871 residual = None
872+ # fp8 epilogue aux tensors (bias/residual) are kept in the *output* dtype
873+ # (bf16); only the matmul operands are fp8.
874+ aux_dtype = torch .bfloat16 if _is_fp8 (dtype ) else dtype
845875 if args .epilogue in ("bias" , "bias_relu" , "bias_residual_gelu" ):
846- bias = torch .randn ((args .n ,), device = "cuda" , dtype = dtype )
876+ bias = torch .randn ((args .n ,), device = "cuda" , dtype = aux_dtype )
847877 if args .epilogue in ("bias_residual_gelu" , "residual_add" ):
848- residual = torch .randn ((args .m , args .n ), device = "cuda" , dtype = dtype )
878+ residual = torch .randn ((args .m , args .n ), device = "cuda" , dtype = aux_dtype )
849879 return bias , residual
850880
851881
@@ -857,9 +887,20 @@ def _make_matmul_problem(
857887 dtype = _dtype_from_name (args .dtype )
858888 a , b = _make_inputs (args .m , args .n , args .k , dtype , seed = args .seed )
859889 bias , residual = _make_epilogue_inputs (args , dtype )
890+ # Stash the rowwise scales on args so the (a, b, bias, residual) tuple
891+ # threaded through every impl stays unchanged. Only the scaled_mm path reads
892+ # them, via _scaled_mm_scales().
893+ if args .epilogue == "scaled_mm" :
894+ args ._scale_a , args ._scale_b = _make_scales (args )
860895 return dtype , a , b , bias , residual
861896
862897
898+ def _scaled_mm_scales (
899+ args : argparse .Namespace ,
900+ ) -> tuple [torch .Tensor , torch .Tensor ]:
901+ return args ._scale_a , args ._scale_b
902+
903+
863904def _apply_epilogue (
864905 args : argparse .Namespace ,
865906 acc : torch .Tensor ,
@@ -889,6 +930,11 @@ def _apply_epilogue(
889930 if args .epilogue == "residual_add" :
890931 assert residual is not None
891932 return acc + residual
933+ if args .epilogue == "scaled_mm" :
934+ # out = scale_a[m] * scale_b[n] * acc, cast to bf16. The scale is folded
935+ # on the f32 accumulator before the cast (matches the fused epilogue).
936+ scale_a , scale_b = _scaled_mm_scales (args )
937+ return (acc .float () * scale_a * scale_b ).to (torch .bfloat16 )
892938 raise AssertionError (f"unhandled epilogue { args .epilogue !r} " )
893939
894940
@@ -900,14 +946,25 @@ def _matmul_expected(
900946 residual : torch .Tensor | None ,
901947 dtype : torch .dtype ,
902948) -> torch .Tensor :
903- return _apply_epilogue (args , a @ b , bias , residual , dtype )
949+ # acc is f32; for fp8 inputs the product is computed in f32 to mirror the
950+ # tensor-core accumulate before any epilogue (scaled_mm, activation, ...).
951+ acc = (a .float () @ b .float ()) if _is_fp8 (dtype ) else (a @ b )
952+ return _apply_epilogue (args , acc , bias , residual , dtype )
904953
905954
906955def _check_close (
907956 actual : torch .Tensor , expected : torch .Tensor , dtype : torch .dtype
908957) -> None :
909958 if dtype == torch .float32 :
910959 torch .testing .assert_close (actual , expected , atol = 1e-4 , rtol = 1e-4 )
960+ elif _is_fp8 (dtype ):
961+ # fp8 (e4m3) operands carry ~2 decimal digits, so the GEMM accumulates
962+ # substantial quantization error; use a relative-error tolerance like
963+ # the scaled_mm unit checks.
964+ ref_max = expected .float ().abs ().max ().item () + 1e-12
965+ rel = (actual .float () - expected .float ()).abs ().max ().item () / ref_max
966+ if rel > 0.1 :
967+ raise AssertionError (f"fp8 mismatch: rel_err={ rel :.4f} > 0.1" )
911968 else :
912969 # bf16/fp16 GEMMs accumulate enough rounding noise that benchmark
913970 # smoke tests need a looser threshold than unit tests.
@@ -1020,7 +1077,17 @@ def _result(
10201077
10211078def _benchmark_aten (args : argparse .Namespace ) -> dict [str , Any ]:
10221079 dtype , a , b , bias , residual = _make_matmul_problem (args )
1023- fn = lambda : _apply_epilogue (args , a @ b , bias , residual , dtype ) # noqa: E731
1080+ if args .epilogue == "scaled_mm" :
1081+ # ATen's fp8 rowwise GEMM is torch._scaled_mm — the SOTA baseline to
1082+ # time against (a dequantized f32 matmul would be a misleadingly slow
1083+ # reference). _apply_epilogue still owns the scaled_mm *semantics* for
1084+ # the correctness reference in _matmul_expected.
1085+ scale_a , scale_b = _scaled_mm_scales (args )
1086+ fn = lambda : torch ._scaled_mm ( # noqa: E731
1087+ a , b , scale_a , scale_b , use_fast_accum = False , out_dtype = torch .bfloat16
1088+ )
1089+ else :
1090+ fn = lambda : _apply_epilogue (args , a @ b , bias , residual , dtype ) # noqa: E731
10241091 stats = _bench_steady (
10251092 fn ,
10261093 num_runs = args .num_runs ,
@@ -1533,18 +1600,32 @@ def _helion_matmul_args(
15331600 if args .epilogue == "residual_add" :
15341601 assert residual is not None
15351602 return (a , b , ResidualAddEpilogue (residual ))
1603+ if args .epilogue == "scaled_mm" :
1604+ scale_a , scale_b = _scaled_mm_scales (args )
1605+ # examples/fp8_matmul.fp8_matmul takes (x, y, sa2d, sb1d) directly and
1606+ # bakes the scale in itself (not via an epilogue callable): scale_a as a
1607+ # (M, N) stride-(1,0) colvec view, scale_b as a rank-1 row vector.
1608+ scale_a2d = scale_a .reshape (args .m , 1 ).expand (args .m , args .n )
1609+ scale_b1d = scale_b .reshape (args .n )
1610+ return (a , b , scale_a2d , scale_b1d )
15361611 raise AssertionError (f"unhandled epilogue { args .epilogue !r} " )
15371612
15381613
15391614def _prepare_helion (args : argparse .Namespace ) -> _PreparedHelion :
15401615 backend = args .helion_backend
15411616 os .environ ["HELION_BACKEND" ] = backend
1542- from examples .matmul import matmul
15431617
15441618 dtype , a , b , bias , residual = _make_matmul_problem (args )
15451619 expected = _matmul_expected (args , a , b , bias , residual , dtype )
15461620 kernel_args = _helion_matmul_args (args , a , b , bias , residual )
15471621
1622+ if args .epilogue == "scaled_mm" :
1623+ # fp8 RowWise scaled_mm uses examples/fp8_matmul.py (hl.dot + fused
1624+ # rowwise scale); the example matmul's torch.addmm does not accept fp8.
1625+ from examples .fp8_matmul import fp8_matmul as matmul
1626+ else :
1627+ from examples .matmul import matmul
1628+
15481629 bound = matmul .bind (kernel_args )
15491630 config = _make_helion_config_from_args (args ) if args .helion_force_config else None
15501631 if config is not None and any (key in config .config for key in _TCGEN05_CONFIG_KEYS ):
@@ -5931,7 +6012,7 @@ def parse_args() -> argparse.Namespace:
59316012 )
59326013 parser .add_argument (
59336014 "--dtype" ,
5934- choices = ("float16" , "bfloat16" , "float32" ),
6015+ choices = ("float16" , "bfloat16" , "float32" , "float8_e4m3fn" ),
59356016 default = "bfloat16" ,
59366017 )
59376018 parser .add_argument ("--num-runs" , type = int , default = 5 )
@@ -6387,6 +6468,27 @@ def _uses_invalid_output_diagnostic_mode(args: argparse.Namespace) -> bool:
63876468
63886469
63896470def _validate_args (args : argparse .Namespace ) -> None :
6471+ # fp8 + scaled_mm wiring. The scaled_mm epilogue is the fp8 RowWise path; it
6472+ # is only meaningful for fp8 inputs, and the only impls that implement it are
6473+ # ATen (torch._scaled_mm) and Helion. quack-direct/quack do not.
6474+ if args .epilogue == "scaled_mm" and args .dtype != "float8_e4m3fn" :
6475+ raise SystemExit ("--epilogue scaled_mm requires --dtype float8_e4m3fn" )
6476+ if args .dtype == "float8_e4m3fn" :
6477+ if args .epilogue != "scaled_mm" :
6478+ raise SystemExit ("--dtype float8_e4m3fn requires --epilogue scaled_mm" )
6479+ impls = args .impls or list (DEFAULT_IMPLS )
6480+ requested = [args .impl ] if args .impl != "all" else impls
6481+ bad = [i for i in requested if i in ("quack" , "quack-direct" )]
6482+ if bad :
6483+ raise SystemExit (
6484+ f"impl(s) { bad } do not support fp8 scaled_mm; use --impls with "
6485+ "aten and/or helion-cute (quack-direct has no fp8 rowwise GEMM here)"
6486+ )
6487+ if "helion-triton" in requested :
6488+ raise SystemExit (
6489+ "helion-triton does not support the fp8 tcgen05 path; "
6490+ "use --impls aten helion-cute"
6491+ )
63906492 special_modes = (
63916493 args .helion_two_cta_diagnostic_sweep ,
63926494 args .helion_two_cta_codegen_report ,
0 commit comments