-
Notifications
You must be signed in to change notification settings - Fork 53
MojoStorePagedSingleCache for single K/V paged store #372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -169,3 +169,69 @@ def forward( | |||||||
| ) | ||||||||
|
|
||||||||
| return key_cache, value_cache | ||||||||
|
|
||||||||
|
|
||||||||
| class MojoStorePagedSingleCache(MojoOperator): | ||||||||
| def __init__( | ||||||||
| self, | ||||||||
| ): | ||||||||
| super().__init__() | ||||||||
|
|
||||||||
| def forward( | ||||||||
| self, | ||||||||
| states: torch.Tensor, | ||||||||
| cache: torch.Tensor, | ||||||||
| block_table: Optional[torch.Tensor] = None, | ||||||||
| cu_q_lens: Optional[torch.Tensor] = None, | ||||||||
| context_kv_lens: Optional[torch.Tensor] = None, | ||||||||
| *, | ||||||||
| chunk_metadata: Optional[torch.Tensor] = None, | ||||||||
| ) -> torch.Tensor: | ||||||||
| """ | ||||||||
| Copy new tokens of a single attribute (key OR value) into one paged cache. | ||||||||
|
|
||||||||
| Mirrors :class:`MojoStorePagedKVCache` but operates on a single tensor/cache | ||||||||
| pair, for cases where only one of K/V needs to be written (e.g. SAGE prefill | ||||||||
| only stores V into the static cache while K lives in a separate cache). | ||||||||
|
|
||||||||
| Args: | ||||||||
| states (torch.Tensor): Shape (token_num, kv_head_num, head_dim) — new tokens. | ||||||||
| cache (torch.Tensor): Shape (total_phys_blocks, kv_heads, block_size, head_dim) — paged cache. | ||||||||
| block_table (torch.Tensor | None): Logical-to-physical block mapping. | ||||||||
| cu_q_lens (torch.Tensor | None): Cumulative query lengths. ``None`` indicates decode mode. | ||||||||
| context_kv_lens (torch.Tensor | None): KV lengths before storing current tokens. | ||||||||
| chunk_metadata (torch.Tensor | None): Optimized precomputed store plan with shape ``(num_chunks, 4)`` | ||||||||
| and per-row ``(src_token_start, dst_block_id, dst_block_offset, chunk_len)``. | ||||||||
|
|
||||||||
| Returns: | ||||||||
| torch.Tensor: Updated ``cache`` after in-place writes. | ||||||||
| """ | ||||||||
| assert len(states.shape) == 3, "states must be (token_num, kv_head_num, head_dim), please check." | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||
|
|
||||||||
| if chunk_metadata is None: | ||||||||
| assert block_table is not None, "block_table is required when chunk_metadata is not provided." | ||||||||
| assert context_kv_lens is not None, "context_kv_lens is required when chunk_metadata is not provided." | ||||||||
| chunk_metadata = build_paged_kv_chunk_metadata( | ||||||||
| block_table, | ||||||||
| cu_q_lens, | ||||||||
| context_kv_lens, | ||||||||
| cache.shape[2], | ||||||||
| ) | ||||||||
| else: | ||||||||
| assert block_table is None and cu_q_lens is None and context_kv_lens is None, ( | ||||||||
| "chunk_metadata path should not be mixed with block_table/cu_q_lens/context_kv_lens." | ||||||||
| ) | ||||||||
|
|
||||||||
| assert_paged_kv_store_contract(chunk_metadata) | ||||||||
|
|
||||||||
| if chunk_metadata.shape[0] == 0: | ||||||||
| return cache | ||||||||
|
|
||||||||
| for src_token_start, dst_block_id, dst_block_offset, chunk_len in chunk_metadata.tolist(): | ||||||||
| src_end = src_token_start + chunk_len | ||||||||
| dst_end = dst_block_offset + chunk_len | ||||||||
| cache[dst_block_id, :, dst_block_offset:dst_end, :] = states[src_token_start:src_end].permute( | ||||||||
| 1, 0, 2 | ||||||||
| ) | ||||||||
|
|
||||||||
| return cache | ||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is highly recommended to verify that
statesandcacheare on the same device. A device mismatch between these tensors will cause runtime failures or silent errors during execution.