Skip to content

Commit ad68cb3

Browse files
committed
sm12x: per-token early-loop-exit on sparse MLA accumulate inner candidate loop
Redesigned suggestion #3 from PR vllm-project#41834 comment 4450901180. The first attempt (e34daef, reverted; later 72a5ff2, also reverted) tried to truncate ``topk_indices.shape[1]`` in Python so the captured launches iterated a narrower combined slice; that approach broke under cudagraph replay (shape baked at capture) and *also* mis-bounded — the combine kernel writes each token's combined buffer as ``[topk_len_t | swa_len_t | -1 padding]`` with SWA *immediately* following the per-token topk, so a fixed ``effective_topk`` cap cuts off the SWA portion (GSM8K dropped 25 pp on the prior attempt). The kernel already loads the per-token combined length (``valid_len = tl.load(lens_ptr + token_idx)`` for the four ``lens``- gated kernels, ``gather_len`` for the two paged kernels). The existing ``is_valid`` guard only short-circuits the *heavy* work past that length; the outer ``for candidate_idx in range(0, num_candidates)`` still pays one ``tl.load`` + branch per iter on the dead tail. Capping the loop at ``min(num_candidates, valid_len - candidate_offset)`` (clamped to 0) removes those wasted iterations while preserving the existing ``is_valid`` semantics: the iterations we now skip are exactly those the existing guard already discarded. Applied to six accumulate kernels in ``sparse_mla_kernels.py``: - ``_accumulate_gathered_attention_chunk_kernel`` - ``_accumulate_indexed_attention_chunk_kernel`` [autotuned in #1] - ``_accumulate_fp8ds_global_slots_attention_chunk_kernel`` [autotuned in #1] - ``_accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel`` [decode] - ``_accumulate_fp8ds_paged_attention_chunk_kernel`` [autotuned in #1] - ``_accumulate_fp8ds_paged_attention_chunk_multihead_kernel`` [decode] CUDA-graph safety: ``lens_ptr`` / ``gather_lens_ptr`` are stable addresses; their values are refreshed per call by the metadata builder (outside the captured forward) and by ``combine_topk_swa_indices`` (inside the forward but writing only into the persistent buffers the accumulate kernels read from). The kernel inner-loop bound is a runtime-loaded scalar — Triton compiles a dynamic loop and the captured launch picks up the current value on each replay. Savings scale with ``combined_topk_buffer_width - actual valid length`` (i.e. mostly visible at long ``max_model_len`` with shorter actual contexts). At our test shape (``max_model_len=131072``, ISL=2048) the saved iterations come mostly from the decode multihead path; expected to be neutral / no-regression at short ``max_model_len`` where the bound equals ``num_candidates``. Signed-off-by: jasl <jasl9187@hotmail.com>
1 parent c2cdacd commit ad68cb3

1 file changed

Lines changed: 62 additions & 7 deletions

File tree

vllm/v1/attention/backends/mla/sparse_mla_kernels.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,8 +1128,13 @@ def _accumulate_gathered_attention_chunk_kernel(
11281128
running_denom = tl.load(denom_ptr + state_offset)
11291129
running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32)
11301130
valid_len = tl.load(lens_ptr + token_idx)
1131+
# Per-token early-loop-exit (see indexed kernel comment).
1132+
local_eff = tl.minimum(
1133+
num_candidates,
1134+
tl.maximum(valid_len - candidate_offset, 0),
1135+
)
11311136

1132-
for candidate_idx in range(0, num_candidates):
1137+
for candidate_idx in range(0, local_eff):
11331138
is_valid = (candidate_offset + candidate_idx) < valid_len
11341139
if HAS_SLOT_IDS:
11351140
slot_id = tl.load(
@@ -1289,8 +1294,21 @@ def _accumulate_indexed_attention_chunk_kernel(
12891294
running_denom = tl.load(denom_ptr + state_offset)
12901295
running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32)
12911296
valid_len = tl.load(lens_ptr + token_idx)
1297+
# Per-token early-loop-exit: the combine_topk_swa_indices kernel writes
1298+
# ``[topk_len_t | swa_len_t | -1 padding]`` and stores
1299+
# ``lens[t] = topk_len_t + swa_len_t``. The existing ``is_valid`` guard
1300+
# already gates the heavy work past ``valid_len``, but the outer loop
1301+
# still iterates the full ``num_candidates`` (= chunk width). Capping
1302+
# the loop at ``min(num_candidates, valid_len - candidate_offset)``
1303+
# saves the per-iteration index load + compare overhead on the dead
1304+
# tail. CUDA-graph-safe because ``lens_ptr`` is a stable address and
1305+
# the loaded value updates per call from the metadata builder.
1306+
local_eff = tl.minimum(
1307+
num_candidates,
1308+
tl.maximum(valid_len - candidate_offset, 0),
1309+
)
12921310

1293-
for candidate_idx in range(0, num_candidates):
1311+
for candidate_idx in range(0, local_eff):
12941312
kv_index = tl.load(
12951313
indices_ptr
12961314
+ token_idx * stride_indices_t
@@ -1445,12 +1463,17 @@ def _accumulate_fp8ds_global_slots_attention_chunk_kernel(
14451463
running_denom = tl.load(denom_ptr + state_offset)
14461464
running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32)
14471465
valid_len = tl.load(lens_ptr + token_idx)
1466+
# Per-token early-loop-exit (see indexed kernel comment).
1467+
local_eff = tl.minimum(
1468+
num_candidates,
1469+
tl.maximum(valid_len - candidate_offset, 0),
1470+
)
14481471

14491472
fp8_mask = offsets < fp8_dim
14501473
rope_mask = (offsets >= fp8_dim) & dim_mask
14511474
rope_offsets = tl.maximum(offsets - fp8_dim, 0)
14521475

1453-
for candidate_idx in range(0, num_candidates):
1476+
for candidate_idx in range(0, local_eff):
14541477
slot_id = tl.load(
14551478
slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c
14561479
)
@@ -1645,12 +1668,21 @@ def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel(
16451668
tl.float32
16461669
)
16471670
valid_len = tl.load(lens_ptr + token_idx)
1671+
# Per-token early-loop-exit: ``lens[t] = topk_len_t + swa_len_t`` (set
1672+
# by combine_topk_swa_indices). Iterating past ``valid_len`` only
1673+
# incurs the per-iter index-load + compare cost on padding-tail; cap
1674+
# the outer loop at ``valid_len - candidate_offset`` to skip the dead
1675+
# tail. CUDA-graph-safe because ``lens_ptr`` is a stable address.
1676+
local_eff = tl.minimum(
1677+
num_candidates,
1678+
tl.maximum(valid_len - candidate_offset, 0),
1679+
)
16481680

16491681
fp8_mask = dim_offsets < fp8_dim
16501682
rope_mask = (dim_offsets >= fp8_dim) & dim_mask
16511683
rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0)
16521684

1653-
for candidate_idx in range(0, num_candidates):
1685+
for candidate_idx in range(0, local_eff):
16541686
slot_id = tl.load(
16551687
slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c
16561688
)
@@ -1851,8 +1883,13 @@ def _accumulate_fp8ds_paged_attention_chunk_kernel(
18511883
fp8_mask = offsets < fp8_dim
18521884
rope_mask = (offsets >= fp8_dim) & dim_mask
18531885
rope_offsets = tl.maximum(offsets - fp8_dim, 0)
1886+
# Per-token early-loop-exit (see indexed kernel comment).
1887+
local_eff = tl.minimum(
1888+
num_candidates,
1889+
tl.maximum(gather_len - candidate_offset, 0),
1890+
)
18541891

1855-
for candidate_idx in range(0, num_candidates):
1892+
for candidate_idx in range(0, local_eff):
18561893
gather_idx = candidate_offset + candidate_idx
18571894
is_valid = gather_idx < gather_len
18581895

@@ -2054,8 +2091,17 @@ def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel(
20542091
fp8_mask = dim_offsets < fp8_dim
20552092
rope_mask = (dim_offsets >= fp8_dim) & dim_mask
20562093
rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0)
2094+
# Per-token early-loop-exit: ``gather_len`` is the per-token count of
2095+
# cached entries available for this paged read; the existing
2096+
# ``is_valid`` guard skips heavy work past that, but we can also skip
2097+
# the per-iter index load + branch by capping the loop. CUDA-graph-
2098+
# safe because ``gather_lens_ptr`` is a stable address.
2099+
local_eff = tl.minimum(
2100+
num_candidates,
2101+
tl.maximum(gather_len - candidate_offset, 0),
2102+
)
20572103

2058-
for candidate_idx in range(0, num_candidates):
2104+
for candidate_idx in range(0, local_eff):
20592105
gather_idx = candidate_offset + candidate_idx
20602106
is_valid = gather_idx < gather_len
20612107

@@ -2247,8 +2293,17 @@ def _fp8ds_paged_attention_with_sink_multihead_kernel(
22472293
fp8_mask = dim_offsets < fp8_dim
22482294
rope_mask = (dim_offsets >= fp8_dim) & dim_mask
22492295
rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0)
2296+
# Per-token early-loop-exit: ``gather_len`` is the per-token count of
2297+
# cached entries available for this paged read; the existing
2298+
# ``is_valid`` guard skips heavy work past that, but we can also skip
2299+
# the per-iter index load + branch by capping the loop. CUDA-graph-
2300+
# safe because ``gather_lens_ptr`` is a stable address.
2301+
local_eff = tl.minimum(
2302+
num_candidates,
2303+
tl.maximum(gather_len - candidate_offset, 0),
2304+
)
22502305

2251-
for candidate_idx in range(0, num_candidates):
2306+
for candidate_idx in range(0, local_eff):
22522307
gather_idx = candidate_offset + candidate_idx
22532308
is_valid = gather_idx < gather_len
22542309
if is_valid:

0 commit comments

Comments
 (0)