Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 155 additions & 57 deletions mojo_opset/backends/ttx/kernels/npu/gelu.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import torch
import triton
import triton.language as tl

from .utils import libentry

from mojo_opset.backends.ttx.kernels.npu.utils import VEC_ALIGN_BYTES
from mojo_opset.backends.ttx.kernels.utils import align
import triton.language.extra.cann.libdevice as libdevice
from .utils import get_num_cores, libentry

"""
This file contains the implementation of GELU (Gaussian Error Linear Unit) for NPU.

GELU formula: gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
GELU formula: gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))

Based on Liger Kernel implementation:
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/geglu.py
Expand All @@ -19,30 +16,32 @@
"""


COL_BLOCKING_THRESHOLD = 4096
MAX_BLOCK_SIZE_N = 1024


GELU_TANH_BLOCK_SIZE_M_CONFIGS = [
triton.Config({"BLOCK_SIZE_M": 1}),
triton.Config({"BLOCK_SIZE_M": 2}),
triton.Config({"BLOCK_SIZE_M": 4}),
triton.Config({"BLOCK_SIZE_M": 8}),
]

GELU_TANH_MAX_BLOCK_SIZE_M = max(
config.kwargs["BLOCK_SIZE_M"] for config in GELU_TANH_BLOCK_SIZE_M_CONFIGS
)


@triton.jit
def gelu_tanh_approx(x):
"""GELU activation using tanh approximation"""
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / π)
"""GELU activation using tanh approximation."""
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
x_cubed = x * x * x
tanh_arg = sqrt_2_over_pi * (x + 0.044715 * x_cubed)
return 0.5 * x * (1 + tl.tanh(tanh_arg))
return 0.5 * x * (1 + libdevice.tanh(tanh_arg))


@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_M": 1}),
triton.Config({"BLOCK_SIZE_M": 2}),
triton.Config({"BLOCK_SIZE_M": 4}),
triton.Config({"BLOCK_SIZE_M": 8}),
triton.Config({"BLOCK_SIZE_M": 12}),
triton.Config({"BLOCK_SIZE_M": 16}),
triton.Config({"BLOCK_SIZE_M": 20}),
triton.Config({"BLOCK_SIZE_M": 24}),
triton.Config({"BLOCK_SIZE_M": 32}),
],
configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS,
key=["n_rows", "n_cols"],
)
@libentry()
Expand Down Expand Up @@ -85,17 +84,81 @@ def _gelu_fwd_kernel(


@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_M": 1}),
triton.Config({"BLOCK_SIZE_M": 2}),
triton.Config({"BLOCK_SIZE_M": 4}),
triton.Config({"BLOCK_SIZE_M": 8}),
triton.Config({"BLOCK_SIZE_M": 12}),
triton.Config({"BLOCK_SIZE_M": 16}),
triton.Config({"BLOCK_SIZE_M": 20}),
triton.Config({"BLOCK_SIZE_M": 24}),
triton.Config({"BLOCK_SIZE_M": 32}),
],
configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS,
key=["n_rows", "n_cols"],
)
@libentry()
@triton.jit
def _gelu_fwd_nomask_kernel(
x,
y,
stride_row,
n_rows,
n_cols,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
grid_size = tl.num_programs(axis=0)

num_row_tasks = (n_rows + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M

for row_task_id in range(pid, num_row_tasks, grid_size):
block_start_row = row_task_id * BLOCK_SIZE_M
rows_off = block_start_row + tl.arange(0, BLOCK_SIZE_M)

for col_offset in range(0, n_cols, BLOCK_SIZE_N):
cols_off = col_offset + tl.arange(0, BLOCK_SIZE_N)

x_ptrs = x + rows_off[:, None] * stride_row + cols_off[None, :]
y_ptrs = y + rows_off[:, None] * stride_row + cols_off[None, :]

x_chunk = tl.load(x_ptrs)
x_f32 = x_chunk.to(tl.float32)
y_f32 = gelu_tanh_approx(x_f32)
y_chunk = y_f32.to(x_chunk.dtype)

tl.store(y_ptrs, y_chunk)


@triton.autotune(
configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS,
key=["n_rows", "n_cols"],
)
@libentry()
@triton.jit
def _gelu_fwd_nomask_single_kernel(
x,
y,
stride_row,
n_rows,
n_cols,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
grid_size = tl.num_programs(axis=0)

num_row_tasks = (n_rows + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
cols_off = tl.arange(0, BLOCK_SIZE_N)

for row_task_id in range(pid, num_row_tasks, grid_size):
block_start_row = row_task_id * BLOCK_SIZE_M
rows_off = block_start_row + tl.arange(0, BLOCK_SIZE_M)

x_ptrs = x + rows_off[:, None] * stride_row + cols_off[None, :]
y_ptrs = y + rows_off[:, None] * stride_row + cols_off[None, :]

x_chunk = tl.load(x_ptrs)
x_f32 = x_chunk.to(tl.float32)
y_f32 = gelu_tanh_approx(x_f32)
y_chunk = y_f32.to(x_chunk.dtype)

tl.store(y_ptrs, y_chunk)


@triton.autotune(
configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS,
key=["n_rows", "n_cols"],
restore_value=["dy", "dx"],
)
Expand Down Expand Up @@ -137,18 +200,44 @@ def _gelu_bwd_kernel(
sqrt_2_over_pi = 0.7978845608028654
x_cubed = x_f32 * x_f32 * x_f32
tanh_arg = sqrt_2_over_pi * (x_f32 + 0.044715 * x_cubed)
tanh_result = tl.tanh(tanh_arg)
tanh_result = libdevice.tanh(tanh_arg)

term1 = 0.5 * (1 + tanh_result)
tanh_sq = tanh_result * tanh_result
term2 = 0.5 * x_f32 * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_f32 * x_f32))
term2 = (
0.5
* x_f32
* (1 - tanh_sq)
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_f32 * x_f32))
)
dgelu_dx = term1 + term2

dx_chunk = dy_chunk * dgelu_dx.to(dy_chunk.dtype)

tl.store(dx_ptrs, dx_chunk, mask=block_mask)


def _rowwise_block_size_n(n_cols):
return min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE_N)


def _rowwise_grid(n_rows, block_size_m):
num_row_tasks = (n_rows + block_size_m - 1) // block_size_m
return (max(1, min(get_num_cores("vector"), num_row_tasks)),)


def _rowwise_autotune_grid(n_rows):
return lambda META: _rowwise_grid(n_rows, META["BLOCK_SIZE_M"])


def _can_use_nomask_kernel(n_rows, n_cols, block_size_n):
return n_cols % block_size_n == 0 and n_rows % GELU_TANH_MAX_BLOCK_SIZE_M == 0


def _can_use_nomask_single_kernel(n_rows, n_cols, block_size_n):
return n_cols == block_size_n and n_rows % GELU_TANH_MAX_BLOCK_SIZE_M == 0


def gelu_fwd_impl(x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for GELU.
Expand All @@ -167,22 +256,36 @@ def gelu_fwd_impl(x: torch.Tensor) -> torch.Tensor:

y = torch.empty_like(x_2d)

if n_cols > COL_BLOCKING_THRESHOLD:
BLOCK_SIZE_N = 2048
block_size_n = _rowwise_block_size_n(n_cols)
grid = _rowwise_autotune_grid(n_rows)

if _can_use_nomask_single_kernel(n_rows, n_cols, block_size_n):
_gelu_fwd_nomask_single_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=block_size_n,
)
elif _can_use_nomask_kernel(n_rows, n_cols, block_size_n):
_gelu_fwd_nomask_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=block_size_n,
)
else:
BLOCK_SIZE_N = align(x, n_cols, VEC_ALIGN_BYTES)

num_programs = triton.runtime.driver.active.utils.get_device_properties("npu")["num_vectorcore"]
grid = (num_programs,)

_gelu_fwd_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
_gelu_fwd_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=block_size_n,
)

return y.reshape(*ori_shape)

Expand Down Expand Up @@ -210,13 +313,8 @@ def gelu_bwd_impl(

dx = torch.empty_like(x_2d)

if n_cols > COL_BLOCKING_THRESHOLD:
BLOCK_SIZE_N = 2048
else:
BLOCK_SIZE_N = align(dy, n_cols, VEC_ALIGN_BYTES)

num_programs = triton.runtime.driver.active.utils.get_device_properties("npu")["num_vectorcore"]
grid = (num_programs,)
block_size_n = _rowwise_block_size_n(n_cols)
grid = _rowwise_autotune_grid(n_rows)

_gelu_bwd_kernel[grid](
dy_2d,
Expand All @@ -225,7 +323,7 @@ def gelu_bwd_impl(
dy_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_N=block_size_n,
)

return dx.reshape(*ori_shape)