Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 135 additions & 42 deletions mojo_opset/backends/ttx/kernels/ilu/moe_quant_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,20 @@
# ---------------------------------------------------------------------------
# Triton: grouped int8 matmul with per-group weight scales and per-token input scales.
#
# Int8 matmul uses per-K rank-1 int32 accumulation, not tl.dot: ILU Triton
# int8 tl.dot can fail LLVM layout conversion / segfault.
# Int8 matmul uses tl.dot on the matrix engine. ILU's int8 tl.dot still miscompiles
# (SharedToDotOperand lowering emits invalid <2xf32><->4xi8> bitcasts -> segfault in
# make_llir), so the int8 operands (|v| <= 127) are cast losslessly to fp16 and fed to
# an fp16 MMA. The int8->fp16 cast is lossless (|v| <= 127), and ILU's dot does a true
# fp32 multiply-accumulate (verified: dot of two fp16 vectors [64,1] returns exactly 4097,
# which fp16 intra-dot accumulation would round to 4096; large-magnitude sums that exceed
# the fp16 max of 65504 also come back finite/exact rather than inf). Each BLOCK_K tile
# (partial sum <= 128*127^2 << 2^24) is thus computed exactly in the fp32 dot output,
# rounded back to int32, and accumulated into an int32
# partial (so the per-group total, which can exceed 2^24 for large groups, stays exact);
# the int32 partial is then dequantized by the per-group weight scale. An autotune
# prune (_prune_block_k_gt_group) keeps BLOCK_K <= QUANT_GROUP_SIZE so a tile never spans
# group boundaries. The weight tile is loaded as [BLOCK_N, BLOCK_K] (B is row-major
# [N, K]) and transposed via tl.trans before the dot.
#
# EPILOGUE enum selects post-matmul activation:
# EPILOGUE_NONE – plain dequant output
Expand All @@ -28,15 +40,37 @@ def _quant_moe_autotune_configs():
(32, 32, 4), (32, 64, 4), (64, 32, 4),
(64, 64, 4), (64, 128, 4), (128, 64, 4),
]:
for ns in [2, 3]:
configs.append(triton.Config(
{"BLOCK_M": BM, "BLOCK_N": BN},
num_warps=nw, num_stages=ns,
))
for BK in [32, 64, 128]:
for ns in [2, 3]:
configs.append(triton.Config(
{"BLOCK_M": BM, "BLOCK_N": BN, "BLOCK_K": BK},
num_warps=nw, num_stages=ns,
))
return configs


@smart_triton_autotune(configs=_quant_moe_autotune_configs(), selected_idx=0, key=["N", "K", "MAX_M"])
def _prune_block_k_gt_group(configs, named_args, **kwargs):
"""Drop configs whose BLOCK_K exceeds the quant group size.

A BLOCK_K larger than QUANT_GROUP_SIZE is numerically correct (the
``k_in_group < QUANT_GROUP_SIZE`` mask zeros the out-of-group lanes) but
wastes up to BLOCK_K/QUANT_GROUP_SIZE of the MMA work on masked lanes, so
such configs must never be picked by the autotuner. QUANT_GROUP_SIZE is
already normalized to a positive value (``<= 0`` -> K) before launch.
"""
qgs = named_args.get("QUANT_GROUP_SIZE") or kwargs.get("QUANT_GROUP_SIZE")
if not qgs or qgs <= 0:
return list(configs)
kept = [c for c in configs if c.kwargs.get("BLOCK_K", 1) <= qgs]
return kept or list(configs)


@smart_triton_autotune(
configs=_quant_moe_autotune_configs(),
selected_idx=0,
key=["N", "K", "MAX_M", "QUANT_GROUP_SIZE"],
prune_configs_by={"early_config_prune": _prune_block_k_gt_group},
)
@libentry()
@triton.jit
def _quant_moe_gemm_kernel(
Expand All @@ -60,12 +94,18 @@ def _quant_moe_gemm_kernel(
EPILOGUE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""Grouped int8 matmul with per-group dequant and optional epilogue.

EPILOGUE_NONE: output = dequant(A @ B.T)
EPILOGUE_SWIGLU: B has N columns (gate + up), output HALF_N = N//2 columns
after silu(gate) * up.

The int8 matmul uses tl.dot (int32 accumulator). The weight tile is loaded
as [BLOCK_N, BLOCK_K] from row-major B[N, K] and transposed with tl.trans.
Each quant group accumulates int32 over BLOCK_K tiles, then is dequantized
by the per-group weight scale before being summed into the fp32 accumulator.
"""
n_tile_id = tl.program_id(0)
m_tile_id = tl.program_id(1)
Expand All @@ -85,6 +125,7 @@ def _quant_moe_gemm_kernel(

offs_m = group_start + m_tile_id * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = n_tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
mask_m = offs_m < group_end

b_base = B + group_id * stride_bg
Expand All @@ -100,34 +141,49 @@ def _quant_moe_gemm_kernel(
out_N = N

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
mask_n = offs_n < out_N

for kg in range(NUM_GROUPS_K):
k_start = kg * QUANT_GROUP_SIZE
# Number of BLOCK_K tiles needed to cover one quant group (constexpr).
K_TILES_PER_GROUP: tl.constexpr = (QUANT_GROUP_SIZE + BLOCK_K - 1) // BLOCK_K

for kg in range(NUM_GROUPS_K):
partial = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)

if EPILOGUE == _SWIGLU:
partial_up = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)

for k_idx in range(k_start, k_start + QUANT_GROUP_SIZE):
if k_idx < K:
a_col = tl.load(A + offs_m * K + k_idx, mask=mask_m, other=0)
b_row = tl.load(
b_base + offs_n * strideBN + k_idx * strideBK,
mask=offs_n < out_N, other=0,
)
partial += a_col.to(tl.int32)[:, None] * b_row.to(tl.int32)[None, :]

if EPILOGUE == _SWIGLU:
bu_row = tl.load(
b_base + offs_n_up * strideBN + k_idx * strideBK,
mask=offs_n_up < N, other=0,
)
partial_up += a_col.to(tl.int32)[:, None] * bu_row.to(tl.int32)[None, :]
for kt in range(K_TILES_PER_GROUP):
k_in_group = kt * BLOCK_K + offs_k
k_off = kg * QUANT_GROUP_SIZE + k_in_group
# Stay inside both the current quant group and the global K extent.
k_mask = (k_in_group < QUANT_GROUP_SIZE) & (k_off < K)

a = tl.load(
A + offs_m[:, None] * K + k_off[None, :],
mask=mask_m[:, None] & k_mask[None, :], other=0,
) # [BLOCK_M, BLOCK_K] int8
a_f16 = a.to(tl.float16)
b = tl.load(
b_base + offs_n[:, None] * strideBN + k_off[None, :] * strideBK,
mask=mask_n[:, None] & k_mask[None, :], other=0,
) # [BLOCK_N, BLOCK_K] int8
tile = tl.dot(a_f16, tl.trans(b).to(tl.float16), out_dtype=tl.float32)
# tile holds the exact integer A@B.T for this BLOCK_K tile (|sum| <=
# BLOCK_K*127^2 << 2^24, fp32-exact). Round (not truncate) before the
# int32 cast so any sub-0.5 fp drift cannot flip the integer.
partial += (tile + tl.where(tile >= 0, 0.5, -0.5)).to(tl.int32)

if EPILOGUE == _SWIGLU:
bu = tl.load(
b_base + offs_n_up[:, None] * strideBN + k_off[None, :] * strideBK,
mask=(offs_n_up[:, None] < N) & k_mask[None, :], other=0,
) # [BLOCK_N, BLOCK_K] int8
tile_up = tl.dot(a_f16, tl.trans(bu).to(tl.float16), out_dtype=tl.float32)
partial_up += (tile_up + tl.where(tile_up >= 0, 0.5, -0.5)).to(tl.int32)

ws = tl.load(
ws_base + offs_n * stride_ws_n + kg * stride_ws_k,
mask=offs_n < out_N, other=0.0,
mask=mask_n, other=0.0,
)
acc += partial.to(tl.float32) * ws[None, :]

Expand Down Expand Up @@ -155,7 +211,25 @@ def _quant_moe_gemm_kernel(
tl.store(c_ptrs, c, mask=c_mask)


def _prepare_quant_gemm_args(A, input_scale, weight_scale, size_per_group, num_groups, K, quant_group_size):
def _make_group_offsets(
size_per_group: torch.Tensor, num_groups: int, device: torch.device
) -> tuple[torch.Tensor, int]:
"""Build [num_groups + 1] int32 prefix sums and the per-group max.

The single ``.max().item()`` here is the only device->host sync in the
experts pipeline; compute it once and reuse for every quant/GEMM launch.
"""
cum = size_per_group.cumsum(0, dtype=torch.int32)
group_offsets = torch.zeros(num_groups + 1, dtype=torch.int32, device=device)
group_offsets[1:] = cum
max_m = int(size_per_group.max().item())
return group_offsets, max_m


def _prepare_quant_gemm_args(
A, input_scale, weight_scale, size_per_group, num_groups, K, quant_group_size,
group_offsets=None, max_m=None,
):
if quant_group_size <= 0:
quant_group_size = K
num_groups_k = (K + quant_group_size - 1) // quant_group_size
Expand All @@ -164,11 +238,8 @@ def _prepare_quant_gemm_args(A, input_scale, weight_scale, size_per_group, num_g
weight_scale = weight_scale.unsqueeze(-1)
input_scale_flat = input_scale.reshape(-1).float()

cum = size_per_group.cumsum(0, dtype=torch.int32)
group_offsets = torch.empty(num_groups + 1, dtype=torch.int32, device=A.device)
group_offsets[0] = 0
group_offsets[1:] = cum
max_m = size_per_group.max().item()
if group_offsets is None or max_m is None:
group_offsets, max_m = _make_group_offsets(size_per_group, num_groups, A.device)

return quant_group_size, num_groups_k, weight_scale, input_scale_flat, group_offsets, max_m

Expand All @@ -185,9 +256,14 @@ def _quant_m_grouped_matmul(
N: int,
K: int,
quant_group_size: int,
group_offsets: torch.Tensor | None = None,
max_m: int | None = None,
) -> torch.Tensor:
quant_group_size, num_groups_k, weight_scale, input_scale_flat, group_offsets, max_m = \
_prepare_quant_gemm_args(A, input_scale, weight_scale, size_per_group, num_groups, K, quant_group_size)
_prepare_quant_gemm_args(
A, input_scale, weight_scale, size_per_group, num_groups, K, quant_group_size,
group_offsets=group_offsets, max_m=max_m,
)

def grid(META):
return (
Expand Down Expand Up @@ -222,10 +298,15 @@ def _quant_m_grouped_matmul_swiglu(
N: int,
K: int,
quant_group_size: int,
group_offsets: torch.Tensor | None = None,
max_m: int | None = None,
) -> torch.Tensor:
half_n = N // 2
quant_group_size, num_groups_k, weight_scale, input_scale_flat, group_offsets, max_m = \
_prepare_quant_gemm_args(A, input_scale, weight_scale, size_per_group, num_groups, K, quant_group_size)
_prepare_quant_gemm_args(
A, input_scale, weight_scale, size_per_group, num_groups, K, quant_group_size,
group_offsets=group_offsets, max_m=max_m,
)

def grid(META):
return (
Expand Down Expand Up @@ -323,6 +404,8 @@ def _moe_smooth_dynamic_quant(
inv_smooth_scale: torch.Tensor,
tokens_per_expert: torch.Tensor,
num_experts: int,
group_offsets: torch.Tensor | None = None,
max_tokens: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Triton: grouped smooth + per-token dynamic int8 quantization.

Expand All @@ -337,12 +420,8 @@ def _moe_smooth_dynamic_quant(
output = torch.empty(total_tokens, K, dtype=torch.int8, device=device)
scale = torch.empty(total_tokens, dtype=torch.float32, device=device)

cum = tokens_per_expert.cumsum(0, dtype=torch.int32)
group_offsets = torch.empty(num_experts + 1, dtype=torch.int32, device=device)
group_offsets[0] = 0
group_offsets[1:] = cum

max_tokens = int(tokens_per_expert.max().item())
if group_offsets is None or max_tokens is None:
group_offsets, max_tokens = _make_group_offsets(tokens_per_expert, num_experts, device)
if max_tokens == 0:
return output, scale.unsqueeze(-1)

Expand Down Expand Up @@ -463,12 +542,19 @@ def quant_moe_experts_impl(
dtype = sorted_hidden_states.dtype
device = sorted_hidden_states.device

# All four launches below share the same token-to-expert layout, so the
# group offsets and per-expert max token count (the only device->host sync)
# are computed once here and threaded through every step.
group_offsets, max_m = _make_group_offsets(tokens_per_expert, num_experts, device)

# --- Step 1: smooth + dynamic quant for up_proj input ---
x_int8, x_scale = _moe_smooth_dynamic_quant(
sorted_hidden_states,
module.up_proj_quantize.inv_smooth_scale,
tokens_per_expert,
num_experts,
group_offsets=group_offsets,
max_tokens=max_m,
)

up_w = _cached_unpacked_proj_weight(module, "up_proj_weight")
Expand All @@ -478,8 +564,9 @@ def quant_moe_experts_impl(
inter = module.intermediate_size

# --- Step 2: int8 grouped GEMM + SwiGLU epilogue ---
# Use fp32 output to match core precision chain (core keeps activated in fp32
# before passing to down_proj_quantize).
# Use fp32 output to match the core precision chain (core keeps the activated
# intermediate in fp32 before passing it to down_proj_quantize, whose per-token
# dynamic scale is derived from this tensor's amax).
fc1_out = torch.empty(t_tokens, inter, device=device, dtype=torch.float32)
_quant_m_grouped_matmul_swiglu(
x_int8,
Expand All @@ -493,6 +580,8 @@ def quant_moe_experts_impl(
n_up,
k_in,
module.up_quant_group_size,
group_offsets=group_offsets,
max_m=max_m,
)

# --- Step 3: smooth + dynamic quant for down_proj input ---
Expand All @@ -501,6 +590,8 @@ def quant_moe_experts_impl(
module.down_proj_quantize.inv_smooth_scale,
tokens_per_expert,
num_experts,
group_offsets=group_offsets,
max_tokens=max_m,
)

down_w = _cached_unpacked_proj_weight(module, "down_proj_weight")
Expand All @@ -522,5 +613,7 @@ def quant_moe_experts_impl(
h_out,
k_inter,
module.down_quant_group_size,
group_offsets=group_offsets,
max_m=max_m,
)
return out
Loading