[cute] Deep AB staging for fp8 to close the compute-bound gap#2741
Open
yushangdi wants to merge 1 commit into
Open
[cute] Deep AB staging for fp8 to close the compute-bound gap#2741yushangdi wants to merge 1 commit into
yushangdi wants to merge 1 commit into
Conversation
a8b1364 to
77192c0
Compare
b9dd708 to
12dda20
Compare
This was referenced Jun 11, 2026
77192c0 to
5908dc7
Compare
yushangdi
added a commit
that referenced
this pull request
Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at ab_stages=3 (the bf16-tuned limit), so its software pipeline was too shallow to hide K-loop TMA latency on compute-bound shapes (4096^3: 1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much deeper AB pipeline than 2-byte bf16 in the same SMEM budget. Backports the tcgen05_config.py deep-staging logic from #2696: - max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the per-CTA budget (mirrors CUTLASS _compute_stages). - _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as long as the AB SMEM fits, instead of hard-failing at 3. - optional_fragments(): widen the ab_stages search cap to 12 for fp8 (1-byte operands) on BOTH the search and validation surfaces, so the autotuner can sample deep pipelines and a frozen deep-staged fp8 config also passes normalize(). Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B, CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup): ab_stages= 3 (prev cap) : 1143 TFLOP/s 62% of aten ab_stages= 6 : 1494 TFLOP/s 80% ab_stages= 8 : 1613 TFLOP/s 86% <- sweet spot ab_stages=10/12 : ~1600 TFLOP/s (SMEM-bound plateau) torch._scaled_mm : 1867 TFLOP/s So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape. Correctness rel_err 0.0000 throughout; full cute suite: 93 passed. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2741, branch: yushangdi/stack/26
12dda20 to
1573e3d
Compare
yushangdi
added a commit
that referenced
this pull request
Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at ab_stages=3 (the bf16-tuned limit), so its software pipeline was too shallow to hide K-loop TMA latency on compute-bound shapes (4096^3: 1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much deeper AB pipeline than 2-byte bf16 in the same SMEM budget. Backports the tcgen05_config.py deep-staging logic from #2696: - max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the per-CTA budget (mirrors CUTLASS _compute_stages). - _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as long as the AB SMEM fits, instead of hard-failing at 3. - optional_fragments(): widen the ab_stages search cap to 12 for fp8 (1-byte operands) on BOTH the search and validation surfaces, so the autotuner can sample deep pipelines and a frozen deep-staged fp8 config also passes normalize(). Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B, CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup): ab_stages= 3 (prev cap) : 1143 TFLOP/s 62% of aten ab_stages= 6 : 1494 TFLOP/s 80% ab_stages= 8 : 1613 TFLOP/s 86% <- sweet spot ab_stages=10/12 : ~1600 TFLOP/s (SMEM-bound plateau) torch._scaled_mm : 1867 TFLOP/s So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape. Correctness rel_err 0.0000 throughout; full cute suite: 93 passed. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2741, branch: yushangdi/stack/26
1573e3d to
615f6ab
Compare
yushangdi
added a commit
that referenced
this pull request
Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at ab_stages=3 (the bf16-tuned limit), so its software pipeline was too shallow to hide K-loop TMA latency on compute-bound shapes (4096^3: 1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much deeper AB pipeline than 2-byte bf16 in the same SMEM budget. Backports the tcgen05_config.py deep-staging logic from #2696: - max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the per-CTA budget (mirrors CUTLASS _compute_stages). - _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as long as the AB SMEM fits, instead of hard-failing at 3. - optional_fragments(): widen the ab_stages search cap to 12 for fp8 (1-byte operands) on BOTH the search and validation surfaces, so the autotuner can sample deep pipelines and a frozen deep-staged fp8 config also passes normalize(). Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B, CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup): ab_stages= 3 (prev cap) : 1143 TFLOP/s 62% of aten ab_stages= 6 : 1494 TFLOP/s 80% ab_stages= 8 : 1613 TFLOP/s 86% <- sweet spot ab_stages=10/12 : ~1600 TFLOP/s (SMEM-bound plateau) torch._scaled_mm : 1867 TFLOP/s So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape. Correctness rel_err 0.0000 throughout; full cute suite: 93 passed. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2741, branch: yushangdi/stack/26
615f6ab to
3221d91
Compare
yushangdi
added a commit
that referenced
this pull request
Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at ab_stages=3 (the bf16-tuned limit), so its software pipeline was too shallow to hide K-loop TMA latency on compute-bound shapes (4096^3: 1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much deeper AB pipeline than 2-byte bf16 in the same SMEM budget. Backports the tcgen05_config.py deep-staging logic from #2696: - max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the per-CTA budget (mirrors CUTLASS _compute_stages). - _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as long as the AB SMEM fits, instead of hard-failing at 3. - optional_fragments(): widen the ab_stages search cap to 12 for fp8 (1-byte operands) on BOTH the search and validation surfaces, so the autotuner can sample deep pipelines and a frozen deep-staged fp8 config also passes normalize(). Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B, CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup): ab_stages= 3 (prev cap) : 1143 TFLOP/s 62% of aten ab_stages= 6 : 1494 TFLOP/s 80% ab_stages= 8 : 1613 TFLOP/s 86% <- sweet spot ab_stages=10/12 : ~1600 TFLOP/s (SMEM-bound plateau) torch._scaled_mm : 1867 TFLOP/s So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape. Correctness rel_err 0.0000 throughout; full cute suite: 93 passed. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2741, branch: yushangdi/stack/26
3221d91 to
c7ffd76
Compare
yushangdi
added a commit
that referenced
this pull request
Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at ab_stages=3 (the bf16-tuned limit), so its software pipeline was too shallow to hide K-loop TMA latency on compute-bound shapes (4096^3: 1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much deeper AB pipeline than 2-byte bf16 in the same SMEM budget. Backports the tcgen05_config.py deep-staging logic from #2696: - max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the per-CTA budget (mirrors CUTLASS _compute_stages). - _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as long as the AB SMEM fits, instead of hard-failing at 3. - optional_fragments(): widen the ab_stages search cap to 12 for fp8 (1-byte operands) on BOTH the search and validation surfaces, so the autotuner can sample deep pipelines and a frozen deep-staged fp8 config also passes normalize(). Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B, CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup): ab_stages= 3 (prev cap) : 1143 TFLOP/s 62% of aten ab_stages= 6 : 1494 TFLOP/s 80% ab_stages= 8 : 1613 TFLOP/s 86% <- sweet spot ab_stages=10/12 : ~1600 TFLOP/s (SMEM-bound plateau) torch._scaled_mm : 1867 TFLOP/s So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape. Correctness rel_err 0.0000 throughout; full cute suite: 93 passed. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2741, branch: yushangdi/stack/26
c7ffd76 to
c5c24f5
Compare
5908dc7 to
62f125b
Compare
yushangdi
added a commit
that referenced
this pull request
Jun 11, 2026
With K-major B in place, the fp8 tcgen05 kernel was still capped at ab_stages=3 (the bf16-tuned limit), so its software pipeline was too shallow to hide K-loop TMA latency on compute-bound shapes (4096^3: 1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much deeper AB pipeline than 2-byte bf16 in the same SMEM budget. Backports the tcgen05_config.py deep-staging logic from #2696: - max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the per-CTA budget (mirrors CUTLASS _compute_stages). - _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as long as the AB SMEM fits, instead of hard-failing at 3. - optional_fragments(): widen the ab_stages search cap to 12 for fp8 (1-byte operands) on BOTH the search and validation surfaces, so the autotuner can sample deep pipelines and a frozen deep-staged fp8 config also passes normalize(). Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B, CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup): ab_stages= 3 (prev cap) : 1143 TFLOP/s 62% of aten ab_stages= 6 : 1494 TFLOP/s 80% ab_stages= 8 : 1613 TFLOP/s 86% <- sweet spot ab_stages=10/12 : ~1600 TFLOP/s (SMEM-bound plateau) torch._scaled_mm : 1867 TFLOP/s So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape. Correctness rel_err 0.0000 throughout; full cute suite: 93 passed. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2741, branch: yushangdi/stack/26
c5c24f5 to
be553dd
Compare
With K-major B in place, the fp8 tcgen05 kernel was still capped at ab_stages=3 (the bf16-tuned limit), so its software pipeline was too shallow to hide K-loop TMA latency on compute-bound shapes (4096^3: 1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much deeper AB pipeline than 2-byte bf16 in the same SMEM budget. Backports the tcgen05_config.py deep-staging logic from #2696: - max_ab_stages_that_fit(): largest ab_stages whose AB SMEM fits the per-CTA budget (mirrors CUTLASS _compute_stages). - _validate_target1_ab_stage_envelope(): admit ab_stages>3 for fp8 as long as the AB SMEM fits, instead of hard-failing at 3. - optional_fragments(): widen the ab_stages search cap to 12 for fp8 (1-byte operands) on BOTH the search and validation surfaces, so the autotuner can sample deep pipelines and a frozen deep-staged fp8 config also passes normalize(). Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B, CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup): ab_stages= 3 (prev cap) : 1143 TFLOP/s 62% of aten ab_stages= 6 : 1494 TFLOP/s 80% ab_stages= 8 : 1613 TFLOP/s 86% <- sweet spot ab_stages=10/12 : ~1600 TFLOP/s (SMEM-bound plateau) torch._scaled_mm : 1867 TFLOP/s So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape. Correctness rel_err 0.0000 throughout; full cute suite: 93 passed. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> stack-info: PR: #2741, branch: yushangdi/stack/26
f878663 to
d01248b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked PRs:
[cute] Deep AB staging for fp8 to close the compute-bound gap
With K-major B in place, the fp8 tcgen05 kernel was still capped at
ab_stages=3 (the bf16-tuned limit), so its software pipeline was too
shallow to hide K-loop TMA latency on compute-bound shapes (4096^3:
1143 TFLOP/s, 62% of torch._scaled_mm). 1-byte fp8 operands fit a much
deeper AB pipeline than 2-byte bf16 in the same SMEM budget.
Backports the tcgen05_config.py deep-staging logic from
#2696:
per-CTA budget (mirrors CUTLASS _compute_stages).
long as the AB SMEM fits, instead of hard-failing at 3.
(1-byte operands) on BOTH the search and validation surfaces, so the
autotuner can sample deep pipelines and a frozen deep-staged fp8 config
also passes normalize().
Benchmark (B200, CUDA 13.2, fp8 e4m3 scaled_mm, m=k=n=4096, col-major B,
CtaGroup.TWO cluster_m=2 role-local, do_bench, 10s warmup):
ab_stages= 3 (prev cap) : 1143 TFLOP/s 62% of aten
ab_stages= 6 : 1494 TFLOP/s 80%
ab_stages= 8 : 1613 TFLOP/s 86% <- sweet spot
ab_stages=10/12 : ~1600 TFLOP/s (SMEM-bound plateau)
torch._scaled_mm : 1867 TFLOP/s
So K-major B (prior commit) + deep AB staging together take fp8 scaled_mm
from ~31% to ~86% of torch._scaled_mm on the 4096^3 compute-bound shape.
Correctness rel_err 0.0000 throughout; full cute suite: 93 passed.
Co-Authored-By: Claude Fable 5 noreply@anthropic.com