Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tests/kernels/quantization/test_hybrid_w4a16_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,75 @@ def test_triton_w4a16_skinny_fmt_gemm_asymmetric(dtype, M, K, N, G, random_seed:

# bf16 accumulation at larger shapes needs slightly looser tolerance
torch.testing.assert_close(out, ref, rtol=1e-2, atol=5e-2)


# ---------------------------------------------------------------------------
# Performance regression test
# ---------------------------------------------------------------------------

# Reference TFLOPS measured on gfx1151 (Strix Halo, 40 CUs) with the
# tuned kernel (num_stages=1, UNROLL_K=4, BM=64/BN=256/BK=64/w=8 for
# M>1024).
# Key: (M, K, N, group_size, has_zp) -> reference TFLOPS
_PERF_REFERENCE_TFLOPS: dict[tuple[int, int, int, int, bool], float] = {
# Qwen2.5-7B shapes — symmetric (compressed-tensors w4a16)
(1606, 3584, 37888, 128, False): 25.0,
(1606, 3584, 18944, 128, False): 26.0,
(1606, 3584, 4608, 128, False): 27.0,
(1606, 3584, 3584, 128, False): 26.0,
# Qwen2.5-7B shapes — asymmetric (AWQ, zero_point=True)
(1606, 3584, 37888, 128, True): 24.5,
(1606, 3584, 18944, 128, True): 24.5,
(1606, 3584, 4608, 128, True): 25.5,
(1606, 3584, 3584, 128, True): 24.5,
}

PERF_TOLERANCE = 0.05 # 5% relative tolerance


@pytest.mark.skipif(not current_platform.is_rocm(), reason="ROCm only")
@pytest.mark.parametrize("has_zp", [False, True], ids=["symmetric", "asymmetric"])
@pytest.mark.parametrize(
"M,K,N,G",
[
(1606, 3584, 37888, 128),
(1606, 3584, 18944, 128),
(1606, 3584, 4608, 128),
(1606, 3584, 3584, 128),
],
)
def test_triton_w4a16_prefill_perf_regression(M, K, N, G, has_zp):
"""Fail if prefill TFLOPS drops more than 5% below reference."""
triton_testing = pytest.importorskip("triton.testing")

ref_tflops = _PERF_REFERENCE_TFLOPS[(M, K, N, G, has_zp)]
num_groups = K // G

a = torch.randn((M, K), device=device, dtype=torch.float16)
b_q_i32 = torch.randint(0, 2**31, (N, K // 8), dtype=torch.int32, device=device)
scales = torch.randn(N, num_groups, dtype=torch.float16, device=device) * 0.01
zp = None
if has_zp:
zp = torch.randint(0, 16, (N, num_groups), dtype=torch.int32, device=device).to(
torch.float16
)

def run():
triton_w4a16_skinny_fmt_gemm(a, b_q_i32, scales, G, zp=zp)

# Warm up to trigger Triton JIT compilation before timing.
for _ in range(3):
run()
torch.accelerator.synchronize()

ms = triton_testing.do_bench(run, warmup=50, rep=100)
tflops = (2 * M * N * K) * 1e-12 / (ms * 1e-3)

mode = "asymmetric" if has_zp else "symmetric"
min_tflops = ref_tflops * (1 - PERF_TOLERANCE)
assert tflops >= min_tflops, (
f"Performance regression ({mode}): {tflops:.2f} TFLOPS < "
f"{min_tflops:.2f} TFLOPS (reference {ref_tflops:.1f}, "
f"tolerance {PERF_TOLERANCE * 100:.0f}%) for "
f"M={M} K={K} N={N} G={G} ({ms:.3f} ms)"
)
110 changes: 65 additions & 45 deletions vllm/model_executor/kernels/linear/mixed_precision/hybrid_w4a16.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _triton_w4a16_skinny_fmt_kernel(
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
UNROLL_K: tl.constexpr, # number of BLOCK_K tiles to unroll per loop iter
):
"""
Fused W4A16 GEMM reading weights from skinny format [N, K//8].
Expand All @@ -79,6 +80,11 @@ def _triton_w4a16_skinny_fmt_kernel(
When HAS_ZP=True, raw zero-points zp_raw are loaded from zp_ptr [N, K//G]
and subtracted directly: (nibble - zp_raw) * scale.
When HAS_ZP=False, only the constant ZP_BIAS is subtracted (symmetric).

UNROLL_K controls how many BLOCK_K tiles are statically unrolled per
outer loop iteration. This amortises loop overhead and gives the
compiler more scheduling freedom. BLOCK_K must not exceed group_size,
so UNROLL_K > 1 effectively processes multiple groups per iteration.
"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
Expand All @@ -99,48 +105,51 @@ def _triton_w4a16_skinny_fmt_kernel(

accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

for k_start in range(0, tl.cdiv(K, BLOCK_K)):
offs_k = k_start * BLOCK_K + tl.arange(0, BLOCK_K)
mask_k = offs_k < K

# ---- Load activations A: [BLOCK_M, BLOCK_K] ----
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_a = (offs_m[:, None] < M) & mask_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)

# ---- Load packed weights B: [BLOCK_N, BLOCK_K//8] int32 ----
offs_k8 = k_start * (BLOCK_K // 8) + tl.arange(0, BLOCK_K // 8)
b_ptrs = b_ptr + offs_n[:, None] * K8 + offs_k8[None, :]
mask_b = (offs_n[:, None] < N) & (offs_k8[None, :] < K8)
b_packed = tl.load(b_ptrs, mask=mask_b, other=0)

# ---- Unpack int4 weights with ExLlama unshuffle ----
b = tl.interleave(b_packed, b_packed)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
b = (b >> shifts_full) & 0xF # [BLOCK_N, BLOCK_K]

# ---- Load scales from [N, K//G] layout ----
g_idx = (k_start * BLOCK_K) // group_size
scale_ptrs = scales_ptr + offs_n * num_groups + g_idx
scale_mask = offs_n < N
scales = tl.load(scale_ptrs, mask=scale_mask, other=1.0)

# ---- Dequantize ----
if HAS_ZP:
# Asymmetric: (nibble - zp_raw) * scale (single subtraction)
zp_ptrs = zp_ptr + offs_n * num_groups + g_idx
zp_raw = tl.load(zp_ptrs, mask=scale_mask, other=0.0)
b_fp = (b.to(scales.dtype) - zp_raw[:, None]) * scales[:, None]
else:
# Symmetric: (w - 8) * scale
b_fp = (b - ZP_BIAS).to(scales.dtype) * scales[:, None]
total_k_tiles = tl.cdiv(K, BLOCK_K)
num_outer = tl.cdiv(total_k_tiles, UNROLL_K)

for outer in range(0, num_outer):
for u in tl.static_range(UNROLL_K):
k_start = outer * UNROLL_K + u
offs_k = k_start * BLOCK_K + tl.arange(0, BLOCK_K)
mask_k = offs_k < K

# ---- Load activations A: [BLOCK_M, BLOCK_K] ----
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_a = (offs_m[:, None] < M) & mask_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)

# ---- Load packed weights B: [BLOCK_N, BLOCK_K//8] int32 ----
offs_k8 = k_start * (BLOCK_K // 8) + tl.arange(0, BLOCK_K // 8)
b_ptrs = b_ptr + offs_n[:, None] * K8 + offs_k8[None, :]
mask_b = (offs_n[:, None] < N) & (offs_k8[None, :] < K8)
b_packed = tl.load(b_ptrs, mask=mask_b, other=0)

# ---- Unpack int4 weights with ExLlama unshuffle ----
b = tl.interleave(b_packed, b_packed)
b = tl.interleave(b, b)
b = tl.interleave(b, b)
b = (b >> shifts_full) & 0xF # [BLOCK_N, BLOCK_K]

# ---- Load scales from [N, K//G] layout ----
g_idx = (k_start * BLOCK_K) // group_size
scale_ptrs = scales_ptr + offs_n * num_groups + g_idx
scale_mask = offs_n < N
scales = tl.load(scale_ptrs, mask=scale_mask, other=1.0)

# ---- Dequantize ----
if HAS_ZP:
zp_ptrs = zp_ptr + offs_n * num_groups + g_idx
zp_raw = tl.load(zp_ptrs, mask=scale_mask, other=0.0)
b_fp = (b.to(scales.dtype) - zp_raw[:, None]) * scales[:, None]
else:
b_fp = (b - ZP_BIAS).to(scales.dtype) * scales[:, None]

# ---- Transpose to [BLOCK_K, BLOCK_N] for matmul ----
b_fp_t = tl.trans(b_fp)
# ---- Transpose to [BLOCK_K, BLOCK_N] for matmul ----
b_fp_t = tl.trans(b_fp)

# ---- Accumulate: [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N] ----
accumulator += tl.dot(a, b_fp_t, out_dtype=tl.float32)
# ---- Accumulate: [BLOCK_M, BLOCK_K] @ [BLOCK_K, BLOCK_N] ----
accumulator += tl.dot(a, b_fp_t, out_dtype=tl.float32)

# ---- Store output C: [BLOCK_M, BLOCK_N] ----
c = accumulator.to(c_ptr.type.element_ty)
Expand Down Expand Up @@ -197,7 +206,7 @@ def triton_w4a16_skinny_fmt_gemm(

if on_gfx1x():
# Tuned on gfx1151 (Strix Halo, 40 CUs, 32-wide wavefronts)
# using Qwen3-4B weight shapes with group_size=128.
# using Qwen3-4B and Qwen2.5-7B weight shapes with group_size=128.
if M <= 32:
BLOCK_M, BLOCK_N, BLOCK_K, num_warps = 32, 32, 128, 4
elif M <= 64:
Expand All @@ -217,10 +226,8 @@ def triton_w4a16_skinny_fmt_gemm(
else:
BLOCK_M, BLOCK_N, BLOCK_K, num_warps = 64, 128, 32, 4
else:
if K >= 2 * N: # tall K (e.g. down_proj)
BLOCK_M, BLOCK_N, BLOCK_K, num_warps = 128, 512, 32, 16
else:
BLOCK_M, BLOCK_N, BLOCK_K, num_warps = 128, 64, 64, 8
# M > 1024: wider N-tiles improve occupancy and L2 reuse.
BLOCK_M, BLOCK_N, BLOCK_K, num_warps = 64, 256, 64, 8
else:
num_warps = 4
if M <= 32:
Expand All @@ -235,6 +242,17 @@ def triton_w4a16_skinny_fmt_gemm(
# a different group would get the wrong scale.
BLOCK_K = min(BLOCK_K, group_size)

# Static K-loop unrolling: process multiple BLOCK_K tiles per outer
# loop iteration. Reduces loop overhead and gives the compiler more
# scheduling freedom. Use 4 for large prefills (M > 1024) where
# K-loop iterations dominate; 1 elsewhere to avoid register pressure.
UNROLL_K = 4 if M > 1024 else 1
# Ensure total K tiles are divisible by UNROLL_K (required for
# correctness since the last unrolled sub-tile guards with mask_k).
total_k_tiles = triton.cdiv(K, BLOCK_K)
if total_k_tiles % UNROLL_K != 0:
UNROLL_K = 1

grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

_triton_w4a16_skinny_fmt_kernel[grid](
Expand All @@ -254,7 +272,9 @@ def triton_w4a16_skinny_fmt_gemm(
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
UNROLL_K=UNROLL_K,
num_warps=num_warps,
num_stages=1,
)
return c

Expand Down
Loading