[ROCm][MoE] W4A16 MoE routing-distribution benchmark suite for the gfx11 prefill GEMM#1020
Draft
roberteg16 wants to merge 3 commits into
Draft
Conversation
4f011d9 to
b548877
Compare
…ault-on) Adds a producer/consumer WMMA kernel (rdna_moe_gemm) for the W4A16 MoE prefill GEMM1 on AMD RDNA3 (gfx11), a faster alternative to the Triton fused_moe_kernel_gptq_awq. Enabled by default on gfx11 (VLLM_MOE_HIP=0 forces Triton, =1 forces on). On the rdna_moe_gemm path gemm1 (up/gate proj, top_k>1) runs the rdna_moe_gemm WMMA kernel and gemm2 (down proj, top_k=1) runs Triton, both at a single block_m=32 moe_align alignment. The per-routed-token activations are stored compact (flat-topk indexed -- the kernels write C[sorted_token_ids[slot]]), so gemm2 gathers them via the shared sorted_token_ids with no re-permute. gemm1 at block_m=32 is faster than at 16 on gfx1151 (the larger WMMA tile more than pays for the extra alignment padding), so there is a single alignment for both gemms. - csrc/rocm/moe_gemm_w4a16_wmma.cu: the WMMA kernel with K, N and the weight N-row stride as runtime args (one instantiation per tile family handles any compliant shape and any weight padding). Compiled into _rocm_C; the WMMA body is gfx11-only with a stub on other arches. Registered as torch.ops._rocm_C.moe_gemm_w4a16; an unsupported shape TORCH_CHECKs. - vllm/envs.py: VLLM_MOE_HIP tri-state (unset = default-on on gfx11). - vllm/model_executor/layers/fused_moe/moe_hip_w4a16.py: host-side shape predicate (prefill_uses_rdna_moe_gemm) plus a graph-safe vLLM custom op wrapper with a no-op fake, so the path is torch.compile-safe. - vllm/model_executor/layers/fused_moe/hybrid_w4a16_moe.py: apply() dispatch -- rdna_moe_gemm gemm1 when the shape is supported, else Triton; gemm2 always Triton. - tests/kernels/quantization/test_moe_gemm_w4a16.py: compiled op vs Triton reference across shapes and block sizes. AI assistance (Claude) was used. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Robert Esclapez Garcia <robert.garcia@amd.com>
Wrap the rdna_moe_gemm gemm1 launch (torch.ops.vllm.moe_gemm_w4a16) in apply() with an apply()-level record_function scope carrying the dims the roofline tool needs: M, N, K, E, top_k, the quant group size g, block_m, valid_blocks (= num_tokens_post_padded // block_m) and n_routed (M*top_k), plus per-expert tok_hist / vtok_hist histograms describing the routing skew that drives the block_m padding. g is the real self._group_size (the per-group scale bytes and dequant FLOPs scale with K/g), not an assumed 128. The scope is gated on VLLM_CUSTOM_SCOPES_FOR_PROFILING / VLLM_NVTX_SCOPES_FOR_PROFILING -- the valid_blocks/histogram reads force a device->host sync, taken only when profiling; production gets a nullcontext. AI assistance (Claude) was used. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Robert Esclapez Garcia <robert.garcia@amd.com>
a899393 to
f50fdea
Compare
Adds benchmarks/kernels/moe_w4a16_bench/, a routing-distribution harness for the gfx11 W4A16 MoE prefill GEMM (rdna_moe_gemm): - gen_distributions.py: synthetic routing archetypes (balanced/uniform/ zipf1/zipf2/hot16) + a moe_align padding/fill space-map. synth_topk() is importable so the benchmark builds routing in-process. - bench_moe_gemm.py: times torch.ops._rocm_C.moe_gemm_w4a16 against the Triton reference (invoke_fused_moe_kernel_hybrid_triton) on identical inputs across the archetypes, for both prefill GEMMs (--gemm 1,2), with Triton as the correctness oracle. Co-authored-by: Claude <noreply@anthropic.com> Signed-off-by: Robert Esclapez Garcia <robert.garcia@amd.com>
b548877 to
b9faa78
Compare
f50fdea to
89bd85b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
benchmarks/kernels/moe_w4a16_bench/, a small benchmark suite that maps where the gfx11 W4A16 MoE prefill GEMM (torch.ops._rocm_C.moe_gemm_w4a16, the rdna_moe_gemm kernel) wins or loses against the Triton reference (fused_moe_kernel_gptq_awqviainvoke_fused_moe_kernel_hybrid_triton) as a function of the token→expert routing distribution. MoE prefill performance is routing-dependent: a custom WMMA kernel's only lever over Triton is keeping a hot expert's weights cache-resident across its blocks, so the win is entirely a function of routing skew and per-block fill. This suite makes that surface measurable and reproducible across synthetic routing archetypes (balanced → heavily skewed) and a sweep of token counts, on identical inputs with Triton as the correctness oracle. It is benchmark-only (no kernel, library, or runtime changes) and adds nothing to the build or import graph of vLLM proper.What this adds
benchmarks/kernels/moe_w4a16_bench/gen_distributions.py: a numpy-only routing-distribution generator. It exposes anArchetypeenum (balanced,uniform,zipf1,zipf2,hot16) and an importablesynth_topk(arch, T)that buildstopk_idswith a controlled per-expert popularity (gumbel-top-k sampling), so the benchmark constructs routing in-process with no on-disk dataset step. Run as a script it prints a GPU-freemoe_alignpadding/fill space-map (per-expert load stats, dead-expert count, WMMA-useful %, and skinny-block fraction atblock_m16 and 32), useful for choosing which shapes are worth GPU time. It deliberately imports neither torch nor vLLM so it runs on any machine.benchmarks/kernels/moe_w4a16_bench/bench_moe_gemm.py: the GPU benchmark. For each(archetype, T)it times the rdna_moe_gemm op against the Triton reference on the same tensors and sharedmoe_align@32layout, reporting Triton µs, kernel µs, speedup %, achieved kernel TFLOP/s, and the relative error vs Triton. It covers both MoE prefill GEMMs via--gemm 1,2(gemm1 = up/gate, top_k=8, K=2048 N=1024; gemm2 = down, top_k=1, K=512 N=2048), at the single productionblock_m=32the op supports. Timing usestriton.testing.do_bench(L2 flushed between iterations) so each launch is measured cold, matching the production pattern where the same GEMM does not run back-to-back. The achieved TFLOP/s is computed over the real padded work (valid_blocks * block_m, i.e.num_tokens_post_padded), not just the useful routed rows, so it reflects the hardware throughput the kernel actually delivers rather than an inflated useful-only figure.The suite reuses existing vLLM building blocks rather than reimplementing them:
pack_int4_exllama_shufflefor the ExLlama INT4 weight layout,moe_align_block_sizefor the block layout, andinvoke_fused_moe_kernel_hybrid_tritonfor the Triton reference; only the routing generator and the harness are new.Why this is useful
Real trained routers are skewed (zipf-like), and the rdna_moe_gemm kernel wins more the more skewed the routing is, while losing on the synthetic perfectly-flat
balancedcase a real router never produces. Reviewers and future tuners can use this suite to confirm the kernel sits in its win regime for realistic routing, find the token-count crossover where it starts winning for a given archetype, and correlate the win with the per-block fill (use%) that the space-map predicts without a GPU. Example crossover finding on gfx1151: forzipf1gemm1 under the cold-cache methodology the kernel turns positive aroundT≈1024(block fill ~63%) and grows to ~+31% byT=4096.