Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 157 additions & 56 deletions mojo_opset/backends/ttx/kernels/npu/gelu.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import torch
import triton
import triton.language as tl

import triton.language.extra.cann.libdevice as libdevice
from .utils import libentry
Comment thread
YangLong114514 marked this conversation as resolved.
Outdated

from mojo_opset.backends.ttx.kernels.npu.utils import VEC_ALIGN_BYTES
from mojo_opset.backends.ttx.kernels.utils import align

"""
This file contains the implementation of GELU (Gaussian Error Linear Unit) for NPU.

GELU formula: gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
GELU formula: gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))

Based on Liger Kernel implementation:
https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/geglu.py
Expand All @@ -19,30 +16,29 @@
"""


COL_BLOCKING_THRESHOLD = 4096
MAX_BLOCK_SIZE_N = 1024
GELU_TANH_MAX_BLOCK_SIZE_M = 8


GELU_TANH_BLOCK_SIZE_M_CONFIGS = [
triton.Config({"BLOCK_SIZE_M": 1}),
triton.Config({"BLOCK_SIZE_M": 2}),
triton.Config({"BLOCK_SIZE_M": 4}),
triton.Config({"BLOCK_SIZE_M": 8}),
]
Comment thread
YangLong114514 marked this conversation as resolved.
Outdated


@triton.jit
def gelu_tanh_approx(x):
"""GELU activation using tanh approximation"""
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / π)
"""GELU activation using tanh approximation."""
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
x_cubed = x * x * x
tanh_arg = sqrt_2_over_pi * (x + 0.044715 * x_cubed)
return 0.5 * x * (1 + tl.tanh(tanh_arg))
return 0.5 * x * (1 + libdevice.tanh(tanh_arg))


@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_M": 1}),
triton.Config({"BLOCK_SIZE_M": 2}),
triton.Config({"BLOCK_SIZE_M": 4}),
triton.Config({"BLOCK_SIZE_M": 8}),
triton.Config({"BLOCK_SIZE_M": 12}),
triton.Config({"BLOCK_SIZE_M": 16}),
triton.Config({"BLOCK_SIZE_M": 20}),
triton.Config({"BLOCK_SIZE_M": 24}),
triton.Config({"BLOCK_SIZE_M": 32}),
],
configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS,
key=["n_rows", "n_cols"],
)
@libentry()
Expand Down Expand Up @@ -85,17 +81,81 @@ def _gelu_fwd_kernel(


@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_M": 1}),
triton.Config({"BLOCK_SIZE_M": 2}),
triton.Config({"BLOCK_SIZE_M": 4}),
triton.Config({"BLOCK_SIZE_M": 8}),
triton.Config({"BLOCK_SIZE_M": 12}),
triton.Config({"BLOCK_SIZE_M": 16}),
triton.Config({"BLOCK_SIZE_M": 20}),
triton.Config({"BLOCK_SIZE_M": 24}),
triton.Config({"BLOCK_SIZE_M": 32}),
],
configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS,
key=["n_rows", "n_cols"],
)
@libentry()
@triton.jit
def _gelu_fwd_nomask_kernel(
x,
y,
stride_row,
n_rows,
n_cols,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
grid_size = tl.num_programs(axis=0)

num_row_tasks = (n_rows + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M

for row_task_id in range(pid, num_row_tasks, grid_size):
block_start_row = row_task_id * BLOCK_SIZE_M
rows_off = block_start_row + tl.arange(0, BLOCK_SIZE_M)

for col_offset in range(0, n_cols, BLOCK_SIZE_N):
cols_off = col_offset + tl.arange(0, BLOCK_SIZE_N)

x_ptrs = x + rows_off[:, None] * stride_row + cols_off[None, :]
y_ptrs = y + rows_off[:, None] * stride_row + cols_off[None, :]

x_chunk = tl.load(x_ptrs)
x_f32 = x_chunk.to(tl.float32)
y_f32 = gelu_tanh_approx(x_f32)
y_chunk = y_f32.to(x_chunk.dtype)

tl.store(y_ptrs, y_chunk)


@triton.autotune(
configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS,
key=["n_rows", "n_cols"],
)
@libentry()
@triton.jit
def _gelu_fwd_nomask_single_kernel(
x,
y,
stride_row,
n_rows,
n_cols,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
grid_size = tl.num_programs(axis=0)

num_row_tasks = (n_rows + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
cols_off = tl.arange(0, BLOCK_SIZE_N)

for row_task_id in range(pid, num_row_tasks, grid_size):
block_start_row = row_task_id * BLOCK_SIZE_M
rows_off = block_start_row + tl.arange(0, BLOCK_SIZE_M)

x_ptrs = x + rows_off[:, None] * stride_row + cols_off[None, :]
y_ptrs = y + rows_off[:, None] * stride_row + cols_off[None, :]

x_chunk = tl.load(x_ptrs)
x_f32 = x_chunk.to(tl.float32)
y_f32 = gelu_tanh_approx(x_f32)
y_chunk = y_f32.to(x_chunk.dtype)

tl.store(y_ptrs, y_chunk)


@triton.autotune(
configs=GELU_TANH_BLOCK_SIZE_M_CONFIGS,
key=["n_rows", "n_cols"],
restore_value=["dy", "dx"],
)
Expand Down Expand Up @@ -137,18 +197,50 @@ def _gelu_bwd_kernel(
sqrt_2_over_pi = 0.7978845608028654
x_cubed = x_f32 * x_f32 * x_f32
tanh_arg = sqrt_2_over_pi * (x_f32 + 0.044715 * x_cubed)
tanh_result = tl.tanh(tanh_arg)
tanh_result = libdevice.tanh(tanh_arg)

term1 = 0.5 * (1 + tanh_result)
tanh_sq = tanh_result * tanh_result
term2 = 0.5 * x_f32 * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_f32 * x_f32))
term2 = (
0.5
* x_f32
* (1 - tanh_sq)
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_f32 * x_f32))
)
dgelu_dx = term1 + term2

dx_chunk = dy_chunk * dgelu_dx.to(dy_chunk.dtype)

tl.store(dx_ptrs, dx_chunk, mask=block_mask)


def _rowwise_block_size_n(n_cols):
return min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE_N)


def _num_vectorcores():
return triton.runtime.driver.active.utils.get_device_properties("npu")[
"num_vectorcore"
]


def _rowwise_grid(n_rows, block_size_m):
num_row_tasks = (n_rows + block_size_m - 1) // block_size_m
return (min(_num_vectorcores(), num_row_tasks),)
Comment thread
YangLong114514 marked this conversation as resolved.
Outdated


def _rowwise_autotune_grid(n_rows):
return lambda META: _rowwise_grid(n_rows, META["BLOCK_SIZE_M"])


def _can_use_nomask_kernel(n_rows, n_cols, block_size_n):
return n_cols % block_size_n == 0 and n_rows % GELU_TANH_MAX_BLOCK_SIZE_M == 0


def _can_use_nomask_single_kernel(n_rows, n_cols, block_size_n):
return n_cols == block_size_n and n_rows % GELU_TANH_MAX_BLOCK_SIZE_M == 0


def gelu_fwd_impl(x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for GELU.
Expand All @@ -167,22 +259,36 @@ def gelu_fwd_impl(x: torch.Tensor) -> torch.Tensor:

y = torch.empty_like(x_2d)

if n_cols > COL_BLOCKING_THRESHOLD:
BLOCK_SIZE_N = 2048
block_size_n = _rowwise_block_size_n(n_cols)
grid = _rowwise_autotune_grid(n_rows)

if _can_use_nomask_single_kernel(n_rows, n_cols, block_size_n):
_gelu_fwd_nomask_single_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=block_size_n,
)
elif _can_use_nomask_kernel(n_rows, n_cols, block_size_n):
_gelu_fwd_nomask_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=block_size_n,
)
else:
BLOCK_SIZE_N = align(x, n_cols, VEC_ALIGN_BYTES)

num_programs = triton.runtime.driver.active.utils.get_device_properties("npu")["num_vectorcore"]
grid = (num_programs,)

_gelu_fwd_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)
_gelu_fwd_kernel[grid](
x_2d,
y,
x_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=block_size_n,
)

return y.reshape(*ori_shape)

Expand Down Expand Up @@ -210,13 +316,8 @@ def gelu_bwd_impl(

dx = torch.empty_like(x_2d)

if n_cols > COL_BLOCKING_THRESHOLD:
BLOCK_SIZE_N = 2048
else:
BLOCK_SIZE_N = align(dy, n_cols, VEC_ALIGN_BYTES)

num_programs = triton.runtime.driver.active.utils.get_device_properties("npu")["num_vectorcore"]
grid = (num_programs,)
block_size_n = _rowwise_block_size_n(n_cols)
grid = _rowwise_autotune_grid(n_rows)

_gelu_bwd_kernel[grid](
dy_2d,
Expand All @@ -225,7 +326,7 @@ def gelu_bwd_impl(
dy_2d.stride(0),
n_rows,
n_cols,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_N=block_size_n,
)

return dx.reshape(*ori_shape)