Skip to content

Commit bfb68ed

Browse files
aabbccddwasdjasl
authored andcommitted
sm12x: overlap C128A prefill KV gather with indexer on aux stream
Run dequantize_and_gather_k_cache for the compressed + SWA caches on aux_stream[1] while the indexer forward runs on aux_stream[0], so the gather is hidden behind the indexer instead of serialising before _forward_prefill. The workspace allocation at the wrapper places kv_workspace at offset 0 of the same per-ubatch workspace buffer that _forward_prefill would otherwise allocate kv from; _reserve_prefill_workspace during warmup already grew the buffer to fit the full prefill spec list, so the kv-only request cannot trigger a resize that orphans kv_workspace mid-forward. A gather_done_event joins the aux stream back before mla_attn runs. CUDA-graph-safe: both aux streams join (event.wait()) before the attention boundary, and the gather is gated on num_prefills > 0 and num_prefills <= PREFILL_CHUNK_SIZE (single-chunk only). Multi-chunk prefill or non-C128A paths fall through to the existing per-chunk gather in _forward_prefill. Original implementation by aabbccddwasd in their dsv4-sm120-opt-v2 branch (commit 6ff395e). This re-applies only the gather-overlap half of that commit; the multi-head prefill kernel half is dropped because the canonical tip already has alex's HEAD_BLOCK=8 version (671958e / vllm-project#41834 PR #6) which was empirically tuned for this hardware. Signed-off-by: jasl <jasl9187@hotmail.com>
1 parent b2c21c7 commit bfb68ed

1 file changed

Lines changed: 164 additions & 26 deletions

File tree

vllm/models/deepseek_v4/attention.py

Lines changed: 164 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ def __init__(
312312
# [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins
313313
# before post-GEMM starts.
314314
self.ln_events = [torch.cuda.Event() for _ in range(4)]
315+
self.kv_done_event = torch.cuda.Event()
316+
self.gather_done_event = torch.cuda.Event()
315317

316318
assert cache_config is not None, "DeepseekV4 attention requires cache_config"
317319
self.swa_cache_layer = DeepseekV4SWACache(
@@ -521,7 +523,130 @@ def attention_impl(
521523
# on the default stream so q stays on its consumer stream (mla_attn
522524
# downstream reads q on default). Indexer/compressor go on aux for
523525
# overlap with default's GEMM + cache write.
524-
if self.indexer is not None:
526+
#
527+
# C128A prefill: launch the KV gather on aux_stream[1] so it overlaps
528+
# with the indexer forward (on aux_stream[0]). The pre-gathered
529+
# workspace is passed to mla_attn so _forward_prefill skips its own
530+
# gather phase.
531+
kv_workspace_for_prefill: torch.Tensor | None = None
532+
gather_overlap = (
533+
isinstance(attn_metadata, dict)
534+
and self.compress_ratio >= 128
535+
and self.indexer is not None
536+
and self.compressor is not None
537+
and self.aux_stream_list is not None
538+
and len(self.aux_stream_list) >= 2
539+
)
540+
if gather_overlap:
541+
_swa_m = cast(
542+
"DeepseekSparseSWAMetadata | None",
543+
attn_metadata.get(self.swa_cache_layer.prefix),
544+
)
545+
gather_overlap = (
546+
_swa_m is not None
547+
and _swa_m.num_prefills > 0
548+
and _swa_m.num_prefills <= PREFILL_CHUNK_SIZE
549+
)
550+
if gather_overlap:
551+
_swa_m = cast(
552+
"DeepseekSparseSWAMetadata",
553+
attn_metadata[self.swa_cache_layer.prefix],
554+
)
555+
_flashmla_m = cast(
556+
FlashMLASparseMetadata | None,
557+
attn_metadata.get(self.mla_attn.prefix),
558+
)
559+
assert _flashmla_m is not None, (
560+
"C128A prefill requires FlashMLASparseMetadata"
561+
)
562+
563+
_seq_lens_cpu = _swa_m.prefill_seq_lens_cpu
564+
_gather_lens_cpu = _swa_m.prefill_gather_lens_cpu
565+
assert _seq_lens_cpu is not None and _gather_lens_cpu is not None
566+
_n_bound, _m_bound = _sparse_mla_prefill_workspace_bounds(
567+
seq_lens_cpu=_seq_lens_cpu,
568+
gather_lens_cpu=_gather_lens_cpu,
569+
compress_ratio=self.compress_ratio,
570+
swa_only=False,
571+
)
572+
573+
# Workspace aliasing: this allocation places kv_workspace at
574+
# offset 0 of the same per-ubatch workspace buffer that
575+
# _forward_prefill will later allocate kv from. _reserve_prefill_workspace
576+
# during warmup already grew the buffer to fit the full spec list
577+
# (kv + combined_indices + ... + state_buffers), so a kv-only
578+
# request here cannot trigger a resize that would orphan
579+
# kv_workspace mid-forward.
580+
(kv_workspace_for_prefill,) = current_workspace_manager().get_simultaneous(
581+
((PREFILL_CHUNK_SIZE, _m_bound, self.mla_attn.head_dim), torch.bfloat16),
582+
)
583+
584+
_aux0 = self.aux_stream_list[0]
585+
_aux1 = self.aux_stream_list[1]
586+
_indexer = self.indexer
587+
assert self.compressor is not None
588+
_compressor = self.compressor
589+
590+
def _wq_b_kv_insert_and_compress() -> torch.Tensor:
591+
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
592+
self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
593+
_compressor(kv_score, positions, self.rotary_emb)
594+
return q
595+
596+
self.ln_events[0].record()
597+
q = _wq_b_kv_insert_and_compress()
598+
self.kv_done_event.record()
599+
600+
with torch.cuda.stream(_aux0):
601+
self.ln_events[0].wait()
602+
_indexer(
603+
hidden_states,
604+
qr,
605+
indexer_kv_score,
606+
indexer_weights,
607+
positions,
608+
self.indexer_rotary_emb,
609+
)
610+
self.ln_events[1].record()
611+
612+
_num_p = _swa_m.num_prefills
613+
_num_d = _swa_m.num_decodes
614+
_comp_k_cache = self.mla_attn.kv_cache
615+
_swa_k_cache = self.swa_cache_layer.kv_cache
616+
_seq_lens_dev = _swa_m.prefill_seq_lens
617+
_gather_lens_dev = _swa_m.prefill_gather_lens
618+
assert _seq_lens_dev is not None and _gather_lens_dev is not None
619+
_block_table = _flashmla_m.block_table[_num_d:]
620+
_comp_block_size = _flashmla_m.block_size // self.compress_ratio
621+
_swa_block_table = _swa_m.block_table[_num_d:]
622+
_swa_block_size = _swa_m.block_size
623+
624+
with torch.cuda.stream(_aux1):
625+
self.kv_done_event.wait()
626+
if _comp_k_cache is not None and _comp_k_cache.numel() > 0:
627+
dequantize_and_gather_k_cache(
628+
kv_workspace_for_prefill[:_num_p],
629+
_comp_k_cache,
630+
seq_lens=_seq_lens_dev[:_num_p] // self.compress_ratio,
631+
gather_lens=None,
632+
block_table=_block_table[:_num_p],
633+
block_size=_comp_block_size,
634+
offset=0,
635+
)
636+
dequantize_and_gather_k_cache(
637+
kv_workspace_for_prefill[:_num_p],
638+
_swa_k_cache,
639+
seq_lens=_seq_lens_dev[:_num_p],
640+
gather_lens=_gather_lens_dev[:_num_p],
641+
block_table=_swa_block_table[:_num_p],
642+
block_size=_swa_block_size,
643+
offset=_n_bound,
644+
)
645+
self.gather_done_event.record()
646+
647+
self.ln_events[1].wait()
648+
self.gather_done_event.wait()
649+
elif self.indexer is not None:
525650
aux_stream = (
526651
self.aux_stream_list[0] if self.aux_stream_list is not None else None
527652
)
@@ -587,7 +712,9 @@ def wq_b_kv_insert() -> torch.Tensor:
587712

588713
# MLA attention writes into the pre-allocated `out` buffer
589714
# ([num_tokens, padded_heads, head_dim]).
590-
self.mla_attn(q, kv, positions, output=out)
715+
self.mla_attn(
716+
q, kv, positions, output=out, kv_workspace=kv_workspace_for_prefill
717+
)
591718

592719
def _fused_qnorm_rope_kv_insert(
593720
self,
@@ -1249,6 +1376,7 @@ def forward(
12491376
kv: torch.Tensor,
12501377
positions: torch.Tensor,
12511378
output: torch.Tensor,
1379+
kv_workspace: torch.Tensor | None = None,
12521380
) -> None:
12531381
assert output.shape == q.shape, (
12541382
f"output buffer shape {output.shape} must match q shape {q.shape}"
@@ -1298,6 +1426,7 @@ def forward(
12981426
output=output[num_decode_tokens:],
12991427
attn_metadata=flashmla_metadata,
13001428
swa_metadata=swa_metadata,
1429+
kv_workspace=kv_workspace,
13011430
)
13021431
if num_decodes > 0:
13031432
self._forward_decode(
@@ -1438,6 +1567,7 @@ def _forward_prefill(
14381567
output: torch.Tensor,
14391568
attn_metadata: FlashMLASparseMetadata | None,
14401569
swa_metadata: "DeepseekSparseSWAMetadata",
1570+
kv_workspace: torch.Tensor | None = None,
14411571
) -> None:
14421572
swa_only = attn_metadata is None
14431573

@@ -1536,36 +1666,44 @@ def _forward_prefill(
15361666
((max_query_chunk_tokens,), torch.int32),
15371667
)
15381668
prefill_state_buffers = None
1669+
# When the wrapper's attention_impl has pre-gathered KV into
1670+
# kv_workspace on an aux stream (overlapped with the indexer), use
1671+
# that buffer in place of the per-chunk gather below. The workspace
1672+
# allocation in attention_impl aliases offset 0 of the same per-ubatch
1673+
# workspace buffer as ``kv`` here, but we route through the explicit
1674+
# parameter so the contract stays visible at the call site.
1675+
_kv = kv_workspace if kv_workspace is not None else kv
15391676
for chunk_idx in range(num_chunks):
15401677
chunk_start = chunk_idx * PREFILL_CHUNK_SIZE
15411678
chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills)
15421679
chunk_size = chunk_end - chunk_start
1543-
if not swa_only:
1544-
# Gather compressed KV
1545-
assert attn_metadata is not None
1546-
block_table = attn_metadata.block_table[num_decodes:]
1680+
if kv_workspace is None:
1681+
if not swa_only:
1682+
# Gather compressed KV
1683+
assert attn_metadata is not None
1684+
block_table = attn_metadata.block_table[num_decodes:]
1685+
dequantize_and_gather_k_cache(
1686+
kv[:chunk_size],
1687+
compressed_k_cache,
1688+
seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
1689+
gather_lens=None,
1690+
block_table=block_table[chunk_start:chunk_end],
1691+
block_size=attn_metadata.block_size // self.compress_ratio,
1692+
offset=0,
1693+
)
1694+
1695+
# Gather SWA KV
1696+
swa_block_table = swa_metadata.block_table[num_decodes:]
15471697
dequantize_and_gather_k_cache(
15481698
kv[:chunk_size],
1549-
compressed_k_cache,
1550-
seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio,
1551-
gather_lens=None,
1552-
block_table=block_table[chunk_start:chunk_end],
1553-
block_size=attn_metadata.block_size // self.compress_ratio,
1554-
offset=0,
1699+
swa_k_cache,
1700+
seq_lens=seq_lens[chunk_start:chunk_end],
1701+
gather_lens=gather_lens[chunk_start:chunk_end],
1702+
block_table=swa_block_table[chunk_start:chunk_end],
1703+
block_size=swa_metadata.block_size,
1704+
offset=N,
15551705
)
15561706

1557-
# Gather SWA KV
1558-
swa_block_table = swa_metadata.block_table[num_decodes:]
1559-
dequantize_and_gather_k_cache(
1560-
kv[:chunk_size],
1561-
swa_k_cache,
1562-
seq_lens=seq_lens[chunk_start:chunk_end],
1563-
gather_lens=gather_lens[chunk_start:chunk_end],
1564-
block_table=swa_block_table[chunk_start:chunk_end],
1565-
block_size=swa_metadata.block_size,
1566-
offset=N,
1567-
)
1568-
15691707
# Combine the topk indices and SWA indices for gathered KV cache
15701708
query_start = (
15711709
query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base
@@ -1594,7 +1732,7 @@ def _forward_prefill(
15941732
if triton_sparse_mla_enabled:
15951733
self._forward_sparse_mla_prefill_triton(
15961734
q=q[query_start:query_end],
1597-
kv=kv[:chunk_size],
1735+
kv=_kv[:chunk_size],
15981736
combined_indices=combined_indices,
15991737
combined_lens=combined_lens,
16001738
output=output[query_start:query_end],
@@ -1604,7 +1742,7 @@ def _forward_prefill(
16041742

16051743
flash_mla_sparse_fwd(
16061744
q=q[query_start:query_end],
1607-
kv=kv.view(-1, 1, q.shape[-1]),
1745+
kv=_kv.view(-1, 1, q.shape[-1]),
16081746
indices=combined_indices.unsqueeze(1),
16091747
sm_scale=self.scale,
16101748
attn_sink=self.attn_sink,

0 commit comments

Comments
 (0)