[cute] Fused-scale epilogue: scalar colvec read for per-row scale#2742
Open
yushangdi wants to merge 1 commit into
Open
[cute] Fused-scale epilogue: scalar colvec read for per-row scale#2742yushangdi wants to merge 1 commit into
yushangdi wants to merge 1 commit into
Conversation
yushangdi
added a commit
that referenced
this pull request
Jun 10, 2026
The fp8 scaled_mm epilogue loaded both rowwise scales per subtile AFTER the accumulator consumer_wait, exposing their GMEM latency. The standalone CUTLASS kernel (cute_scaled_mm) instead reads the whole rowvec scale into registers ONCE before the acc wait so its latency overlaps the MMA, and reads the per-row colvec scale as a single scalar. Backports the remaining fused-scale pieces from #2696: - memory_ops.py: register-hoist a rowvec epilogue aux (rowwise scale_b[n]) into a register tensor in per-tile setup (before the subtile loop / acc consumer_wait), fp8-gated; read per-subtile from registers instead of a fresh GMEM load. - cute_fx_walk.py: classify a stride-(1,0) (M,N) aux as a per-row column vector ("broadcast", 2) so scale_a is read as a scalar per subtile (tTR_gAux[(0,0,0,s)]) instead of a redundant N-wide vector load. Generated epilogue now matches the standalone: autovec_copy hoist before the acc consumer_wait, scalar colvec read. Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B, ab_stages=8, do_bench, 10s warmup): before (deep staging only) : 1616 TFLOP/s 86% of aten after (this change) : 1650 TFLOP/s 88% of aten torch._scaled_mm : 1878 TFLOP/s Full cute suite: 93 passed; rel_err 0.0000. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2742, branch: yushangdi/stack/27
b9dd708 to
12dda20
Compare
2e66347 to
ea8db74
Compare
This was referenced Jun 10, 2026
This was referenced Jun 11, 2026
12dda20 to
1573e3d
Compare
yushangdi
added a commit
that referenced
this pull request
Jun 11, 2026
The fp8 scaled_mm epilogue loaded both rowwise scales per subtile AFTER the accumulator consumer_wait, exposing their GMEM latency. The standalone CUTLASS kernel (cute_scaled_mm) instead reads the whole rowvec scale into registers ONCE before the acc wait so its latency overlaps the MMA, and reads the per-row colvec scale as a single scalar. Backports the remaining fused-scale pieces from #2696: - memory_ops.py: register-hoist a rowvec epilogue aux (rowwise scale_b[n]) into a register tensor in per-tile setup (before the subtile loop / acc consumer_wait), fp8-gated; read per-subtile from registers instead of a fresh GMEM load. - cute_fx_walk.py: classify a stride-(1,0) (M,N) aux as a per-row column vector ("broadcast", 2) so scale_a is read as a scalar per subtile (tTR_gAux[(0,0,0,s)]) instead of a redundant N-wide vector load. Generated epilogue now matches the standalone: autovec_copy hoist before the acc consumer_wait, scalar colvec read. Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B, ab_stages=8, do_bench, 10s warmup): before (deep staging only) : 1616 TFLOP/s 86% of aten after (this change) : 1650 TFLOP/s 88% of aten torch._scaled_mm : 1878 TFLOP/s Full cute suite: 93 passed; rel_err 0.0000. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2742, branch: yushangdi/stack/27
ea8db74 to
d3fee78
Compare
c5c24f5 to
be553dd
Compare
yushangdi
added a commit
that referenced
this pull request
Jun 11, 2026
The fp8 scaled_mm epilogue loaded both rowwise scales per subtile AFTER the accumulator consumer_wait, exposing their GMEM latency. The standalone CUTLASS kernel (cute_scaled_mm) instead reads the whole rowvec scale into registers ONCE before the acc wait so its latency overlaps the MMA, and reads the per-row colvec scale as a single scalar. Backports the remaining fused-scale pieces from #2696: - memory_ops.py: register-hoist a rowvec epilogue aux (rowwise scale_b[n]) into a register tensor in per-tile setup (before the subtile loop / acc consumer_wait), fp8-gated; read per-subtile from registers instead of a fresh GMEM load. - cute_fx_walk.py: classify a stride-(1,0) (M,N) aux as a per-row column vector ("broadcast", 2) so scale_a is read as a scalar per subtile (tTR_gAux[(0,0,0,s)]) instead of a redundant N-wide vector load. Generated epilogue now matches the standalone: autovec_copy hoist before the acc consumer_wait, scalar colvec read. Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B, ab_stages=8, do_bench, 10s warmup): before (deep staging only) : 1616 TFLOP/s 86% of aten after (this change) : 1650 TFLOP/s 88% of aten torch._scaled_mm : 1878 TFLOP/s Full cute suite: 93 passed; rel_err 0.0000. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2742, branch: yushangdi/stack/27
d3fee78 to
1510d86
Compare
1510d86 to
7fedb85
Compare
7fedb85 to
1bdb662
Compare
1bdb662 to
ac887ec
Compare
ac887ec to
60d5b4c
Compare
Contributor
Author
|
benchmark: P2375123513 |
yushangdi
added a commit
that referenced
this pull request
Jun 11, 2026
The fp8 scaled_mm epilogue read the per-row column-vector scale (``scale_a[m]``, an ``(M, N)`` view with trailing stride 0) as a full N-wide vector ``.load()`` per subtile, even though the value is uniform over each thread's N fragment. The standalone CUTLASS kernel (cute_scaled_mm) instead reads it as a single scalar (``sa = tTR_gSA[(0,0,0,subtile)]``). Backports the colvec-read piece from #2696: - cute_fx_walk.py: classify a stride-(1,0) (M,N) aux as a per-row column vector ("broadcast", 2) so it is distinguished from a genuine dense (M,N) residual (trailing stride 1 -> "exact") and from the (1,N) leading-broadcast rowvec ("broadcast", 0). Tried before _matches_exact since a stride-0-N aux still has the full (M,N) underlying shape. - memory_ops.py: read a broadcast_axis==2 aux as a scalar per subtile (``tTR_aux_grouped[(0,0,0,subtile)]``) instead of a vector ``.load()``; the scalar broadcasts in the ``acc * aux`` chain multiply. Generated epilogue now matches the standalone's scalar colvec read. Tests: - test_aux_load_kind_* (6 unit tests): drive aux_tensor_load_kind on synthetic load FX nodes, pinning the colvec classification and its boundaries (exact vs colvec vs leading-broadcast, index-order match, global-shape match, extra_mask rejection). - test_tcgen05_fused_colvec_scale_emits_scalar_read_and_is_correct (e2e): compiles a real matmul with scale_a.unsqueeze(1).expand(m, n), asserts the scalar-read marker in the generated code (and that it does NOT fall back to a vector load), and checks numerics. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2742, branch: yushangdi/stack/27
60d5b4c to
8d9123c
Compare
The fp8 scaled_mm epilogue read the per-row column-vector scale (``scale_a[m]``, an ``(M, N)`` view with trailing stride 0) as a full N-wide vector ``.load()`` per subtile, even though the value is uniform over each thread's N fragment. The standalone CUTLASS kernel (cute_scaled_mm) instead reads it as a single scalar (``sa = tTR_gSA[(0,0,0,subtile)]``). Backports the colvec-read piece from #2696: - cute_fx_walk.py: classify a stride-(1,0) (M,N) aux as a per-row column vector ("broadcast", 2) so it is distinguished from a genuine dense (M,N) residual (trailing stride 1 -> "exact") and from the (1,N) leading-broadcast rowvec ("broadcast", 0). Tried before _matches_exact since a stride-0-N aux still has the full (M,N) underlying shape. - memory_ops.py: read a broadcast_axis==2 aux as a scalar per subtile (``tTR_aux_grouped[(0,0,0,subtile)]``) instead of a vector ``.load()``; the scalar broadcasts in the ``acc * aux`` chain multiply. Generated epilogue now matches the standalone's scalar colvec read. Tests: - test_aux_load_kind_* (6 unit tests): drive aux_tensor_load_kind on synthetic load FX nodes, pinning the colvec classification and its boundaries (exact vs colvec vs leading-broadcast, index-order match, global-shape match, extra_mask rejection). - test_tcgen05_fused_colvec_scale_emits_scalar_read_and_is_correct (e2e): compiles a real matmul with scale_a.unsqueeze(1).expand(m, n), asserts the scalar-read marker in the generated code (and that it does NOT fall back to a vector load), and checks numerics. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2742, branch: yushangdi/stack/27
8d9123c to
bc46dbc
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
[cute] Fused-scale epilogue: scalar colvec read for per-row scale
The fp8 scaled_mm epilogue read the per-row column-vector scale
(
scale_a[m], an(M, N)view with trailing stride 0) as a fullN-wide vector
.load()per subtile, even though the value is uniformover each thread's N fragment. The standalone CUTLASS kernel
(cute_scaled_mm) instead reads it as a single scalar
(
sa = tTR_gSA[(0,0,0,subtile)]).Backports the colvec-read piece from #2696:
vector ("broadcast", 2) so it is distinguished from a genuine dense
(M,N) residual (trailing stride 1 -> "exact") and from the (1,N)
leading-broadcast rowvec ("broadcast", 0). Tried before _matches_exact
since a stride-0-N aux still has the full (M,N) underlying shape.
(
tTR_aux_grouped[(0,0,0,subtile)]) instead of a vector.load();the scalar broadcasts in the
acc * auxchain multiply.Generated epilogue now matches the standalone's scalar colvec read.
Tests:
synthetic load FX nodes, pinning the colvec classification and its
boundaries (exact vs colvec vs leading-broadcast, index-order match,
global-shape match, extra_mask rejection).
compiles a real matmul with scale_a.unsqueeze(1).expand(m, n), asserts
the scalar-read marker in the generated code (and that it does NOT fall
back to a vector load), and checks numerics.
Co-Authored-By: Claude Fable 5 noreply@anthropic.com