Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 144 additions & 29 deletions mojo_opset/backends/ttx/kernels/npu/silu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import torch
import triton
import triton.language as tl

from .utils import libentry

from mojo_opset.backends.ttx.kernels.npu.utils import VEC_ALIGN_BYTES
from mojo_opset.backends.ttx.kernels.utils import align
from .utils import get_num_cores, libentry

"""
This file contains the implementation of SiLU (Sigmoid Linear Unit) for NPU.
Expand All @@ -18,7 +14,7 @@
"""


COL_BLOCKING_THRESHOLD = 4096
MAX_BLOCK_SIZE_N = 2048


@triton.jit
Expand Down Expand Up @@ -80,6 +76,92 @@ def _silu_fwd_kernel(
tl.store(y_ptrs, y_chunk, mask=block_mask)


@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_M": 1}),
triton.Config({"BLOCK_SIZE_M": 2}),
triton.Config({"BLOCK_SIZE_M": 4}),
triton.Config({"BLOCK_SIZE_M": 8}),
triton.Config({"BLOCK_SIZE_M": 16}),
],
key=["n_rows", "n_cols"],
)
@libentry()
@triton.jit
def _silu_fwd_nomask_kernel(
x,
y,
stride_row,
n_rows,
n_cols,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
grid_size = tl.num_programs(axis=0)

num_row_tasks = (n_rows + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M

for row_task_id in range(pid, num_row_tasks, grid_size):
block_start_row = row_task_id * BLOCK_SIZE_M
rows_off = block_start_row + tl.arange(0, BLOCK_SIZE_M)

for col_offset in range(0, n_cols, BLOCK_SIZE_N):
cols_off = col_offset + tl.arange(0, BLOCK_SIZE_N)

x_ptrs = x + rows_off[:, None] * stride_row + cols_off[None, :]
y_ptrs = y + rows_off[:, None] * stride_row + cols_off[None, :]

x_chunk = tl.load(x_ptrs)
x_f32 = x_chunk.to(tl.float32)
y_f32 = silu_activation(x_f32)
y_chunk = y_f32.to(x_chunk.dtype)

tl.store(y_ptrs, y_chunk)


@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_M": 1}),
triton.Config({"BLOCK_SIZE_M": 2}),
triton.Config({"BLOCK_SIZE_M": 4}),
triton.Config({"BLOCK_SIZE_M": 8}),
triton.Config({"BLOCK_SIZE_M": 16}),
],
key=["n_rows", "n_cols"],
)
@libentry()
@triton.jit
def _silu_fwd_nomask_single_kernel(
x,
y,
stride_row,
n_rows,
n_cols,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
grid_size = tl.num_programs(axis=0)

num_row_tasks = (n_rows + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
cols_off = tl.arange(0, BLOCK_SIZE_N)

for row_task_id in range(pid, num_row_tasks, grid_size):
block_start_row = row_task_id * BLOCK_SIZE_M
rows_off = block_start_row + tl.arange(0, BLOCK_SIZE_M)

x_ptrs = x + rows_off[:, None] * stride_row + cols_off[None, :]
y_ptrs = y + rows_off[:, None] * stride_row + cols_off[None, :]

x_chunk = tl.load(x_ptrs)
x_f32 = x_chunk.to(tl.float32)
y_f32 = silu_activation(x_f32)
y_chunk = y_f32.to(x_chunk.dtype)

tl.store(y_ptrs, y_chunk)


@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_M": 1}),
Expand Down Expand Up @@ -139,6 +221,30 @@ def _silu_bwd_kernel(
tl.store(dx_ptrs, dx_chunk, mask=block_mask)


def _rowwise_block_size_n(n_cols):
return min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE_N)


def _rowwise_grid(n_rows, block_size_m):
num_row_tasks = (n_rows + block_size_m - 1) // block_size_m
return (max(1, min(get_num_cores("vector"), num_row_tasks)),)


def _rowwise_autotune_grid(n_rows):
return lambda META: _rowwise_grid(n_rows, META["BLOCK_SIZE_M"])


SILU_NOMASK_MAX_BLOCK_SIZE_M = 16


def _can_use_nomask_kernel(n_rows, n_cols, block_size_n):
return n_cols % block_size_n == 0 and n_rows % SILU_NOMASK_MAX_BLOCK_SIZE_M == 0


def _can_use_nomask_single_kernel(n_rows, n_cols, block_size_n):
return n_cols == block_size_n and n_rows % SILU_NOMASK_MAX_BLOCK_SIZE_M == 0


def silu_fwd_impl(
x: torch.Tensor,
) -> torch.Tensor:
Expand All @@ -159,22 +265,36 @@ def silu_fwd_impl(

y = torch.empty_like(x_2d)

if n_cols > COL_BLOCKING_THRESHOLD:
BLOCK_SIZE_N = 2048
block_size_n = _rowwise_block_size_n(n_cols)
grid = _rowwise_autotune_grid(n_rows)

if _can_use_nomask_single_kernel(n_rows, n_cols, block_size_n):
_silu_fwd_nomask_single_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=block_size_n,
)
elif _can_use_nomask_kernel(n_rows, n_cols, block_size_n):
_silu_fwd_nomask_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=block_size_n,
)
else:
BLOCK_SIZE_N = align(x, n_cols, VEC_ALIGN_BYTES)

num_programs = triton.runtime.driver.active.utils.get_device_properties("npu")["num_vectorcore"]
grid = (num_programs,)

_silu_fwd_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
_silu_fwd_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=block_size_n,
)

return y.reshape(*ori_shape)

Expand Down Expand Up @@ -202,13 +322,8 @@ def silu_bwd_impl(

dx = torch.empty_like(x_2d)

if n_cols > COL_BLOCKING_THRESHOLD:
BLOCK_SIZE_N = 2048
else:
BLOCK_SIZE_N = align(dy, n_cols, VEC_ALIGN_BYTES)

num_programs = triton.runtime.driver.active.utils.get_device_properties("npu")["num_vectorcore"]
grid = (num_programs,)
block_size_n = _rowwise_block_size_n(n_cols)
grid = _rowwise_autotune_grid(n_rows)

_silu_bwd_kernel[grid](
dy_2d,
Expand All @@ -217,7 +332,7 @@ def silu_bwd_impl(
dy_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_N=block_size_n,
)

return dx.reshape(*ori_shape)