@@ -312,6 +312,8 @@ def __init__(
312312 # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins
313313 # before post-GEMM starts.
314314 self .ln_events = [torch .cuda .Event () for _ in range (4 )]
315+ self .kv_done_event = torch .cuda .Event ()
316+ self .gather_done_event = torch .cuda .Event ()
315317
316318 assert cache_config is not None , "DeepseekV4 attention requires cache_config"
317319 self .swa_cache_layer = DeepseekV4SWACache (
@@ -521,7 +523,130 @@ def attention_impl(
521523 # on the default stream so q stays on its consumer stream (mla_attn
522524 # downstream reads q on default). Indexer/compressor go on aux for
523525 # overlap with default's GEMM + cache write.
524- if self .indexer is not None :
526+ #
527+ # C128A prefill: launch the KV gather on aux_stream[1] so it overlaps
528+ # with the indexer forward (on aux_stream[0]). The pre-gathered
529+ # workspace is passed to mla_attn so _forward_prefill skips its own
530+ # gather phase.
531+ kv_workspace_for_prefill : torch .Tensor | None = None
532+ gather_overlap = (
533+ isinstance (attn_metadata , dict )
534+ and self .compress_ratio >= 128
535+ and self .indexer is not None
536+ and self .compressor is not None
537+ and self .aux_stream_list is not None
538+ and len (self .aux_stream_list ) >= 2
539+ )
540+ if gather_overlap :
541+ _swa_m = cast (
542+ "DeepseekSparseSWAMetadata | None" ,
543+ attn_metadata .get (self .swa_cache_layer .prefix ),
544+ )
545+ gather_overlap = (
546+ _swa_m is not None
547+ and _swa_m .num_prefills > 0
548+ and _swa_m .num_prefills <= PREFILL_CHUNK_SIZE
549+ )
550+ if gather_overlap :
551+ _swa_m = cast (
552+ "DeepseekSparseSWAMetadata" ,
553+ attn_metadata [self .swa_cache_layer .prefix ],
554+ )
555+ _flashmla_m = cast (
556+ FlashMLASparseMetadata | None ,
557+ attn_metadata .get (self .mla_attn .prefix ),
558+ )
559+ assert _flashmla_m is not None , (
560+ "C128A prefill requires FlashMLASparseMetadata"
561+ )
562+
563+ _seq_lens_cpu = _swa_m .prefill_seq_lens_cpu
564+ _gather_lens_cpu = _swa_m .prefill_gather_lens_cpu
565+ assert _seq_lens_cpu is not None and _gather_lens_cpu is not None
566+ _n_bound , _m_bound = _sparse_mla_prefill_workspace_bounds (
567+ seq_lens_cpu = _seq_lens_cpu ,
568+ gather_lens_cpu = _gather_lens_cpu ,
569+ compress_ratio = self .compress_ratio ,
570+ swa_only = False ,
571+ )
572+
573+ # Workspace aliasing: this allocation places kv_workspace at
574+ # offset 0 of the same per-ubatch workspace buffer that
575+ # _forward_prefill will later allocate kv from. _reserve_prefill_workspace
576+ # during warmup already grew the buffer to fit the full spec list
577+ # (kv + combined_indices + ... + state_buffers), so a kv-only
578+ # request here cannot trigger a resize that would orphan
579+ # kv_workspace mid-forward.
580+ (kv_workspace_for_prefill ,) = current_workspace_manager ().get_simultaneous (
581+ ((PREFILL_CHUNK_SIZE , _m_bound , self .mla_attn .head_dim ), torch .bfloat16 ),
582+ )
583+
584+ _aux0 = self .aux_stream_list [0 ]
585+ _aux1 = self .aux_stream_list [1 ]
586+ _indexer = self .indexer
587+ assert self .compressor is not None
588+ _compressor = self .compressor
589+
590+ def _wq_b_kv_insert_and_compress () -> torch .Tensor :
591+ q = self .wq_b (qr ).view (- 1 , self .n_local_heads , self .head_dim )
592+ self ._fused_qnorm_rope_kv_insert (q , kv , positions , attn_metadata )
593+ _compressor (kv_score , positions , self .rotary_emb )
594+ return q
595+
596+ self .ln_events [0 ].record ()
597+ q = _wq_b_kv_insert_and_compress ()
598+ self .kv_done_event .record ()
599+
600+ with torch .cuda .stream (_aux0 ):
601+ self .ln_events [0 ].wait ()
602+ _indexer (
603+ hidden_states ,
604+ qr ,
605+ indexer_kv_score ,
606+ indexer_weights ,
607+ positions ,
608+ self .indexer_rotary_emb ,
609+ )
610+ self .ln_events [1 ].record ()
611+
612+ _num_p = _swa_m .num_prefills
613+ _num_d = _swa_m .num_decodes
614+ _comp_k_cache = self .mla_attn .kv_cache
615+ _swa_k_cache = self .swa_cache_layer .kv_cache
616+ _seq_lens_dev = _swa_m .prefill_seq_lens
617+ _gather_lens_dev = _swa_m .prefill_gather_lens
618+ assert _seq_lens_dev is not None and _gather_lens_dev is not None
619+ _block_table = _flashmla_m .block_table [_num_d :]
620+ _comp_block_size = _flashmla_m .block_size // self .compress_ratio
621+ _swa_block_table = _swa_m .block_table [_num_d :]
622+ _swa_block_size = _swa_m .block_size
623+
624+ with torch .cuda .stream (_aux1 ):
625+ self .kv_done_event .wait ()
626+ if _comp_k_cache is not None and _comp_k_cache .numel () > 0 :
627+ dequantize_and_gather_k_cache (
628+ kv_workspace_for_prefill [:_num_p ],
629+ _comp_k_cache ,
630+ seq_lens = _seq_lens_dev [:_num_p ] // self .compress_ratio ,
631+ gather_lens = None ,
632+ block_table = _block_table [:_num_p ],
633+ block_size = _comp_block_size ,
634+ offset = 0 ,
635+ )
636+ dequantize_and_gather_k_cache (
637+ kv_workspace_for_prefill [:_num_p ],
638+ _swa_k_cache ,
639+ seq_lens = _seq_lens_dev [:_num_p ],
640+ gather_lens = _gather_lens_dev [:_num_p ],
641+ block_table = _swa_block_table [:_num_p ],
642+ block_size = _swa_block_size ,
643+ offset = _n_bound ,
644+ )
645+ self .gather_done_event .record ()
646+
647+ self .ln_events [1 ].wait ()
648+ self .gather_done_event .wait ()
649+ elif self .indexer is not None :
525650 aux_stream = (
526651 self .aux_stream_list [0 ] if self .aux_stream_list is not None else None
527652 )
@@ -587,7 +712,9 @@ def wq_b_kv_insert() -> torch.Tensor:
587712
588713 # MLA attention writes into the pre-allocated `out` buffer
589714 # ([num_tokens, padded_heads, head_dim]).
590- self .mla_attn (q , kv , positions , output = out )
715+ self .mla_attn (
716+ q , kv , positions , output = out , kv_workspace = kv_workspace_for_prefill
717+ )
591718
592719 def _fused_qnorm_rope_kv_insert (
593720 self ,
@@ -1249,6 +1376,7 @@ def forward(
12491376 kv : torch .Tensor ,
12501377 positions : torch .Tensor ,
12511378 output : torch .Tensor ,
1379+ kv_workspace : torch .Tensor | None = None ,
12521380 ) -> None :
12531381 assert output .shape == q .shape , (
12541382 f"output buffer shape { output .shape } must match q shape { q .shape } "
@@ -1298,6 +1426,7 @@ def forward(
12981426 output = output [num_decode_tokens :],
12991427 attn_metadata = flashmla_metadata ,
13001428 swa_metadata = swa_metadata ,
1429+ kv_workspace = kv_workspace ,
13011430 )
13021431 if num_decodes > 0 :
13031432 self ._forward_decode (
@@ -1438,6 +1567,7 @@ def _forward_prefill(
14381567 output : torch .Tensor ,
14391568 attn_metadata : FlashMLASparseMetadata | None ,
14401569 swa_metadata : "DeepseekSparseSWAMetadata" ,
1570+ kv_workspace : torch .Tensor | None = None ,
14411571 ) -> None :
14421572 swa_only = attn_metadata is None
14431573
@@ -1536,36 +1666,44 @@ def _forward_prefill(
15361666 ((max_query_chunk_tokens ,), torch .int32 ),
15371667 )
15381668 prefill_state_buffers = None
1669+ # When the wrapper's attention_impl has pre-gathered KV into
1670+ # kv_workspace on an aux stream (overlapped with the indexer), use
1671+ # that buffer in place of the per-chunk gather below. The workspace
1672+ # allocation in attention_impl aliases offset 0 of the same per-ubatch
1673+ # workspace buffer as ``kv`` here, but we route through the explicit
1674+ # parameter so the contract stays visible at the call site.
1675+ _kv = kv_workspace if kv_workspace is not None else kv
15391676 for chunk_idx in range (num_chunks ):
15401677 chunk_start = chunk_idx * PREFILL_CHUNK_SIZE
15411678 chunk_end = min (chunk_start + PREFILL_CHUNK_SIZE , num_prefills )
15421679 chunk_size = chunk_end - chunk_start
1543- if not swa_only :
1544- # Gather compressed KV
1545- assert attn_metadata is not None
1546- block_table = attn_metadata .block_table [num_decodes :]
1680+ if kv_workspace is None :
1681+ if not swa_only :
1682+ # Gather compressed KV
1683+ assert attn_metadata is not None
1684+ block_table = attn_metadata .block_table [num_decodes :]
1685+ dequantize_and_gather_k_cache (
1686+ kv [:chunk_size ],
1687+ compressed_k_cache ,
1688+ seq_lens = seq_lens [chunk_start :chunk_end ] // self .compress_ratio ,
1689+ gather_lens = None ,
1690+ block_table = block_table [chunk_start :chunk_end ],
1691+ block_size = attn_metadata .block_size // self .compress_ratio ,
1692+ offset = 0 ,
1693+ )
1694+
1695+ # Gather SWA KV
1696+ swa_block_table = swa_metadata .block_table [num_decodes :]
15471697 dequantize_and_gather_k_cache (
15481698 kv [:chunk_size ],
1549- compressed_k_cache ,
1550- seq_lens = seq_lens [chunk_start :chunk_end ] // self . compress_ratio ,
1551- gather_lens = None ,
1552- block_table = block_table [chunk_start :chunk_end ],
1553- block_size = attn_metadata .block_size // self . compress_ratio ,
1554- offset = 0 ,
1699+ swa_k_cache ,
1700+ seq_lens = seq_lens [chunk_start :chunk_end ],
1701+ gather_lens = gather_lens [ chunk_start : chunk_end ] ,
1702+ block_table = swa_block_table [chunk_start :chunk_end ],
1703+ block_size = swa_metadata .block_size ,
1704+ offset = N ,
15551705 )
15561706
1557- # Gather SWA KV
1558- swa_block_table = swa_metadata .block_table [num_decodes :]
1559- dequantize_and_gather_k_cache (
1560- kv [:chunk_size ],
1561- swa_k_cache ,
1562- seq_lens = seq_lens [chunk_start :chunk_end ],
1563- gather_lens = gather_lens [chunk_start :chunk_end ],
1564- block_table = swa_block_table [chunk_start :chunk_end ],
1565- block_size = swa_metadata .block_size ,
1566- offset = N ,
1567- )
1568-
15691707 # Combine the topk indices and SWA indices for gathered KV cache
15701708 query_start = (
15711709 query_start_loc_cpu [num_decodes + chunk_start ] - prefill_token_base
@@ -1594,7 +1732,7 @@ def _forward_prefill(
15941732 if triton_sparse_mla_enabled :
15951733 self ._forward_sparse_mla_prefill_triton (
15961734 q = q [query_start :query_end ],
1597- kv = kv [:chunk_size ],
1735+ kv = _kv [:chunk_size ],
15981736 combined_indices = combined_indices ,
15991737 combined_lens = combined_lens ,
16001738 output = output [query_start :query_end ],
@@ -1604,7 +1742,7 @@ def _forward_prefill(
16041742
16051743 flash_mla_sparse_fwd (
16061744 q = q [query_start :query_end ],
1607- kv = kv .view (- 1 , 1 , q .shape [- 1 ]),
1745+ kv = _kv .view (- 1 , 1 , q .shape [- 1 ]),
16081746 indices = combined_indices .unsqueeze (1 ),
16091747 sm_scale = self .scale ,
16101748 attn_sink = self .attn_sink ,
0 commit comments