Skip to content

[cute] Fused-scale epilogue: scalar colvec read for per-row scale#2742

Open
yushangdi wants to merge 1 commit into
yushangdi/stack/26from
yushangdi/stack/27
Open

[cute] Fused-scale epilogue: scalar colvec read for per-row scale#2742
yushangdi wants to merge 1 commit into
yushangdi/stack/26from
yushangdi/stack/27

Conversation

@yushangdi

@yushangdi yushangdi commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


[cute] Fused-scale epilogue: scalar colvec read for per-row scale

The fp8 scaled_mm epilogue read the per-row column-vector scale
(scale_a[m], an (M, N) view with trailing stride 0) as a full
N-wide vector .load() per subtile, even though the value is uniform
over each thread's N fragment. The standalone CUTLASS kernel
(cute_scaled_mm) instead reads it as a single scalar
(sa = tTR_gSA[(0,0,0,subtile)]).

Backports the colvec-read piece from #2696:

  • cute_fx_walk.py: classify a stride-(1,0) (M,N) aux as a per-row column
    vector ("broadcast", 2) so it is distinguished from a genuine dense
    (M,N) residual (trailing stride 1 -> "exact") and from the (1,N)
    leading-broadcast rowvec ("broadcast", 0). Tried before _matches_exact
    since a stride-0-N aux still has the full (M,N) underlying shape.
  • memory_ops.py: read a broadcast_axis==2 aux as a scalar per subtile
    (tTR_aux_grouped[(0,0,0,subtile)]) instead of a vector .load();
    the scalar broadcasts in the acc * aux chain multiply.

Generated epilogue now matches the standalone's scalar colvec read.

Tests:

  • test_aux_load_kind_* (6 unit tests): drive aux_tensor_load_kind on
    synthetic load FX nodes, pinning the colvec classification and its
    boundaries (exact vs colvec vs leading-broadcast, index-order match,
    global-shape match, extra_mask rejection).
  • test_tcgen05_fused_colvec_scale_emits_scalar_read_and_is_correct (e2e):
    compiles a real matmul with scale_a.unsqueeze(1).expand(m, n), asserts
    the scalar-read marker in the generated code (and that it does NOT fall
    back to a vector load), and checks numerics.

Co-Authored-By: Claude Fable 5 noreply@anthropic.com

yushangdi added a commit that referenced this pull request Jun 10, 2026
The fp8 scaled_mm epilogue loaded both rowwise scales per subtile AFTER
the accumulator consumer_wait, exposing their GMEM latency. The standalone
CUTLASS kernel (cute_scaled_mm) instead reads the whole rowvec scale into
registers ONCE before the acc wait so its latency overlaps the MMA, and
reads the per-row colvec scale as a single scalar.

Backports the remaining fused-scale pieces from #2696:
- memory_ops.py: register-hoist a rowvec epilogue aux (rowwise scale_b[n])
  into a register tensor in per-tile setup (before the subtile loop / acc
  consumer_wait), fp8-gated; read per-subtile from registers instead of a
  fresh GMEM load.
- cute_fx_walk.py: classify a stride-(1,0) (M,N) aux as a per-row column
  vector ("broadcast", 2) so scale_a is read as a scalar per subtile
  (tTR_gAux[(0,0,0,s)]) instead of a redundant N-wide vector load.

Generated epilogue now matches the standalone: autovec_copy hoist before
the acc consumer_wait, scalar colvec read.

Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
ab_stages=8, do_bench, 10s warmup):

  before (deep staging only) : 1616 TFLOP/s   86% of aten
  after (this change)        : 1650 TFLOP/s   88% of aten
  torch._scaled_mm           : 1878 TFLOP/s

Full cute suite: 93 passed; rel_err 0.0000.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2742, branch: yushangdi/stack/27
@yushangdi yushangdi force-pushed the yushangdi/stack/26 branch from b9dd708 to 12dda20 Compare June 10, 2026 18:12
@yushangdi yushangdi force-pushed the yushangdi/stack/27 branch from 2e66347 to ea8db74 Compare June 10, 2026 18:12
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 10, 2026
@yushangdi yushangdi changed the base branch from yushangdi/stack/26 to main June 11, 2026 02:29
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/26 June 11, 2026 02:29
@yushangdi yushangdi changed the base branch from yushangdi/stack/26 to main June 11, 2026 02:35
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/26 June 11, 2026 02:36
@yushangdi yushangdi force-pushed the yushangdi/stack/26 branch from 12dda20 to 1573e3d Compare June 11, 2026 17:07
yushangdi added a commit that referenced this pull request Jun 11, 2026
The fp8 scaled_mm epilogue loaded both rowwise scales per subtile AFTER
the accumulator consumer_wait, exposing their GMEM latency. The standalone
CUTLASS kernel (cute_scaled_mm) instead reads the whole rowvec scale into
registers ONCE before the acc wait so its latency overlaps the MMA, and
reads the per-row colvec scale as a single scalar.

Backports the remaining fused-scale pieces from #2696:
- memory_ops.py: register-hoist a rowvec epilogue aux (rowwise scale_b[n])
  into a register tensor in per-tile setup (before the subtile loop / acc
  consumer_wait), fp8-gated; read per-subtile from registers instead of a
  fresh GMEM load.
- cute_fx_walk.py: classify a stride-(1,0) (M,N) aux as a per-row column
  vector ("broadcast", 2) so scale_a is read as a scalar per subtile
  (tTR_gAux[(0,0,0,s)]) instead of a redundant N-wide vector load.

Generated epilogue now matches the standalone: autovec_copy hoist before
the acc consumer_wait, scalar colvec read.

Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
ab_stages=8, do_bench, 10s warmup):

  before (deep staging only) : 1616 TFLOP/s   86% of aten
  after (this change)        : 1650 TFLOP/s   88% of aten
  torch._scaled_mm           : 1878 TFLOP/s

Full cute suite: 93 passed; rel_err 0.0000.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2742, branch: yushangdi/stack/27
@yushangdi yushangdi force-pushed the yushangdi/stack/27 branch from ea8db74 to d3fee78 Compare June 11, 2026 17:08
@yushangdi yushangdi force-pushed the yushangdi/stack/26 branch 5 times, most recently from c5c24f5 to be553dd Compare June 11, 2026 20:32
yushangdi added a commit that referenced this pull request Jun 11, 2026
The fp8 scaled_mm epilogue loaded both rowwise scales per subtile AFTER
the accumulator consumer_wait, exposing their GMEM latency. The standalone
CUTLASS kernel (cute_scaled_mm) instead reads the whole rowvec scale into
registers ONCE before the acc wait so its latency overlaps the MMA, and
reads the per-row colvec scale as a single scalar.

Backports the remaining fused-scale pieces from #2696:
- memory_ops.py: register-hoist a rowvec epilogue aux (rowwise scale_b[n])
  into a register tensor in per-tile setup (before the subtile loop / acc
  consumer_wait), fp8-gated; read per-subtile from registers instead of a
  fresh GMEM load.
- cute_fx_walk.py: classify a stride-(1,0) (M,N) aux as a per-row column
  vector ("broadcast", 2) so scale_a is read as a scalar per subtile
  (tTR_gAux[(0,0,0,s)]) instead of a redundant N-wide vector load.

Generated epilogue now matches the standalone: autovec_copy hoist before
the acc consumer_wait, scalar colvec read.

Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
ab_stages=8, do_bench, 10s warmup):

  before (deep staging only) : 1616 TFLOP/s   86% of aten
  after (this change)        : 1650 TFLOP/s   88% of aten
  torch._scaled_mm           : 1878 TFLOP/s

Full cute suite: 93 passed; rel_err 0.0000.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2742, branch: yushangdi/stack/27
@yushangdi yushangdi force-pushed the yushangdi/stack/27 branch from d3fee78 to 1510d86 Compare June 11, 2026 20:49
@yushangdi yushangdi changed the base branch from yushangdi/stack/26 to main June 11, 2026 21:17
@yushangdi yushangdi force-pushed the yushangdi/stack/27 branch from 1510d86 to 7fedb85 Compare June 11, 2026 21:17
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/26 June 11, 2026 21:18
@yushangdi yushangdi changed the base branch from yushangdi/stack/26 to main June 11, 2026 22:00
@yushangdi yushangdi force-pushed the yushangdi/stack/27 branch from 7fedb85 to 1bdb662 Compare June 11, 2026 22:00
@yushangdi yushangdi changed the title [cute] Fused-scale epilogue opts: hoist rowvec load, scalar colvec read [cute] Fused-scale epilogue: scalar colvec read for per-row scale Jun 11, 2026
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/26 June 11, 2026 22:00
@yushangdi yushangdi changed the base branch from yushangdi/stack/26 to main June 11, 2026 22:16
@yushangdi yushangdi force-pushed the yushangdi/stack/27 branch from 1bdb662 to ac887ec Compare June 11, 2026 22:16
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/26 June 11, 2026 22:16
@yushangdi yushangdi changed the base branch from yushangdi/stack/26 to main June 11, 2026 22:58
@yushangdi yushangdi force-pushed the yushangdi/stack/27 branch from ac887ec to 60d5b4c Compare June 11, 2026 22:58
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/26 June 11, 2026 22:58
@yushangdi

Copy link
Copy Markdown
Contributor Author

benchmark: P2375123513

@yushangdi yushangdi changed the base branch from yushangdi/stack/26 to main June 11, 2026 23:12
yushangdi added a commit that referenced this pull request Jun 11, 2026
The fp8 scaled_mm epilogue read the per-row column-vector scale
(``scale_a[m]``, an ``(M, N)`` view with trailing stride 0) as a full
N-wide vector ``.load()`` per subtile, even though the value is uniform
over each thread's N fragment. The standalone CUTLASS kernel
(cute_scaled_mm) instead reads it as a single scalar
(``sa = tTR_gSA[(0,0,0,subtile)]``).

Backports the colvec-read piece from #2696:
- cute_fx_walk.py: classify a stride-(1,0) (M,N) aux as a per-row column
  vector ("broadcast", 2) so it is distinguished from a genuine dense
  (M,N) residual (trailing stride 1 -> "exact") and from the (1,N)
  leading-broadcast rowvec ("broadcast", 0). Tried before _matches_exact
  since a stride-0-N aux still has the full (M,N) underlying shape.
- memory_ops.py: read a broadcast_axis==2 aux as a scalar per subtile
  (``tTR_aux_grouped[(0,0,0,subtile)]``) instead of a vector ``.load()``;
  the scalar broadcasts in the ``acc * aux`` chain multiply.

Generated epilogue now matches the standalone's scalar colvec read.

Tests:
- test_aux_load_kind_* (6 unit tests): drive aux_tensor_load_kind on
  synthetic load FX nodes, pinning the colvec classification and its
  boundaries (exact vs colvec vs leading-broadcast, index-order match,
  global-shape match, extra_mask rejection).
- test_tcgen05_fused_colvec_scale_emits_scalar_read_and_is_correct (e2e):
  compiles a real matmul with scale_a.unsqueeze(1).expand(m, n), asserts
  the scalar-read marker in the generated code (and that it does NOT fall
  back to a vector load), and checks numerics.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2742, branch: yushangdi/stack/27
@yushangdi yushangdi force-pushed the yushangdi/stack/27 branch from 60d5b4c to 8d9123c Compare June 11, 2026 23:12
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/26 June 11, 2026 23:12
@yushangdi yushangdi marked this pull request as ready for review June 11, 2026 23:13
@yushangdi yushangdi requested review from jansel and oulgen June 11, 2026 23:13
@yushangdi yushangdi marked this pull request as draft June 11, 2026 23:17
@yushangdi yushangdi changed the base branch from yushangdi/stack/26 to main June 11, 2026 23:17
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/26 June 11, 2026 23:17
@yushangdi yushangdi marked this pull request as ready for review June 11, 2026 23:17
@yushangdi yushangdi removed request for jansel and oulgen June 11, 2026 23:29
@yushangdi yushangdi marked this pull request as draft June 11, 2026 23:29
The fp8 scaled_mm epilogue read the per-row column-vector scale
(``scale_a[m]``, an ``(M, N)`` view with trailing stride 0) as a full
N-wide vector ``.load()`` per subtile, even though the value is uniform
over each thread's N fragment. The standalone CUTLASS kernel
(cute_scaled_mm) instead reads it as a single scalar
(``sa = tTR_gSA[(0,0,0,subtile)]``).

Backports the colvec-read piece from #2696:
- cute_fx_walk.py: classify a stride-(1,0) (M,N) aux as a per-row column
  vector ("broadcast", 2) so it is distinguished from a genuine dense
  (M,N) residual (trailing stride 1 -> "exact") and from the (1,N)
  leading-broadcast rowvec ("broadcast", 0). Tried before _matches_exact
  since a stride-0-N aux still has the full (M,N) underlying shape.
- memory_ops.py: read a broadcast_axis==2 aux as a scalar per subtile
  (``tTR_aux_grouped[(0,0,0,subtile)]``) instead of a vector ``.load()``;
  the scalar broadcasts in the ``acc * aux`` chain multiply.

Generated epilogue now matches the standalone's scalar colvec read.

Tests:
- test_aux_load_kind_* (6 unit tests): drive aux_tensor_load_kind on
  synthetic load FX nodes, pinning the colvec classification and its
  boundaries (exact vs colvec vs leading-broadcast, index-order match,
  global-shape match, extra_mask rejection).
- test_tcgen05_fused_colvec_scale_emits_scalar_read_and_is_correct (e2e):
  compiles a real matmul with scale_a.unsqueeze(1).expand(m, n), asserts
  the scalar-read marker in the generated code (and that it does NOT fall
  back to a vector load), and checks numerics.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2742, branch: yushangdi/stack/27
@yushangdi yushangdi changed the base branch from yushangdi/stack/26 to main June 11, 2026 23:41
@yushangdi yushangdi force-pushed the yushangdi/stack/27 branch from 8d9123c to bc46dbc Compare June 11, 2026 23:41
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/26 June 11, 2026 23:42
@yushangdi yushangdi marked this pull request as ready for review June 12, 2026 00:02
@yushangdi yushangdi requested review from jansel and oulgen June 12, 2026 00:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant