Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 62 additions & 31 deletions mojo_opset/backends/ttx/kernels/mlu/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _swa_split_blocks(

return num_global_window_blocks, non_global_window_start_block, num_total_blocks


@triton.jit
def _swa_transposed_range_blocks(
kv_block_start_id,
Expand Down Expand Up @@ -719,6 +720,7 @@ def _swa_paged_decode_kernel(
NUM_TOTAL_BLOCKS,
MAX_NUM_BLOCKS_PER_SEQ,
stride_qb,
stride_qs,
stride_qh,
stride_qd,
stride_k_block,
Expand All @@ -730,11 +732,13 @@ def _swa_paged_decode_kernel(
stride_v_blksz,
stride_v_dim,
stride_ob,
stride_os,
stride_oh,
stride_od,
stride_bt_batch,
stride_bt_block,
softmax_scale,
Q_SEQLEN: tl.constexpr,
GLOBAL_WINDOW: tl.constexpr,
LOCAL_WINDOW: tl.constexpr,
NUM_Q_HEADS: tl.constexpr,
Expand All @@ -754,7 +758,6 @@ def _swa_paged_decode_kernel(
GQA_HEAD_STRIDE: tl.constexpr = NUM_KV_HEADS if GQA_INTERLEAVE else 1
NUM_Q_HEAD_BLOCKS_PER_KV_HEAD: tl.constexpr = tl.cdiv(GQA_GROUP_SIZE, BLOCK_SIZE_Q_HEADS)


pid = tl.program_id(0)
n_progs = tl.num_programs(0)

Expand All @@ -768,26 +771,26 @@ def _swa_paged_decode_kernel(
kv_seq_len = tl.load(seqlens_ptr + b_id)

offs_d = tl.arange(0, BLOCK_SIZE_D)
offs_s = tl.arange(0, Q_SEQLEN)
offs_gqa_block = q_head_block_id * BLOCK_SIZE_Q_HEADS + tl.arange(0, BLOCK_SIZE_Q_HEADS)
offs_head_block = kv_head_id * GQA_GROUP_STRIDE + offs_gqa_block * GQA_HEAD_STRIDE
q_ptrs = q_ptr + b_id * stride_qb + offs_head_block[:, None] * stride_qh + offs_d[None, :] * stride_qd

q = tl.load(q_ptrs, mask = (offs_d[None, :] < HEAD_DIM) & (offs_gqa_block[:, None] < GQA_GROUP_SIZE), other = 0.0)
q_ptrs = q_ptr + b_id * stride_qb + offs_s[:, None, None] * stride_qs + offs_head_block[None, :, None] * stride_qh + offs_d[None, None, :] * stride_qd

m_i = tl.zeros((BLOCK_SIZE_Q_HEADS,), dtype=tl.float32) - float("inf")
l_i = tl.zeros((BLOCK_SIZE_Q_HEADS,), dtype=tl.float32)
acc = tl.zeros((BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D), dtype=tl.float32)
q = tl.load(q_ptrs, mask = (offs_d[None, None, :] < HEAD_DIM) & (offs_gqa_block[None, :, None] < GQA_GROUP_SIZE), other = 0.0)
q = tl.reshape(q, [Q_SEQLEN * BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D])
m_i = tl.zeros((BLOCK_SIZE_Q_HEADS * Q_SEQLEN,), dtype=tl.float32) - float("inf")
l_i = tl.zeros((BLOCK_SIZE_Q_HEADS * Q_SEQLEN,), dtype=tl.float32)
acc = tl.zeros((BLOCK_SIZE_Q_HEADS * Q_SEQLEN, BLOCK_SIZE_D), dtype=tl.float32)

num_global_window_blocks, non_global_window_start_block, num_total_blocks = _swa_split_blocks(
kv_seq_len - 1,
1,
kv_seq_len - Q_SEQLEN,
Q_SEQLEN,
kv_seq_len,
BLOCK_SIZE_N,
True,
GLOBAL_WINDOW,
LOCAL_WINDOW,
)


for kv_block_id in range(num_global_window_blocks):
kv_block_start = kv_block_id * BLOCK_SIZE_N
Expand All @@ -814,13 +817,18 @@ def _swa_paged_decode_kernel(
block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
order=(1, 0),
)

gw_mask = (kv_block_start + tl.arange(0, BLOCK_SIZE_N)) < GLOBAL_WINDOW
gw_mask = gw_mask[None, :]
kv_mask = tl.arange(0, BLOCK_SIZE_N) < kv_block_len
seq_offsets = tl.arange(0, BLOCK_SIZE_Q_HEADS * Q_SEQLEN) // BLOCK_SIZE_Q_HEADS
base = kv_block_start + tl.arange(0, BLOCK_SIZE_N)
if LOCAL_WINDOW is not None:
sw_mask = (kv_block_start + tl.arange(0, BLOCK_SIZE_N) + LOCAL_WINDOW) >= (kv_seq_len - 1)
sw_mask = base[None, :] + LOCAL_WINDOW >= kv_seq_len - Q_SEQLEN + seq_offsets[:, None]
gw_mask = gw_mask | sw_mask
kv_mask = tl.arange(0, BLOCK_SIZE_N) < kv_block_len
mask = gw_mask & kv_mask
casul_mask = base[None, :] <= kv_seq_len - Q_SEQLEN + seq_offsets[:, None]
mask = gw_mask & kv_mask[None, :] & casul_mask

acc, l_i, m_i = _decode_acc_fwd_MxN(
acc,
l_i,
Expand Down Expand Up @@ -862,14 +870,16 @@ def _swa_paged_decode_kernel(
block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
order=(1, 0),
)

kv_mask = tl.arange(0, BLOCK_SIZE_N) < kv_block_len
seq_offsets = tl.arange(0, BLOCK_SIZE_Q_HEADS * Q_SEQLEN) // BLOCK_SIZE_Q_HEADS
base = kv_block_start + tl.arange(0, BLOCK_SIZE_N)
casul_mask = base[None, :] <= kv_seq_len - Q_SEQLEN + seq_offsets[:, None]
if LOCAL_WINDOW is not None:
sw_mask = (kv_block_start + tl.arange(0, BLOCK_SIZE_N) + LOCAL_WINDOW) >= (kv_seq_len - 1)
mask = kv_mask & sw_mask
sw_mask = base[None, :] + LOCAL_WINDOW >= kv_seq_len - Q_SEQLEN + seq_offsets[:, None]
mask = sw_mask & kv_mask[None, :] & casul_mask
else:
mask = kv_mask

mask = kv_mask[None, :] & casul_mask
acc, l_i, m_i = _decode_acc_fwd_MxN(
acc,
l_i,
Expand All @@ -888,10 +898,10 @@ def _swa_paged_decode_kernel(

if kv_seq_len > 0:
# avoid division by zero
acc = acc / l_i[:, None]

o_ptrs = o_ptr + b_id * stride_ob + offs_head_block[:, None] * stride_oh + offs_d[None, :] * stride_od
tl.store(o_ptrs, acc.to(o_ptr.dtype.element_ty), mask=(offs_d[None, :] < HEAD_DIM) & (offs_gqa_block[:, None] < GQA_GROUP_SIZE))
acc = tl.where(l_i[:, None] > 0, acc / tl.where(l_i[:, None] > 0, l_i[:, None], 1.0), 0.0)
acc = tl.reshape(acc, [Q_SEQLEN, BLOCK_SIZE_Q_HEADS, BLOCK_SIZE_D])
o_ptrs = o_ptr + b_id * stride_ob + offs_s[:, None, None] * stride_os + offs_head_block[None, :, None] * stride_oh + offs_d[None, None, :] * stride_od
tl.store(o_ptrs, acc.to(o_ptr.dtype.element_ty), mask = (offs_d[None, None, :] < HEAD_DIM) & (offs_gqa_block[None, :, None] < GQA_GROUP_SIZE))


def swa_paged_decode_impl(
Expand All @@ -905,12 +915,31 @@ def swa_paged_decode_impl(
gqa_interleave: bool = False,
softmax_scale: Optional[float] = None,
) -> torch.Tensor:
batch_size, num_q_heads, head_dim = q.shape
if q.ndim == 4:
batch_size, seq_lens, num_q_heads, head_dim = q.shape
stride_qb, stride_qs, stride_qh, stride_qd = q.stride()
stride_ob = seq_lens * num_q_heads * head_dim
stride_os = num_q_heads * head_dim
stride_oh = head_dim
stride_od = 1
assert torch.all(seqlens >= seq_lens), f"the seqlens of kv cache must larger than seq_lens({seq_lens}) of q, \
but: {seqlens[torch.where(seqlens < seq_lens)]}"
else:
batch_size, num_q_heads, head_dim = q.shape
stride_qb, stride_qh, stride_qd = q.stride()
seq_lens = 1
stride_qs = 1
Comment thread
seainair marked this conversation as resolved.
stride_ob = num_q_heads * head_dim
stride_os = 1
stride_oh = head_dim
stride_od = 1

num_total_blocks, num_kv_heads, page_size, head_dim_cache = key_cache.shape

max_num_blocks_per_seq = block_tables.shape[1]

assert head_dim == head_dim_cache
assert seq_lens <= 4
if softmax_scale is None:
softmax_scale = 1.0 / (head_dim**0.5)

Expand All @@ -933,9 +962,10 @@ def swa_paged_decode_impl(
batch_size,
num_total_blocks,
max_num_blocks_per_seq,
q.stride(0),
q.stride(1),
q.stride(2),
stride_qb,
stride_qs,
stride_qh,
stride_qd,
key_cache.stride(0),
key_cache.stride(1),
key_cache.stride(2),
Expand All @@ -944,12 +974,14 @@ def swa_paged_decode_impl(
value_cache.stride(1),
value_cache.stride(2),
value_cache.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
stride_ob,
stride_os,
stride_oh,
stride_od,
block_tables.stride(0),
block_tables.stride(1),
softmax_scale,
seq_lens,
global_window_size,
local_window_size,
num_q_heads,
Expand All @@ -963,5 +995,4 @@ def swa_paged_decode_impl(
num_warps=1, num_stages=3,
pipeline_strategies=["reduce_delay"],
)

return o
31 changes: 31 additions & 0 deletions mojo_opset/backends/ttx/operators/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from mojo_opset.experimental import MojoPagedDecodeSWAWithKVDequant
from mojo_opset.experimental import MojoPagedPrefillGQAWithKVDequant
from mojo_opset.experimental import MojoPagedPrefillSWAWithKVDequant
from mojo_opset.experimental import MojoPagedDecodeNstepSWA


class TTXPagedPrefillGQA(MojoPagedPrefillGQA):
Expand Down Expand Up @@ -430,3 +431,33 @@ def forward(
self.gqa_interleave,
)
return o


class TTXPagedDecodeNstepSWA(MojoPagedDecodeNstepSWA):
supported_platforms_list = ["mlu"]

def forward(
self,
q: torch.Tensor, # [bsz, seq_len, n_q_heads, head_dim]
k_cache: torch.Tensor, # [n_pages, n_kv_heads, page_size, head_dim]
v_cache: torch.Tensor, # [n_pages, n_kv_heads, page_size, head_dim]
total_seq_lens: torch.Tensor, # [bsz]
block_table: torch.Tensor, # [bsz, max_num_blocks]
softmax_scale: Optional[float] = None,
*,
max_total_seq_len: Optional[int] = None,
) -> torch.Tensor:
assert_paged_decode_contract(block_table, total_seq_lens)
o = swa_paged_decode(
q,
k_cache,
v_cache,
total_seq_lens,
block_table,
self.local_window_size,
self.global_window_size,
self.gqa_interleave,
softmax_scale,
)

return o
2 changes: 2 additions & 0 deletions mojo_opset/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .operators.attention import MojoPagedDecodeNSA
from .operators.attention import MojoPagedDecodeGQAWithKVDequant
from .operators.attention import MojoPagedDecodeSWAWithKVDequant
from .operators.attention import MojoPagedDecodeNstepSWA
from .operators.attention import MojoPagedPrefillGQAWithKVDequant
from .operators.attention import MojoPagedPrefillMLA
from .operators.attention import MojoPagedPrefillNSA
Expand Down Expand Up @@ -50,6 +51,7 @@
"MojoPagedDecodeGQAWithKVDequant",
"MojoPagedPrefillSWAWithKVDequant",
"MojoPagedDecodeSWAWithKVDequant",
"MojoPagedDecodeNstepSWA",
"MojoFusedAttnOutputGate",
"MojoPagedPrefillSageGQA",
"MojoStorePagedMLAKVCache",
Expand Down
112 changes: 112 additions & 0 deletions mojo_opset/experimental/operators/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,117 @@ def extra_repr(self):
return f"{self.is_causal=}, {self.gqa_layout=}, {self.global_window_size=}, {self.local_window_size=}, {self.query_dtype=}, {self.context_dtype=}, {self.compute_dtype=}".replace("self.", "")


class MojoPagedDecodeNstepSWA(MojoOperator):
def __init__(
self,
is_causal: bool = True,
gqa_layout: str = "AABB",
global_window_size: Optional[int] = None,
local_window_size: Optional[int] = None,
):
"""
Paged decode SWA operator that consumes a multi-step (n-step) query of shape
[bsz, seq_len, n_q_heads, head_dim].

Parameter descriptions:
- gqa_layout (str): GQA head grouping layout, values {"ABAB","AABB"}, default "AABB".
- is_causal (bool): Whether to enable causal masking, default True.
- global_window_size (Optional[int]): Global attention window length; None means no global window,
default None. Only effective when is_causal=True.
- local_window_size (Optional[int]): Local attention window length; None means no local window,
default None. Only effective when is_causal=True.
"""
super().__init__()

if gqa_layout not in ["ABAB", "AABB"]:
raise ValueError(f"gqa_layout must be one of ['ABAB', 'AABB'], got {gqa_layout}")

self.is_causal = is_causal
self.gqa_layout = gqa_layout
self.gqa_interleave = gqa_layout == "ABAB"
self.global_window_size = global_window_size
self.local_window_size = local_window_size

def forward(
self,
query: torch.Tensor, # [bsz, seq_len, n_q_heads, head_dim]
key_cache: torch.Tensor, # [n_pages, n_kv_heads, page_size, head_dim]
value_cache: torch.Tensor, # [n_pages, n_kv_heads, page_size, head_dim]
total_seq_lens: torch.Tensor, # [bsz]
block_table: torch.Tensor, # [bsz, max_num_blocks]
softmax_scale: Optional[float] = None,
*,
max_total_seq_len: Optional[int] = None,
) -> torch.Tensor:
# Note: for decode kernel, is_causal = False should never happen

assert_paged_decode_contract(block_table, total_seq_lens)
assert query.ndim == 4, (
f"MojoPagedDecodeNstepSWA expects 4D query [bsz, seq_len, n_q_heads, head_dim], got ndim={query.ndim}"
)
bsz, seq_len, n_q_heads, head_dim = query.shape

_, n_kv_heads, page_size, _ = key_cache.shape
if softmax_scale is None:
softmax_scale = 1.0 / (head_dim**0.5)

o = torch.zeros_like(query)
for i in range(bsz):
q_i = query[i] # -> [seq_len, n_q_heads, head_dim]
q_i = q_i.permute(1, 0, 2) # -> [n_q_heads, seq_len, head_dim]

kv_seq_len = total_seq_lens[i].item()
if kv_seq_len <= 0:
# skip padded tokens
continue
if block_table[i, 0].item() < 0:
raise ValueError("Paged decode requires a valid block table for rows with kv lens > 0.")
kv_blocks = (kv_seq_len + page_size - 1) // page_size
k_i = key_cache[block_table[i, :kv_blocks]] # [kv_blocks, n_kv_heads, page_size, head_dim]
k_i = k_i.permute(1, 0, 2, 3).reshape(n_kv_heads, kv_blocks * page_size, head_dim)[:, :kv_seq_len]
k_i_T = k_i.permute(0, 2, 1) # -> [n_kv_heads, head_dim, kv_seq_len]
if n_q_heads != n_kv_heads:
if self.gqa_interleave:
k_i_T = k_i_T.repeat((n_q_heads // n_kv_heads, 1, 1))
else:
k_i_T = k_i_T.repeat_interleave(
n_q_heads // n_kv_heads, dim=0
) # -> [n_q_heads, head_dim, kv_seq_len]
s_i = torch.bmm(q_i, k_i_T).float() * softmax_scale # -> [n_q_heads, seq_len, kv_seq_len]

if self.is_causal:
s_mask = _generate_window_mask(
seq_len,
kv_seq_len,
self.local_window_size,
self.global_window_size,
).to(s_i.device)
s_i = torch.where(s_mask, s_i, float("-inf"))
m_i = torch.max(s_i, dim=-1, keepdim=True).values # -> [n_q_heads, seq_len, 1]
s_i = s_i - m_i # -> [n_q_heads, seq_len, kv_seq_len]
p_i = torch.exp(s_i)
l_i = torch.sum(p_i, dim=-1, keepdim=True) # -> [n_q_heads, seq_len, 1]
p_i = p_i.to(query.dtype)

v_i = value_cache[block_table[i, :kv_blocks]]
v_i = v_i.permute(1, 0, 2, 3).reshape(n_kv_heads, kv_blocks * page_size, head_dim)[
:, :kv_seq_len
] # -> [n_kv_heads, kv_seq_len, head_dim]
if n_q_heads != n_kv_heads:
if self.gqa_interleave:
v_i = v_i.repeat((n_q_heads // n_kv_heads, 1, 1))
else:
v_i = v_i.repeat_interleave(n_q_heads // n_kv_heads, dim=0) # -> [n_q_heads, kv_seq_len, head_dim]
o_i = torch.bmm(p_i, v_i).float() # -> [n_q_heads, seq_len, head_dim]
o_i = o_i / l_i # -> [n_q_heads, seq_len, head_dim]
o_i = o_i.permute(1, 0, 2) # -> [seq_len, n_q_heads, head_dim]
o[i] = o_i.to(o.dtype)
return o

def extra_repr(self) -> str:
return f"is_causal={self.is_causal}, gqa_layout={self.gqa_layout}, global_window_size={self.global_window_size}, local_window_size={self.local_window_size}"


# ---------------------------------------------------------------------------
# NSA (Native Sparse Attention) helpers - module-level to avoid triggering
# MojoOperator.__init_subclass__ registration.
Expand Down Expand Up @@ -1742,5 +1853,6 @@ def extra_repr(self) -> str:
"MojoPagedDecodeGQAWithKVDequant",
"MojoPagedPrefillSWAWithKVDequant",
"MojoPagedDecodeSWAWithKVDequant",
"MojoPagedDecodeNstepSWA",
"MojoPagedPrefillSageGQA",
]
Loading
Loading