Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 111 additions & 46 deletions src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _rms_norm_forward_kernel(
eps,
offset,
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Expand All @@ -75,15 +76,17 @@ def _rms_norm_forward_kernel(

X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
X_row_dtype = X_row.dtype
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
if elementwise_affine:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)

# On Llama, only rstd is computed on fp32
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(tl.float32)

# Gemma computes everything on fp32, and then casts back the output to the original dtype
if casting_mode == _CASTING_MODE_GEMMA:
W_row = W_row.to(tl.float32)
if elementwise_affine:
W_row = W_row.to(tl.float32)
X_row = X_row.to(tl.float32)

if casting_mode == _CASTING_MODE_NONE:
Expand All @@ -104,7 +107,10 @@ def _rms_norm_forward_kernel(
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(X_row_dtype)

Y_row = X_row * (offset + W_row)
if elementwise_affine:
Y_row = X_row * (offset + W_row)
else:
Y_row = X_row

if casting_mode == _CASTING_MODE_GEMMA:
Y_row = Y_row.to(X_row_dtype)
Expand Down Expand Up @@ -132,6 +138,7 @@ def _rms_norm_backward_kernel(
offset,
rows_per_program: tl.constexpr,
casting_mode: tl.constexpr,
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Expand All @@ -145,16 +152,18 @@ def _rms_norm_backward_kernel(
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols

dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
if elementwise_affine:
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)

dY_ptr += row_start * dY_row_stride
dX_ptr += row_start * dX_row_stride

X_ptr += row_start * X_row_stride
RSTD_ptr += row_start

W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_row = W_row + offset
if elementwise_affine:
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_row = W_row + offset

for _ in range(row_start, row_end):
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
Expand All @@ -167,24 +176,34 @@ def _rms_norm_backward_kernel(

# Different bacward graphs for different casting modes
if casting_mode == _CASTING_MODE_LLAMA:
m = (dY_row * W_row).to(tl.float32)
if elementwise_affine:
m = (dY_row * W_row).to(tl.float32)
else:
m = dY_row.to(tl.float32)

elif casting_mode == _CASTING_MODE_GEMMA:
dY_row = dY_row.to(tl.float32)
m = dY_row * W_row
if elementwise_affine:
m = dY_row * W_row
else:
m = dY_row
else:
m = dY_row * W_row
if elementwise_affine:
m = dY_row * W_row
else:
m = dY_row

dX_row = rstd_row * m

dX_row += (rstd_row) * (-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row)

# calculate the gradient of W
if casting_mode == _CASTING_MODE_LLAMA:
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
else:
# here X_row is already in fp32 (see previous if block)
dW_row += dY_row * (X_row * rstd_row)
if elementwise_affine:
# calculate the gradient of W
if casting_mode == _CASTING_MODE_LLAMA:
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
else:
# here X_row is already in fp32 (see previous if block)
dW_row += dY_row * (X_row * rstd_row)

tl.store(dX_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)

Expand All @@ -193,7 +212,8 @@ def _rms_norm_backward_kernel(
X_ptr += X_row_stride
RSTD_ptr += RSTD_row_stride

tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)
if elementwise_affine:
tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)


@triton.jit
Expand All @@ -211,6 +231,7 @@ def _block_rms_norm_forward_kernel(
eps,
offset,
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_ROW: tl.constexpr,
):
Expand All @@ -234,15 +255,17 @@ def _block_rms_norm_forward_kernel(
other=0,
)
X_row_dtype = X_row.dtype
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)
if elementwise_affine:
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0)

# On Llama, only rstd is computed on fp32
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(tl.float32)

# Gemma computes everything on fp32, and then casts back the output to the original dtype
if casting_mode == _CASTING_MODE_GEMMA:
W_row = W_row.to(tl.float32)
if elementwise_affine:
W_row = W_row.to(tl.float32)
X_row = X_row.to(tl.float32)

if casting_mode == _CASTING_MODE_NONE:
Expand All @@ -263,7 +286,10 @@ def _block_rms_norm_forward_kernel(
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(X_row_dtype)

Y_row = X_row * (offset + W_row)[None, :]
if elementwise_affine:
Y_row = X_row * (offset + W_row)[None, :]
else:
Y_row = X_row

if casting_mode == _CASTING_MODE_GEMMA:
Y_row = Y_row.to(X_row_dtype)
Expand Down Expand Up @@ -295,6 +321,7 @@ def _block_rms_norm_backward_kernel(
offset,
rows_per_program: tl.constexpr,
casting_mode: tl.constexpr,
elementwise_affine: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_ROW: tl.constexpr,
):
Expand All @@ -309,10 +336,11 @@ def _block_rms_norm_backward_kernel(
col_offsets = tl.arange(0, BLOCK_SIZE)
col_mask = col_offsets < n_cols

dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
if elementwise_affine:
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)

W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
W_row = W_row + offset
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
W_row = W_row + offset

for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
row_idx = start + tl.arange(0, BLOCK_ROW)
Expand All @@ -335,35 +363,45 @@ def _block_rms_norm_backward_kernel(

# Different bacward graphs for different casting modes
if casting_mode == _CASTING_MODE_LLAMA:
m = (dY_row * W_row[None, :]).to(tl.float32)
if elementwise_affine:
m = (dY_row * W_row[None, :]).to(tl.float32)
else:
m = dY_row.to(tl.float32)

elif casting_mode == _CASTING_MODE_GEMMA:
dY_row = dY_row.to(tl.float32)
m = dY_row * W_row[None, :]
if elementwise_affine:
m = dY_row * W_row[None, :]
else:
m = dY_row
else:
m = dY_row * W_row[None, :]
if elementwise_affine:
m = dY_row * W_row[None, :]
else:
m = dY_row

dX_row = rstd_row[:, None] * m

dX_row += (rstd_row[:, None]) * (
-(1 / n_cols) * (rstd_row * rstd_row * tl.sum(m * X_row, axis=1))[:, None] * X_row
)

# calculate the gradient of W
if casting_mode == _CASTING_MODE_LLAMA:
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
else:
# here X_row is already in fp32 (see previous if block)
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)
if elementwise_affine:
if casting_mode == _CASTING_MODE_LLAMA:
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
else:
# here X_row is already in fp32 (see previous if block)
dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)

tl.store(
dX_ptr + row_idx[:, None] * dX_row_stride + col_offsets[None, :],
dX_row,
mask=row_mask[:, None] & col_mask[None, :],
)

tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)
if elementwise_affine:
tl.store(dW_ptr + pid * dW_row_stride + col_offsets, dW_row, mask=col_mask)


_str_to_casting_mode = {
Expand Down Expand Up @@ -392,8 +430,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype
RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device)

# Check constraints.
assert X.shape[1] == W.shape[0], "Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
if W is not None:
# Check constraints.
assert X.shape[1] == W.shape[0], (
"Incompatible hidden size dimension between tensor1.shape[1] and tensor2.shape[0]"
)
elementwise_affine = True
else:
elementwise_affine = False

# XPU-specific optimization
kernel_args = {}
Expand All @@ -406,13 +450,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
X,
X.stride(0),
W,
W.stride(0),
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
n_cols,
eps,
offset,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
Expand All @@ -426,14 +471,15 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
X,
X.stride(0),
W,
W.stride(0),
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
n_rows,
n_cols,
eps,
offset,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
Expand All @@ -455,8 +501,13 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
elif X.device.type == "npu":
sm_count = get_npu_multi_processor_count()

# fp32 for numerical stability especially.
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
if W is not None:
# fp32 for numerical stability especially.
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)
elementwise_affine = True
else:
_dW = None
elementwise_affine = False

if n_cols > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
Expand All @@ -483,16 +534,17 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
X.stride(0),
torch_to_triton_dtype[X.dtype],
W,
W.stride(0),
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
_dW.stride(0) if elementwise_affine else 0,
n_rows,
n_cols,
offset,
rows_per_program,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
Expand All @@ -509,22 +561,27 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
X.stride(0),
torch_to_triton_dtype[X.dtype],
W,
W.stride(0),
W.stride(0) if elementwise_affine else 0,
RSTD,
RSTD.stride(0),
_dW,
_dW.stride(0),
_dW.stride(0) if elementwise_affine else 0,
n_rows,
n_cols,
offset,
rows_per_program,
casting_mode,
elementwise_affine=elementwise_affine,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
**kernel_args, # XPU-specific optimization
)
dX = dX.view(*shape)
dW = _dW.sum(dim=0).to(W.dtype)

if elementwise_affine:
dW = _dW.sum(dim=0).to(W.dtype)
else:
dW = None

return dX, dW

Expand Down Expand Up @@ -565,7 +622,11 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row
ctx.row_mode = row_mode
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(X, W, RSTD)
ctx.elementwise_affine = W is not None
if W is not None:
ctx.save_for_backward(X, W, RSTD)
else:
ctx.save_for_backward(X, RSTD)
return Y

@staticmethod
Expand All @@ -574,7 +635,11 @@ def backward(ctx, dY):
"""
Y: (B, T, H) or (BxT, H)
"""
X, W, RSTD = ctx.saved_tensors
if ctx.elementwise_affine:
X, W, RSTD = ctx.saved_tensors
else:
X, RSTD = ctx.saved_tensors
W = None
dX, dW = rms_norm_backward(
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
)
Expand Down
Loading
Loading