Skip to content

[cuda backend] int4/8 matvec: vectorized activation load #20144

Merged
Gasoonjia merged 6 commits into
mainfrom
g4-opt-int4-vecload
Jun 12, 2026
Merged

[cuda backend] int4/8 matvec: vectorized activation load #20144
Gasoonjia merged 6 commits into
mainfrom
g4-opt-int4-vecload

Conversation

@Gasoonjia

@Gasoonjia Gasoonjia commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

The decode-only int4_plain_mm matvec was bound by activation load-instruction throughput, not DRAM bandwidth (already ~64% peak) or latency. Each inner iteration issued ~15 loads per 16-byte weight chunk: 8 scalar int32 activation loads + the same per-block scale d reloaded 4x. Same as int8_plain_mm

Align Q8Block to 16 bytes (sizeof 36->48) so each block's qs_even/qs_odd 16B halves are 16B-aligned, then load a whole activation block with two vectorized uint4 loads + one d load (~4x fewer activation loads). dp4a math and accumulation order are bit-identical; the int8 activation values and scale are unchanged.

gemma4_31b decode (long-ctx harness, stacked on optimize_1):
decode 43.98 -> 46.557 tok/s (+6.4%), +12.7% compare with llama.cpp (41.5 token/s)

profile result: int4 matvec avg 38.4 -> 34.75 us (-9.5%); quant kernel unchanged.

Gasoonjia added 4 commits June 8, 2026 22:15
…decode

Coalesce int4 W4A8 decode-matvec scale/zero loads by baking the
[N, n_groups] layout into the weight constant at pack time. Introduces
CudaCoalescedInt4Tensor (an ExecuTorch-internal subclass) that owns the
[n_groups, N] -> [N, n_groups] transpose, registers the int4_plain_mm
dispatch on it by type, and adds the coalesced dp4a matvec kernel that
reads scale/zero row-for-row with qdata (single coalesced load vs 32
stride-N cache lines). ~29.2 -> 37.4 tok/s on gemma group_size=32.

Rebased onto main; INT8 dp4a decode op and the floor_div pass from this
branch landed separately and now live in quantize_op_dispatch/.
…ied) + benchmark rework

Summary:
At decode (L_q==1) the standard pack-GQA SDPA kernel's grid collapses to
CTA = batch * n_kv_heads, which under-occupies the SMs; split-K flash-decoding
partitions the KV sequence across many more CTAs to fill the GPU. In
ReplaceEdgeOpWithTritonOpPass._pick_sdpa_kernel, route decode to split-K when
L_q==1 and L_kv >= 256 (power-of-2 head dim required; prefill and non-pow2 head
dims keep the standard kernel).

The 256 crossover was measured under CUDA-graph timing (capture+replay, faithful
to the deployed --cuda_graph runtime). The earlier 2048 boundary was overfit to
a plain (non-cuda-graph) microbenchmark, which charged split-K a ~140us per-call
partial-buffer alloc + extra-launch overhead that the graph runtime eliminates;
under faithful timing split-K wins ~1.2-20x from L_kv ~= 256 upward.

benchmark_sdpa.py reworked: deleted run_sweep and all CSV/sentinel machinery;
run_benchmark now compares all six backends (ET-standard, ET-split-K, PyTorch,
Flash, Efficient, Math) with the PyTorch correctness check, across several
decode configs (gemma D256/CTA16, qwen D256/CTA2, D128/CTA16) over the L_kv
range, with a cuda-graph on/off toggle (--mode {cudagraph,plain,both}) timing
every backend through a small self-contained cuda-graph primitive; terminal-only
output. Each reported cell is the mean+/-std over the last 6 of 10 runs (first 4
discarded as warmup; N_RUNS=10, N_WARMUP=4).

Test Plan:
Exercised against the repo (PYTHONPATH) since the conda env's installed
executorch is stale; a lib reinstall is required for the routing to take effect
in a real export.

backends/cuda/tests/test_sdpa_splitk_replacement.py
  - L_kv=128 -> standard; L_kv=256 -> split-K; L_kv=4096 -> split-K;
    non-pow2 D=96 -> standard.
backends/cuda/tests/test_triton_sdpa_splitk.py (14) and
backends/cuda/tests/test_triton_sdpa_nan.py (3) pass. 21 tests total.

gemma4_31b long-context decode (2401-tok prompt, 256 new tokens, temp 0,
--cuda_graph, 10 runs middle-6) with split-K routing: decode 37.91 -> 43.98
tok/s (+16.0%); prefill within noise.

python backends/cuda/benchmarks/benchmark_sdpa.py --mode cudagraph (gemma
D256/CTA16, mean+/-std us): L_kv=2048 ET-std 102.4+/-0.0 / ET-split-K 24.6+/-0.2 /
PyTorch 475.1+/-0.3 / Flash 56.5+/-0.0; L_kv=16384 ET-std 785.5+/-0.0 /
ET-split-K 179.8+/-0.1 / PyTorch 3447+/-2.6. Plain-timing mode shows split-K's
per-call overhead (the artifact behind the old 2048).
…ock)

The decode-only int4_plain_mm matvec was bound by activation load-instruction
throughput, not DRAM bandwidth (already ~64% peak) or latency. Each inner
iteration issued ~15 loads per 16-byte weight chunk: 8 scalar int32 activation
loads + the same per-block scale d reloaded 4x.

Align Q8Block to 16 bytes (sizeof 36->48) so each block's qs_even/qs_odd 16B
halves are 16B-aligned, then load a whole activation block with two vectorized
uint4 loads + one d load (~4x fewer activation loads). dp4a math and
accumulation order are bit-identical; the int8 activation values and scale are
unchanged.

gemma4_31b decode (long-ctx harness, stacked on optimize_1):
  decode  43.98 -> 46.79 tok/s (+6.4%)
  prefill 1193  -> 1186     (noise; int4_plain_mm is decode-only)
nsys: int4 matvec avg 38.4 -> 34.75 us (-9.5%); quant kernel unchanged.
Unit tests test_aoti_torch_cuda_int4_plain_mm: 6/6 pass (M=1/8, gs=16/32/128).
@pytorch-bot

pytorch-bot Bot commented Jun 9, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20144

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Pending, 2 Unrelated Failures

As of commit dc5de74 with merge base 635a884 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 9, 2026
@Gasoonjia Gasoonjia changed the title [cuda backend] int4 W4A8 matvec: vectorized activation load [cuda backend] int4/8 matvec: vectorized activation load Jun 9, 2026
@Gasoonjia Gasoonjia marked this pull request as ready for review June 9, 2026 17:08
Base automatically changed from g4-opt-sliding-splitk to main June 12, 2026 04:51
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@Gasoonjia Gasoonjia merged commit 630ddba into main Jun 12, 2026
335 of 346 checks passed
@Gasoonjia Gasoonjia deleted the g4-opt-int4-vecload branch June 12, 2026 07:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants