Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 112 additions & 44 deletions mojo_opset/backends/ttx/kernels/npu/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,33 @@
}

SRAM_ALIGNMENT = 32

_MOJO_UB_LIMIT_BYTES = 192 * 1024
_MOJO_ROPE_UB_SAFETY_FACTOR = 4

# When the half RoPE dimension satisfies the SRAM byte-alignment requirement,
# we can leverage a more efficient extension API to perform the RoPE computation.
def _is_half_rope_dim_aligned(half_rope_dim: int, dtype_size: int = 2) -> bool:
return (half_rope_dim * dtype_size) % SRAM_ALIGNMENT == 0


def _mojo_estimate_rope_ub_bytes(
token_block_size: int,
n_qh: int,
n_kh: int,
rope_dim: int,
compute_dtype_size: int,
) -> int:
# The Ascend backend allocates extra UB buffers for slices and intermediates.
# A conservative multiplier prevents compile-time UB overflow on large rope_dim.
return (
token_block_size
* (n_qh + n_kh)
* rope_dim
* compute_dtype_size
* _MOJO_ROPE_UB_SAFETY_FACTOR
)


def _get_token_block_size(n_qh: int, n_kh: int) -> int:
assert n_qh <= 84 and n_kh <= 84, "don't support head_num > 84, please raise an issue."

Expand All @@ -41,6 +60,47 @@ def _get_token_block_size(n_qh: int, n_kh: int) -> int:

return 1

def _get_token_block_size_opt(
n_qh: int,
n_kh: int,
rope_dim: int,
compute_dtype_size: int,
) -> int:
assert n_qh <= 84 and n_kh <= 84, "don't support head_num > 84, please raise an issue."

if (n_qh, n_kh) in ROPE_TOKEN_BLOCK_SIZE_TABLE:
block_size = ROPE_TOKEN_BLOCK_SIZE_TABLE[(n_qh, n_kh)]
return _mojo_limit_token_block_size_by_ub(
block_size, n_qh, n_kh, rope_dim, compute_dtype_size
)

for (q_thresh, k_thresh), block_size in sorted(
ROPE_TOKEN_BLOCK_SIZE_TABLE.items(), key=lambda x: (x[0][0], x[0][1])
):
if n_qh <= q_thresh and n_kh <= k_thresh:
return _mojo_limit_token_block_size_by_ub(
block_size, n_qh, n_kh, rope_dim, compute_dtype_size
)

return 1


def _mojo_limit_token_block_size_by_ub(
token_block_size: int,
n_qh: int,
n_kh: int,
rope_dim: int,
compute_dtype_size: int,
) -> int:
while token_block_size > 1:
estimated_ub_bytes = _mojo_estimate_rope_ub_bytes(
token_block_size, n_qh, n_kh, rope_dim, compute_dtype_size
)
if estimated_ub_bytes <= _MOJO_UB_LIMIT_BYTES:
return token_block_size
token_block_size -= 1
return 1
Comment thread
YangLong114514 marked this conversation as resolved.
Outdated


@tensor_cache
def prepare_chunk_indices(
Expand Down Expand Up @@ -220,32 +280,28 @@ def _rope_inplace_kernel(

global_seq_offsets = seq_offsets

cos_token_ptr = cos_ptr + batch_idx * cos_batch_stride + seq_offsets[:, None] * cos_seq_stride
sin_token_ptr = sin_ptr + batch_idx * sin_batch_stride + seq_offsets[:, None] * sin_seq_stride

half_rope_dim_offsets = tl.arange(0, half_rope_dim)
half_rope_dim_mask = half_rope_dim_offsets < half_rope_dim

cos_block_2d = tl.load(
cos_token_ptr + half_rope_dim_offsets[None, :],
mask=seq_mask[:, None] & half_rope_dim_mask[None, :],
other=0,
cos_offsets = (
cos_ptr
+ batch_idx * cos_batch_stride
+ seq_offsets[:, None, None] * cos_seq_stride
+ half_rope_dim_offsets[None, None, :]
)
sin_block_2d = tl.load(
sin_token_ptr + half_rope_dim_offsets[None, :],
mask=seq_mask[:, None] & half_rope_dim_mask[None, :],
other=0,
sin_offsets = (
sin_ptr
+ batch_idx * sin_batch_stride
+ seq_offsets[:, None, None] * sin_seq_stride
+ half_rope_dim_offsets[None, None, :]
)
cos_tile = tl.load(cos_offsets, mask=seq_mask[:, None, None], other=0.0)
sin_tile = tl.load(sin_offsets, mask=seq_mask[:, None, None], other=0.0)

head_q_offsets = tl.arange(0, n_qh)
head_k_offsets = tl.arange(0, n_kh)

cos_tile = tl.reshape(cos_block_2d, (TOKEN_BLOCK_SIZE, 1, half_rope_dim), can_reorder=True)
sin_tile = tl.reshape(sin_block_2d, (TOKEN_BLOCK_SIZE, 1, half_rope_dim), can_reorder=True)

if ALIGNED:
rope_dim_offsets = tl.arange(0, rope_dim)
rope_dim_mask = rope_dim_offsets < rope_dim

q_offsets = (
batch_idx * q_batch_stride
Expand All @@ -254,11 +310,18 @@ def _rope_inplace_kernel(
+ nope_dim
+ rope_dim_offsets[None, None, :]
)
q_mask = seq_mask[:, None, None] & (head_q_offsets[None, :, None] < n_qh) & rope_dim_mask[None, None, :]

q_tile = tl.load(q_ptr + q_offsets, mask=q_mask, other=0.0).to(sin_block_2d.dtype)
q_tile = _compute_rope(q_tile, sin_tile, cos_tile, n_qh, half_rope_dim, TOKEN_BLOCK_SIZE, INVERSE)
tl.store(q_ptr + q_offsets, q_tile, mask=q_mask)
q_tile = tl.load(q_ptr + q_offsets, mask=seq_mask[:, None, None], other=0.0).to(tl.float32)
q_tile = _compute_rope(
q_tile,
sin_tile,
cos_tile,
n_qh,
half_rope_dim,
TOKEN_BLOCK_SIZE,
INVERSE,
).to(q_ptr.dtype.element_ty)
tl.store(q_ptr + q_offsets, q_tile, mask=seq_mask[:, None, None])

k_offsets = (
batch_idx * k_batch_stride
Expand All @@ -267,11 +330,18 @@ def _rope_inplace_kernel(
+ nope_dim
+ rope_dim_offsets[None, None, :]
)
k_mask = seq_mask[:, None, None] & (head_k_offsets[None, :, None] < n_kh) & rope_dim_mask[None, None, :]

k_tile = tl.load(k_ptr + k_offsets, mask=k_mask, other=0).to(sin_block_2d.dtype)
k_tile = _compute_rope(k_tile, sin_tile, cos_tile, n_kh, half_rope_dim, TOKEN_BLOCK_SIZE, INVERSE)
tl.store(k_ptr + k_offsets, k_tile, mask=k_mask)
k_tile = tl.load(k_ptr + k_offsets, mask=seq_mask[:, None, None], other=0.0).to(tl.float32)
k_tile = _compute_rope(
k_tile,
sin_tile,
cos_tile,
n_kh,
half_rope_dim,
TOKEN_BLOCK_SIZE,
INVERSE,
).to(k_ptr.dtype.element_ty)
tl.store(k_ptr + k_offsets, k_tile, mask=seq_mask[:, None, None])
else:
q_offsets_half1 = (
batch_idx * q_batch_stride
Expand All @@ -288,15 +358,13 @@ def _rope_inplace_kernel(
+ half_rope_dim
+ half_rope_dim_offsets[None, None, :]
)
q_half_mask = (
seq_mask[:, None, None] & (head_q_offsets[None, :, None] < n_qh) & half_rope_dim_mask[None, None, :]
)

q_tile_1 = tl.load(q_ptr + q_offsets_half1, mask=q_half_mask, other=0.0).to(sin_block_2d.dtype)
q_tile_2 = tl.load(q_ptr + q_offsets_half2, mask=q_half_mask, other=0.0).to(sin_block_2d.dtype)
q_tile_1 = tl.load(q_ptr + q_offsets_half1, mask=seq_mask[:, None, None], other=0.0).to(tl.float32)
q_tile_2 = tl.load(q_ptr + q_offsets_half2, mask=seq_mask[:, None, None], other=0.0).to(tl.float32)
new_q_1, new_q_2 = _compute_rope_separated(q_tile_1, q_tile_2, sin_tile, cos_tile, INVERSE)
tl.store(q_ptr + q_offsets_half1, new_q_1, mask=q_half_mask)
tl.store(q_ptr + q_offsets_half2, new_q_2, mask=q_half_mask)
new_q_1 = new_q_1.to(q_ptr.dtype.element_ty)
new_q_2 = new_q_2.to(q_ptr.dtype.element_ty)
tl.store(q_ptr + q_offsets_half1, new_q_1, mask=seq_mask[:, None, None])
tl.store(q_ptr + q_offsets_half2, new_q_2, mask=seq_mask[:, None, None])

k_offsets_half1 = (
batch_idx * k_batch_stride
Expand All @@ -313,15 +381,13 @@ def _rope_inplace_kernel(
+ half_rope_dim
+ half_rope_dim_offsets[None, None, :]
)
k_half_mask = (
seq_mask[:, None, None] & (head_k_offsets[None, :, None] < n_kh) & half_rope_dim_mask[None, None, :]
)

k_tile_1 = tl.load(k_ptr + k_offsets_half1, mask=k_half_mask, other=0.0).to(sin_block_2d.dtype)
k_tile_2 = tl.load(k_ptr + k_offsets_half2, mask=k_half_mask, other=0.0).to(sin_block_2d.dtype)
k_tile_1 = tl.load(k_ptr + k_offsets_half1, mask=seq_mask[:, None, None], other=0.0).to(tl.float32)
k_tile_2 = tl.load(k_ptr + k_offsets_half2, mask=seq_mask[:, None, None], other=0.0).to(tl.float32)
new_k_1, new_k_2 = _compute_rope_separated(k_tile_1, k_tile_2, sin_tile, cos_tile, INVERSE)
tl.store(k_ptr + k_offsets_half1, new_k_1, mask=k_half_mask)
tl.store(k_ptr + k_offsets_half2, new_k_2, mask=k_half_mask)
new_k_1 = new_k_1.to(k_ptr.dtype.element_ty)
new_k_2 = new_k_2.to(k_ptr.dtype.element_ty)
tl.store(k_ptr + k_offsets_half1, new_k_1, mask=seq_mask[:, None, None])
tl.store(k_ptr + k_offsets_half2, new_k_2, mask=seq_mask[:, None, None])


def _normalize_to_bsnd(
Expand Down Expand Up @@ -442,11 +508,13 @@ def rope_fwd_impl(
nope_dim = head_dim - rope_dim
half_rope_dim = rope_dim // 2

is_aligned = _is_half_rope_dim_aligned(half_rope_dim)
token_block_size = _get_token_block_size(n_q_head, n_kv_head)
num_seq_blocks = (seq_len + token_block_size - 1) // token_block_size
is_aligned = _is_half_rope_dim_aligned(half_rope_dim, q.element_size())
token_block_size = _get_token_block_size_opt(n_q_head, n_kv_head, rope_dim, 4)
# num_seq_blocks = (seq_len + token_block_size - 1) // token_block_size

num_seq_blocks = triton.cdiv(seq_len, token_block_size)
Comment thread
YangLong114514 marked this conversation as resolved.
num_programs = get_num_cores()
num_programs = min(num_programs, batch_size * num_seq_blocks)
grid = (num_programs,)

cos = cos.contiguous()
Expand Down