Commit d176f8e
committed
sm12x: per-token early-loop-exit on sparse MLA accumulate inner candidate loop
Redesigned suggestion #3 from PR vllm-project#41834 comment 4450901180. The first
attempt (e34daef, reverted; later 72a5ff2, also reverted) tried to
truncate ``topk_indices.shape[1]`` in Python so the captured launches
iterated a narrower combined slice; that approach broke under cudagraph
replay (shape baked at capture) and *also* mis-bounded — the combine
kernel writes each token's combined buffer as ``[topk_len_t |
swa_len_t | -1 padding]`` with SWA *immediately* following the
per-token topk, so a fixed ``effective_topk`` cap cuts off the SWA
portion (GSM8K dropped 25 pp on the prior attempt).
The kernel already loads the per-token combined length
(``valid_len = tl.load(lens_ptr + token_idx)`` for the four ``lens``-
gated kernels, ``gather_len`` for the two paged kernels). The existing
``is_valid`` guard only short-circuits the *heavy* work past that
length; the outer ``for candidate_idx in range(0, num_candidates)``
still pays one ``tl.load`` + branch per iter on the dead tail.
Capping the loop at
``min(num_candidates, valid_len - candidate_offset)`` (clamped to 0)
removes those wasted iterations while preserving the existing
``is_valid`` semantics: the iterations we now skip are exactly those
the existing guard already discarded.
Applied to six accumulate kernels in ``sparse_mla_kernels.py``:
- ``_accumulate_gathered_attention_chunk_kernel``
- ``_accumulate_indexed_attention_chunk_kernel`` [autotuned in #1]
- ``_accumulate_fp8ds_global_slots_attention_chunk_kernel`` [autotuned in #1]
- ``_accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel`` [decode]
- ``_accumulate_fp8ds_paged_attention_chunk_kernel`` [autotuned in #1]
- ``_accumulate_fp8ds_paged_attention_chunk_multihead_kernel`` [decode]
CUDA-graph safety: ``lens_ptr`` / ``gather_lens_ptr`` are stable
addresses; their values are refreshed per call by the metadata builder
(outside the captured forward) and by ``combine_topk_swa_indices``
(inside the forward but writing only into the persistent buffers the
accumulate kernels read from). The kernel inner-loop bound is a
runtime-loaded scalar — Triton compiles a dynamic loop and the captured
launch picks up the current value on each replay.
Savings scale with ``combined_topk_buffer_width - actual valid length``
(i.e. mostly visible at long ``max_model_len`` with shorter actual
contexts). At our test shape (``max_model_len=131072``, ISL=2048) the
saved iterations come mostly from the decode multihead path; expected
to be neutral / no-regression at short ``max_model_len`` where the
bound equals ``num_candidates``.
Signed-off-by: jasl <jasl9187@hotmail.com>1 parent 84dd877 commit d176f8e
1 file changed
Lines changed: 62 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1128 | 1128 | | |
1129 | 1129 | | |
1130 | 1130 | | |
| 1131 | + | |
| 1132 | + | |
| 1133 | + | |
| 1134 | + | |
| 1135 | + | |
1131 | 1136 | | |
1132 | | - | |
| 1137 | + | |
1133 | 1138 | | |
1134 | 1139 | | |
1135 | 1140 | | |
| |||
1289 | 1294 | | |
1290 | 1295 | | |
1291 | 1296 | | |
| 1297 | + | |
| 1298 | + | |
| 1299 | + | |
| 1300 | + | |
| 1301 | + | |
| 1302 | + | |
| 1303 | + | |
| 1304 | + | |
| 1305 | + | |
| 1306 | + | |
| 1307 | + | |
| 1308 | + | |
| 1309 | + | |
1292 | 1310 | | |
1293 | | - | |
| 1311 | + | |
1294 | 1312 | | |
1295 | 1313 | | |
1296 | 1314 | | |
| |||
1445 | 1463 | | |
1446 | 1464 | | |
1447 | 1465 | | |
| 1466 | + | |
| 1467 | + | |
| 1468 | + | |
| 1469 | + | |
| 1470 | + | |
1448 | 1471 | | |
1449 | 1472 | | |
1450 | 1473 | | |
1451 | 1474 | | |
1452 | 1475 | | |
1453 | | - | |
| 1476 | + | |
1454 | 1477 | | |
1455 | 1478 | | |
1456 | 1479 | | |
| |||
1645 | 1668 | | |
1646 | 1669 | | |
1647 | 1670 | | |
| 1671 | + | |
| 1672 | + | |
| 1673 | + | |
| 1674 | + | |
| 1675 | + | |
| 1676 | + | |
| 1677 | + | |
| 1678 | + | |
| 1679 | + | |
1648 | 1680 | | |
1649 | 1681 | | |
1650 | 1682 | | |
1651 | 1683 | | |
1652 | 1684 | | |
1653 | | - | |
| 1685 | + | |
1654 | 1686 | | |
1655 | 1687 | | |
1656 | 1688 | | |
| |||
1851 | 1883 | | |
1852 | 1884 | | |
1853 | 1885 | | |
| 1886 | + | |
| 1887 | + | |
| 1888 | + | |
| 1889 | + | |
| 1890 | + | |
1854 | 1891 | | |
1855 | | - | |
| 1892 | + | |
1856 | 1893 | | |
1857 | 1894 | | |
1858 | 1895 | | |
| |||
2054 | 2091 | | |
2055 | 2092 | | |
2056 | 2093 | | |
| 2094 | + | |
| 2095 | + | |
| 2096 | + | |
| 2097 | + | |
| 2098 | + | |
| 2099 | + | |
| 2100 | + | |
| 2101 | + | |
| 2102 | + | |
2057 | 2103 | | |
2058 | | - | |
| 2104 | + | |
2059 | 2105 | | |
2060 | 2106 | | |
2061 | 2107 | | |
| |||
2247 | 2293 | | |
2248 | 2294 | | |
2249 | 2295 | | |
| 2296 | + | |
| 2297 | + | |
| 2298 | + | |
| 2299 | + | |
| 2300 | + | |
| 2301 | + | |
| 2302 | + | |
| 2303 | + | |
| 2304 | + | |
2250 | 2305 | | |
2251 | | - | |
| 2306 | + | |
2252 | 2307 | | |
2253 | 2308 | | |
2254 | 2309 | | |
| |||
0 commit comments