Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions mojo_opset/backends/ixformer/operators/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ixformer import functions as ixf_f

from mojo_opset.core import MojoStorePagedKVCache
from mojo_opset.core import MojoStorePagedSingleCache
from mojo_opset.core.operators.kv_cache import assert_paged_kv_store_contract

class IxformerStorePagedKVCache(MojoStorePagedKVCache):
Expand Down Expand Up @@ -78,3 +79,61 @@ def forward(
context_kv_lens,
)
return key_cache, value_cache


class IxformerStorePagedSingleCache(MojoStorePagedSingleCache):
supported_platforms_list = ["ilu"]

def forward(
self,
states: torch.Tensor,
cache: torch.Tensor,
block_table: Optional[torch.Tensor] = None,
cu_q_lens: Optional[torch.Tensor] = None,
context_kv_lens: Optional[torch.Tensor] = None,
*,
chunk_metadata: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Store new tokens of a single attribute (key OR value) into one ixformer
block-based paged cache.

Args:
states (torch.Tensor): New tokens with shape (token_num, kv_head_num, head_dim).
cache (torch.Tensor): Paged cache with shape
(num_blocks, kv_head_num, block_size, head_dim), updated in-place.
block_table (torch.Tensor | None): Logical-to-physical block mapping with
shape (batch_size, max_blocks_per_sequence).
cu_q_lens (torch.Tensor | None): Cumulative query lengths for prefill with
shape (batch_size + 1,). None indicates decode mode.
context_kv_lens (torch.Tensor | None): Existing KV lengths before storing the
current tokens, shape (batch_size,). Padding entries use -1.
chunk_metadata (torch.Tensor | None): Optional precomputed store plan with shape
(num_chunks, 4) and per-row (src_token_start, dst_block_id, dst_block_offset, chunk_len).

Returns:
torch.Tensor: Updated cache.
"""
if states.dim() != 3:
raise ValueError("states must be (token_num, kv_head_num, head_dim).")
if cache.dim() != 4:
raise ValueError("cache must be (num_blocks, kv_head_num, block_size, head_dim).")
if cache.dtype != states.dtype:
raise ValueError("IxformerStorePagedSingleCache requires states and cache to have the same dtype.")
Comment on lines +121 to +122

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is highly recommended to verify that states and cache are on the same device. A device mismatch between these tensors will cause runtime failures or silent errors during execution.

Suggested change
if cache.dtype != states.dtype:
raise ValueError("IxformerStorePagedSingleCache requires states and cache to have the same dtype.")
if cache.dtype != states.dtype:
raise ValueError("IxformerStorePagedSingleCache requires states and cache to have the same dtype.")
if cache.device != states.device:
raise ValueError("IxformerStorePagedSingleCache requires states and cache to be on the same device.")


if chunk_metadata is not None:
raise NotImplementedError("IxformerStorePagedSingleCache does not support the chunk_metadata path.")
if block_table is None or context_kv_lens is None:
raise ValueError("block_table and context_kv_lens are required when chunk_metadata is not provided.")

ixf_f.paged_store_kv_cache_with_block_table(
states,
states,
cache,
cache,
block_table,
cu_q_lens,
context_kv_lens,
store_mode=1,
)
return cache
2 changes: 2 additions & 0 deletions mojo_opset/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

""" kvcache """
from .operators.kv_cache import MojoStorePagedKVCache
from .operators.kv_cache import MojoStorePagedSingleCache

""" gemm """
from .operators.gemm import MojoGemm
Expand Down Expand Up @@ -125,6 +126,7 @@
"MojoSWA",

"MojoStorePagedKVCache",
"MojoStorePagedSingleCache",

"MojoGemm",
"MojoQuantGemm",
Expand Down
66 changes: 66 additions & 0 deletions mojo_opset/core/operators/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,69 @@ def forward(
)

return key_cache, value_cache


class MojoStorePagedSingleCache(MojoOperator):
def __init__(
self,
):
super().__init__()

def forward(
self,
states: torch.Tensor,
cache: torch.Tensor,
block_table: Optional[torch.Tensor] = None,
cu_q_lens: Optional[torch.Tensor] = None,
context_kv_lens: Optional[torch.Tensor] = None,
*,
chunk_metadata: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Copy new tokens of a single attribute (key OR value) into one paged cache.

Mirrors :class:`MojoStorePagedKVCache` but operates on a single tensor/cache
pair, for cases where only one of K/V needs to be written (e.g. SAGE prefill
only stores V into the static cache while K lives in a separate cache).

Args:
states (torch.Tensor): Shape (token_num, kv_head_num, head_dim) — new tokens.
cache (torch.Tensor): Shape (total_phys_blocks, kv_heads, block_size, head_dim) — paged cache.
block_table (torch.Tensor | None): Logical-to-physical block mapping.
cu_q_lens (torch.Tensor | None): Cumulative query lengths. ``None`` indicates decode mode.
context_kv_lens (torch.Tensor | None): KV lengths before storing current tokens.
chunk_metadata (torch.Tensor | None): Optimized precomputed store plan with shape ``(num_chunks, 4)``
and per-row ``(src_token_start, dst_block_id, dst_block_offset, chunk_len)``.

Returns:
torch.Tensor: Updated ``cache`` after in-place writes.
"""
assert len(states.shape) == 3, "states must be (token_num, kv_head_num, head_dim), please check."

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The cache tensor is expected to have exactly 4 dimensions (total_phys_blocks, kv_heads, block_size, head_dim). Adding an explicit dimension check for cache (similar to the check for states) will prevent unexpected IndexError when accessing cache.shape[2] and provide a clearer error message.

Suggested change
assert len(states.shape) == 3, "states must be (token_num, kv_head_num, head_dim), please check."
assert len(states.shape) == 3, "states must be (token_num, kv_head_num, head_dim), please check."
assert len(cache.shape) == 4, "cache must be (total_phys_blocks, kv_heads, block_size, head_dim), please check."


if chunk_metadata is None:
assert block_table is not None, "block_table is required when chunk_metadata is not provided."
assert context_kv_lens is not None, "context_kv_lens is required when chunk_metadata is not provided."
chunk_metadata = build_paged_kv_chunk_metadata(
block_table,
cu_q_lens,
context_kv_lens,
cache.shape[2],
)
else:
assert block_table is None and cu_q_lens is None and context_kv_lens is None, (
"chunk_metadata path should not be mixed with block_table/cu_q_lens/context_kv_lens."
)

assert_paged_kv_store_contract(chunk_metadata)

if chunk_metadata.shape[0] == 0:
return cache

for src_token_start, dst_block_id, dst_block_offset, chunk_len in chunk_metadata.tolist():
src_end = src_token_start + chunk_len
dst_end = dst_block_offset + chunk_len
cache[dst_block_id, :, dst_block_offset:dst_end, :] = states[src_token_start:src_end].permute(
1, 0, 2
)

return cache
147 changes: 147 additions & 0 deletions mojo_opset/tests/accuracy/operators/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch

from mojo_opset import MojoStorePagedKVCache
from mojo_opset import MojoStorePagedSingleCache
from mojo_opset.experimental import MojoStorePagedMLAKVCache
from mojo_opset.tests.utils import assert_close
from mojo_opset.tests.utils import auto_switch_platform
Expand Down Expand Up @@ -420,6 +421,152 @@ def test_store_paged_kv_chunk_metadata_perf_and_accuracy():
)


# ===========================================================================
# MojoStorePagedSingleCache
# ===========================================================================

@pytest.mark.parametrize(
"batch_size, kv_heads, head_dim, block_size, context_kv_lens_val, q_lens_val",
[
(2, 2, 128, 128, [0, 0], [130, 33]),
(2, 2, 128, 128, [32, 35], [1, 1]),
(2, 2, 128, 128, [15, 40], [788, 126]),
(2, 2, 128, 512, [255, 511], [300, 257]),
(2, 2, 128, 2048, [1023, 2047], [900, 1025]),
(1, 1, 128, 128, [0], [5]),
(1, 1, 128, 128, [5], [1]),
(1, 1, 128, 2048, [2046], [2]),
(3, 2, 128, 128, [32, -1, 35], [1, 1, 1]),
(3, 2, 128, 128, [0, -1, 5], [4, 0, 2]),
(3, 2, 128, 512, [510, -1, 700], [4, 1, 300]),
(8, 2, 128, 128, [224, 542, 34, 41, 54, 57, 65, 0], [432, 84, 977, 93, 23, 89, 31, 555]),
(8, 2, 128, 128, [772, 974, 3232, 43, 77, 7633, 888, 1], [1, 1, 1, 1, 1, 1, 1, 1]),
],
)
@bypass_not_implemented
def test_store_paged_single_cache(batch_size, kv_heads, head_dim, block_size, context_kv_lens_val, q_lens_val):
case = _build_store_paged_kv_case(
batch_size,
kv_heads,
head_dim,
block_size,
context_kv_lens_val,
q_lens_val,
device=get_torch_device(),
)

store_single_ref = MojoStorePagedSingleCache._registry.get("torch")()
store_single = MojoStorePagedSingleCache()
if type(store_single_ref) is type(store_single):
raise NotImplementedError("both operands resolve to the same implementation, skipping comparison.")

cache_ref = store_single_ref(
case["key_states"],
case["k_cache"].clone(),
chunk_metadata=case["chunk_metadata"],
)
cache = store_single(
case["key_states"],
case["k_cache"].clone(),
chunk_metadata=case["chunk_metadata"],
)

assert_close(cache, cache_ref)


@pytest.mark.parametrize(
"batch_size, kv_heads, head_dim, block_size, context_kv_lens_val, q_lens_val",
[
(1, 2, 128, 16, [0], [3]),
(1, 2, 128, 128, [127], [1]),
(2, 4, 128, 32, [5, 33], [7, 19]),
(2, 4, 128, 256, [255, 511], [1, 1]),
(3, 8, 128, 128, [17, -1, 63], [1, 1, 1]),
(4, 16, 128, 128, [0, 3, 127, 255], [9, 17, 33, 65]),
(4, 16, 128, 512, [511, 1025, 7, 63], [1, 1, 1, 1]),
(6, 24, 128, 128, [31, 511, 1023, 7, 95, 1535], [129, 257, 513, 5, 17, 65]),
],
)
@bypass_not_implemented
def test_store_paged_single_cache_without_chunk_metadata(
batch_size,
kv_heads,
head_dim,
block_size,
context_kv_lens_val,
q_lens_val,
):
case = _build_store_paged_kv_case(
batch_size,
kv_heads,
head_dim,
block_size,
context_kv_lens_val,
q_lens_val,
device=get_torch_device(),
)

store_single_ref = MojoStorePagedSingleCache._registry.get("torch")()
store_single = MojoStorePagedSingleCache()
if type(store_single_ref) is type(store_single):
raise NotImplementedError("both operands resolve to the same implementation, skipping comparison.")

cache_ref = store_single_ref(
case["key_states"],
case["k_cache"].clone(),
case["block_table"],
case["cu_q_lens"],
case["context_kv_lens"],
)
cache = store_single(
case["key_states"],
case["k_cache"].clone(),
case["block_table"],
case["cu_q_lens"],
case["context_kv_lens"],
)

assert_close(cache, cache_ref)


@bypass_not_implemented
def test_store_paged_single_cache_matches_full_kv_store():
batch_size = 4
kv_heads = 2
head_dim = 128
block_size = 128
context_kv_lens_val = [224, 0, 34, 41]
q_lens_val = [432, 84, 977, 93]

case = _build_store_paged_kv_case(
batch_size,
kv_heads,
head_dim,
block_size,
context_kv_lens_val,
q_lens_val,
device=get_torch_device(),
)

store_kv_ref = MojoStorePagedKVCache._registry.get("torch")()
store_single_ref = MojoStorePagedSingleCache._registry.get("torch")()

k_cache_ref, _ = store_kv_ref(
case["key_states"],
case["value_states"],
case["k_cache"].clone(),
case["v_cache"].clone(),
chunk_metadata=case["chunk_metadata"],
)
cache_single = store_single_ref(
case["key_states"],
case["k_cache"].clone(),
chunk_metadata=case["chunk_metadata"],
)

assert_close(cache_single, k_cache_ref)


# ===========================================================================
# MojoStorePagedMLAKVCache
# ===========================================================================
Expand Down
Loading