Skip to content

[cute] Deep AB staging for fp8 to close the compute-bound gap#2741

Open
yushangdi wants to merge 1 commit into
yushangdi/stack/25from
yushangdi/stack/26
Open

[cute] Deep AB staging for fp8 to close the compute-bound gap#2741
yushangdi wants to merge 1 commit into
yushangdi/stack/25from
yushangdi/stack/26

Conversation

@yushangdi

@yushangdi yushangdi commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


[cute] Deep AB staging for fp8 to close the compute-bound gap

With K-major B in place, the fp8 tcgen05 kernel was still capped at
ab_stages=3 (the bf16-tuned limit), so its software pipeline was too
shallow to hide K-loop TMA latency on compute-bound shapes (4096^3:
1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much
deeper AB pipeline than 2-byte bf16 in the same SMEM budget.

Backports the tcgen05_config.py deep-staging logic from
#2696:

  • max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the
    per-CTA budget (mirrors CUTLASS _compute_stages).
  • _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as
    long as the AB SMEM fits, instead of hard-failing at 3.
  • optional_fragments(): widen the ab_stages search cap to 12 for fp8
    (1-byte operands) on BOTH the search and validation surfaces, so the
    autotuner can sample deep pipelines and a frozen deep-staged fp8 config
    also passes normalize().

Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup):

ab_stages= 3 (prev cap) : 1143 TFLOP/s 62% of aten
ab_stages= 6 : 1494 TFLOP/s 80%
ab_stages= 8 : 1613 TFLOP/s 86% <- sweet spot
ab_stages=10/12 : ~1600 TFLOP/s (SMEM-bound plateau)
torch._scaled_mm : 1867 TFLOP/s

So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm
from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape.
Correctness rel_err 0.0000 throughout; full cute suite: 93 passed.

Co-Authored-By: Claude Fable 5 noreply@anthropic.com

@yushangdi yushangdi force-pushed the yushangdi/stack/25 branch from a8b1364 to 77192c0 Compare June 10, 2026 18:12
@yushangdi yushangdi force-pushed the yushangdi/stack/26 branch from b9dd708 to 12dda20 Compare June 10, 2026 18:12
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 10, 2026
@yushangdi yushangdi changed the base branch from yushangdi/stack/25 to main June 11, 2026 02:29
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/25 June 11, 2026 02:29
@yushangdi yushangdi changed the base branch from yushangdi/stack/25 to main June 11, 2026 02:35
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/25 June 11, 2026 02:36
@yushangdi yushangdi force-pushed the yushangdi/stack/25 branch from 77192c0 to 5908dc7 Compare June 11, 2026 17:05
yushangdi added a commit that referenced this pull request Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at
ab_stages=3 (the bf16-tuned limit), so its software pipeline was too
shallow to hide K-loop TMA latency on compute-bound shapes (4096^3:
1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much
deeper AB pipeline than 2-byte bf16 in the same SMEM budget.

Backports the tcgen05_config.py deep-staging logic from
#2696:
- max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the
  per-CTA budget (mirrors CUTLASS _compute_stages).
- _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as
  long as the AB SMEM fits, instead of hard-failing at 3.
- optional_fragments(): widen the ab_stages search cap to 12 for fp8
  (1-byte operands) on BOTH the search and validation surfaces, so the
  autotuner can sample deep pipelines and a frozen deep-staged fp8 config
  also passes normalize().

Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup):

  ab_stages= 3 (prev cap) : 1143 TFLOP/s   62% of aten
  ab_stages= 6            : 1494 TFLOP/s   80%
  ab_stages= 8            : 1613 TFLOP/s   86%   <- sweet spot
  ab_stages=10/12         : ~1600 TFLOP/s  (SMEM-bound plateau)
  torch._scaled_mm        : 1867 TFLOP/s

So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm
from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape.
Correctness rel_err 0.0000 throughout; full cute suite: 93 passed.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2741, branch: yushangdi/stack/26
@yushangdi yushangdi force-pushed the yushangdi/stack/26 branch from 12dda20 to 1573e3d Compare June 11, 2026 17:07
yushangdi added a commit that referenced this pull request Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at
ab_stages=3 (the bf16-tuned limit), so its software pipeline was too
shallow to hide K-loop TMA latency on compute-bound shapes (4096^3:
1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much
deeper AB pipeline than 2-byte bf16 in the same SMEM budget.

Backports the tcgen05_config.py deep-staging logic from
#2696:
- max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the
  per-CTA budget (mirrors CUTLASS _compute_stages).
- _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as
  long as the AB SMEM fits, instead of hard-failing at 3.
- optional_fragments(): widen the ab_stages search cap to 12 for fp8
  (1-byte operands) on BOTH the search and validation surfaces, so the
  autotuner can sample deep pipelines and a frozen deep-staged fp8 config
  also passes normalize().

Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup):

  ab_stages= 3 (prev cap) : 1143 TFLOP/s   62% of aten
  ab_stages= 6            : 1494 TFLOP/s   80%
  ab_stages= 8            : 1613 TFLOP/s   86%   <- sweet spot
  ab_stages=10/12         : ~1600 TFLOP/s  (SMEM-bound plateau)
  torch._scaled_mm        : 1867 TFLOP/s

So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm
from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape.
Correctness rel_err 0.0000 throughout; full cute suite: 93 passed.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2741, branch: yushangdi/stack/26
@yushangdi yushangdi force-pushed the yushangdi/stack/26 branch from 1573e3d to 615f6ab Compare June 11, 2026 17:41
yushangdi added a commit that referenced this pull request Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at
ab_stages=3 (the bf16-tuned limit), so its software pipeline was too
shallow to hide K-loop TMA latency on compute-bound shapes (4096^3:
1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much
deeper AB pipeline than 2-byte bf16 in the same SMEM budget.

Backports the tcgen05_config.py deep-staging logic from
#2696:
- max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the
  per-CTA budget (mirrors CUTLASS _compute_stages).
- _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as
  long as the AB SMEM fits, instead of hard-failing at 3.
- optional_fragments(): widen the ab_stages search cap to 12 for fp8
  (1-byte operands) on BOTH the search and validation surfaces, so the
  autotuner can sample deep pipelines and a frozen deep-staged fp8 config
  also passes normalize().

Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup):

  ab_stages= 3 (prev cap) : 1143 TFLOP/s   62% of aten
  ab_stages= 6            : 1494 TFLOP/s   80%
  ab_stages= 8            : 1613 TFLOP/s   86%   <- sweet spot
  ab_stages=10/12         : ~1600 TFLOP/s  (SMEM-bound plateau)
  torch._scaled_mm        : 1867 TFLOP/s

So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm
from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape.
Correctness rel_err 0.0000 throughout; full cute suite: 93 passed.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2741, branch: yushangdi/stack/26
@yushangdi yushangdi force-pushed the yushangdi/stack/26 branch from 615f6ab to 3221d91 Compare June 11, 2026 18:01
yushangdi added a commit that referenced this pull request Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at
ab_stages=3 (the bf16-tuned limit), so its software pipeline was too
shallow to hide K-loop TMA latency on compute-bound shapes (4096^3:
1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much
deeper AB pipeline than 2-byte bf16 in the same SMEM budget.

Backports the tcgen05_config.py deep-staging logic from
#2696:
- max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the
  per-CTA budget (mirrors CUTLASS _compute_stages).
- _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as
  long as the AB SMEM fits, instead of hard-failing at 3.
- optional_fragments(): widen the ab_stages search cap to 12 for fp8
  (1-byte operands) on BOTH the search and validation surfaces, so the
  autotuner can sample deep pipelines and a frozen deep-staged fp8 config
  also passes normalize().

Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup):

  ab_stages= 3 (prev cap) : 1143 TFLOP/s   62% of aten
  ab_stages= 6            : 1494 TFLOP/s   80%
  ab_stages= 8            : 1613 TFLOP/s   86%   <- sweet spot
  ab_stages=10/12         : ~1600 TFLOP/s  (SMEM-bound plateau)
  torch._scaled_mm        : 1867 TFLOP/s

So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm
from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape.
Correctness rel_err 0.0000 throughout; full cute suite: 93 passed.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2741, branch: yushangdi/stack/26
@yushangdi yushangdi force-pushed the yushangdi/stack/26 branch from 3221d91 to c7ffd76 Compare June 11, 2026 18:12
yushangdi added a commit that referenced this pull request Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at
ab_stages=3 (the bf16-tuned limit), so its software pipeline was too
shallow to hide K-loop TMA latency on compute-bound shapes (4096^3:
1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much
deeper AB pipeline than 2-byte bf16 in the same SMEM budget.

Backports the tcgen05_config.py deep-staging logic from
#2696:
- max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the
  per-CTA budget (mirrors CUTLASS _compute_stages).
- _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as
  long as the AB SMEM fits, instead of hard-failing at 3.
- optional_fragments(): widen the ab_stages search cap to 12 for fp8
  (1-byte operands) on BOTH the search and validation surfaces, so the
  autotuner can sample deep pipelines and a frozen deep-staged fp8 config
  also passes normalize().

Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup):

  ab_stages= 3 (prev cap) : 1143 TFLOP/s   62% of aten
  ab_stages= 6            : 1494 TFLOP/s   80%
  ab_stages= 8            : 1613 TFLOP/s   86%   <- sweet spot
  ab_stages=10/12         : ~1600 TFLOP/s  (SMEM-bound plateau)
  torch._scaled_mm        : 1867 TFLOP/s

So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm
from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape.
Correctness rel_err 0.0000 throughout; full cute suite: 93 passed.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2741, branch: yushangdi/stack/26
@yushangdi yushangdi force-pushed the yushangdi/stack/26 branch from c7ffd76 to c5c24f5 Compare June 11, 2026 18:23
@yushangdi yushangdi force-pushed the yushangdi/stack/25 branch from 5908dc7 to 62f125b Compare June 11, 2026 20:29
yushangdi added a commit that referenced this pull request Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at
ab_stages=3 (the bf16-tuned limit), so its software pipeline was too
shallow to hide K-loop TMA latency on compute-bound shapes (4096^3:
1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much
deeper AB pipeline than 2-byte bf16 in the same SMEM budget.

Backports the tcgen05_config.py deep-staging logic from
#2696:
- max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the
  per-CTA budget (mirrors CUTLASS _compute_stages).
- _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as
  long as the AB SMEM fits, instead of hard-failing at 3.
- optional_fragments(): widen the ab_stages search cap to 12 for fp8
  (1-byte operands) on BOTH the search and validation surfaces, so the
  autotuner can sample deep pipelines and a frozen deep-staged fp8 config
  also passes normalize().

Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup):

  ab_stages= 3 (prev cap) : 1143 TFLOP/s   62% of aten
  ab_stages= 6            : 1494 TFLOP/s   80%
  ab_stages= 8            : 1613 TFLOP/s   86%   <- sweet spot
  ab_stages=10/12         : ~1600 TFLOP/s  (SMEM-bound plateau)
  torch._scaled_mm        : 1867 TFLOP/s

So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm
from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape.
Correctness rel_err 0.0000 throughout; full cute suite: 93 passed.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2741, branch: yushangdi/stack/26
@yushangdi yushangdi force-pushed the yushangdi/stack/26 branch from c5c24f5 to be553dd Compare June 11, 2026 20:32
@yushangdi yushangdi marked this pull request as ready for review June 11, 2026 20:32
@yushangdi yushangdi requested a review from jansel June 11, 2026 20:33
@yushangdi yushangdi changed the base branch from yushangdi/stack/25 to main June 11, 2026 22:00
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/25 June 11, 2026 22:00
@yushangdi yushangdi marked this pull request as ready for review June 11, 2026 22:00
@yushangdi yushangdi marked this pull request as draft June 11, 2026 22:15
@yushangdi yushangdi changed the base branch from yushangdi/stack/25 to main June 11, 2026 22:15
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/25 June 11, 2026 22:16
@yushangdi yushangdi marked this pull request as ready for review June 11, 2026 22:16
@yushangdi yushangdi marked this pull request as draft June 11, 2026 22:58
@yushangdi yushangdi changed the base branch from yushangdi/stack/25 to main June 11, 2026 22:58
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/25 June 11, 2026 22:58
@yushangdi yushangdi marked this pull request as ready for review June 11, 2026 22:58
@yushangdi yushangdi marked this pull request as draft June 11, 2026 23:12
@yushangdi yushangdi changed the base branch from yushangdi/stack/25 to main June 11, 2026 23:12
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/25 June 11, 2026 23:12
@yushangdi yushangdi marked this pull request as ready for review June 11, 2026 23:12
@yushangdi yushangdi marked this pull request as draft June 11, 2026 23:17
@yushangdi yushangdi changed the base branch from yushangdi/stack/25 to main June 11, 2026 23:17
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/25 June 11, 2026 23:17
@yushangdi yushangdi marked this pull request as ready for review June 11, 2026 23:17
@yushangdi yushangdi marked this pull request as draft June 11, 2026 23:31
With K-major B in place, the fp8 tcgen05 kernel was still capped at
ab_stages=3 (the bf16-tuned limit), so its software pipeline was too
shallow to hide K-loop TMA latency on compute-bound shapes (4096^3:
1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much
deeper AB pipeline than 2-byte bf16 in the same SMEM budget.

Backports the tcgen05_config.py deep-staging logic from
#2696:
- max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the
  per-CTA budget (mirrors CUTLASS _compute_stages).
- _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as
  long as the AB SMEM fits, instead of hard-failing at 3.
- optional_fragments(): widen the ab_stages search cap to 12 for fp8
  (1-byte operands) on BOTH the search and validation surfaces, so the
  autotuner can sample deep pipelines and a frozen deep-staged fp8 config
  also passes normalize().

Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup):

  ab_stages= 3 (prev cap) : 1143 TFLOP/s   62% of aten
  ab_stages= 6            : 1494 TFLOP/s   80%
  ab_stages= 8            : 1613 TFLOP/s   86%   <- sweet spot
  ab_stages=10/12         : ~1600 TFLOP/s  (SMEM-bound plateau)
  torch._scaled_mm        : 1867 TFLOP/s

So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm
from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape.
Correctness rel_err 0.0000 throughout; full cute suite: 93 passed.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

stack-info: PR: #2741, branch: yushangdi/stack/26
@yushangdi yushangdi changed the base branch from yushangdi/stack/25 to main June 11, 2026 23:41
@yushangdi yushangdi force-pushed the yushangdi/stack/26 branch from f878663 to d01248b Compare June 11, 2026 23:41
@yushangdi yushangdi changed the base branch from main to yushangdi/stack/25 June 11, 2026 23:42
@yushangdi yushangdi marked this pull request as ready for review June 12, 2026 00:01
@yushangdi yushangdi requested review from jansel and oulgen June 12, 2026 00:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant