[Metal] Expose kernel_mul_mm_id wrapper and parallelize rowids fan-out#3555
Open
fiorelorenzo wants to merge 5 commits into
Open
[Metal] Expose kernel_mul_mm_id wrapper and parallelize rowids fan-out#3555fiorelorenzo wants to merge 5 commits into
fiorelorenzo wants to merge 5 commits into
Conversation
Adds call_quantized_matmul_mm_id in candle-metal-kernels mirroring call_quantized_matmul_mm_t. Dispatches the single-pass MoE kernel already vendored at quantized.metal:7267 (Q4_0..Q6K, F16, F32). Adds a Criterion microbench at candle-core/benches/benchmarks/qmatmul_id.rs driving the wrapper at Q4_K, N=2048, K=4096, num_experts=128. Reference: ggml-org/llama.cpp ggml/src/ggml-metal/ggml-metal.metal:9787 for kernel_mul_mm_id. candle vendored only the single-pass variant (llama.cpp's separate _id_map0 pre-pass is not in tree), so this is one dispatch rather than two. Plumbing only; no integration into QMatmul / forward_via_f16. Per huggingface#3444 review, fused-kernel cross-products belong in CustomOp callers; this wrapper makes that possible.
…ne bench
The vendored kernel_mul_mm_id at quantized.metal:7267 carries an upstream
TODO at L7301: the rowids fan-out runs in every thread of every
threadgroup, redundantly scanning the full M*T ids buffer. For
Qwen3-MoE shapes (num_experts=128, T=8) at M=256 this dominates kernel
runtime.
Replace the serial scan with a stride-128 parallel scan and a
threadgroup atomic_uint counter. The downstream matmul reads rowids by
index and writes into dst[jid[0], jid[1], :], so the non-deterministic
rowid order is safe. This diverges from llama.cpp ggml-metal.metal
(same TODO upstream) and is annotated in the source comment.
Effect on MacBook Air M4 16 GB:
M=1: 1.91 ms -> 1.89 ms (decode, not scan-bound)
M=32: 35.0 ms -> 26.9 ms (~1.3x)
M=256: 800 ms -> 51.6 ms (~15.5x)
Also bumps the qmatmul_id bench to add a QMatMul::forward baseline at
the equivalent per-expert load (M_per_expert in {1, 2, 16}) and changes
Throughput::Bytes to Throughput::Elements so Criterion's "Gelem/s"
correctly reports FLOPs/s.
Validates call_quantized_matmul_mm_id against a naive CPU reference on
small shapes (E=4, M=3, T=2, N=8, K=32) with deterministic routing
ids[t,s] = (t*2 + s) % E. Uses GgmlDType::F32 so there is no
quantization rounding; tolerance is per-element abs diff < 1e-3.
The CPU reference mirrors what the kernel actually does (not just the
wrapper's docstring claim), in particular:
- src1 row is indexed by the slot dimension, not the token: the rowids
fan-out stores ushort2(slot, token) and the matmul reads
nb12*id[1] + nb11*(id[0] % ne11). For src1_shape = [num_tokens, k]
the wrapper sets nb12 = 0, so the input row used is `slot %
num_tokens` (this falls out of the upstream llama.cpp single-pass
convention where src1 is per-slot, not per-token).
- dst writes land at flat offset
j + slot * ne0 + token * (ne0 * ne1)
with ne0 = n and ne1 = experts_per_token * num_tokens, so the
per-token stride is n*T*M, not the contiguous n*T. The output buffer
in the test is sized accordingly; the comparison loop only touches
the n-wide row per (token, slot) pair that the kernel writes.
The test covers the parallelized rowids fan-out introduced in the
previous commit (rowids order is non-deterministic but the matmul reads
by index, so the final output is stable; this is the property the test
asserts).
Companion to huggingface#3555.
The wrapper was wiring the kernel for a layout the caller does not get. Two related bugs, both exposed by writing the correctness test: 1. `ne1` was set to `nei0 * nei1` (= experts_per_token * num_tokens). The kernel writes at `jid[0] * ne0 + jid[1] * (ne0 * ne1)`, where `jid = (slot, token)`. With `ne1 = nei0 * nei1`, the per-token byte multiplier becomes `n * experts_per_token * num_tokens`, which for `num_tokens > 1` walks off the end of the contiguous `[num_tokens, experts_per_token, n]` buffer a caller naturally allocates. The microbench in `qmatmul_id.rs` never read its output so the out-of-bounds writes went unnoticed. 2. `src1` was wired as `[ne11 = num_tokens, ne10 = k]`. The kernel reads `src1 + nb12 * id[1] + nb11 * (id[0] % ne11)`, so with `nb12 = 0` and `ne11 = num_tokens` the row used was `slot % num_tokens`, not the token. Per-(token, slot) work was therefore reading the wrong input vector when `slot != token`. The fix treats `src1` as `[ne12 = num_tokens, ne11 = 1, ne10 = k]` (singleton slot dim, broadcast across slots) and sets `ne1 = nei0` so the dst write stride matches the contiguous `[num_tokens, experts_per_token, n]` layout the wrapper docstring describes. The grid X dim is decoupled from `ne1` and now derives from `nei0 * nei1` (a safe upper bound on the per-expert routing count); the kernel's existing per-expert `r1 * BLOCK_SIZE_N >= _ne1` early-exit clamps unused tiles. The wrapper also now rejects 3D src1 since per-slot input layout is not supported. Also updates the F32 correctness test added in d8efd5e to use the contiguous output layout and the new per-token src1 indexing.
Refactors run_bench_metal to take the GgmlDType as a parameter and adds
F16 cases at M in {1, 32}. The F16 path exercises the wrapper's
non-quantized branch (kernel_mul_mm_id_f16_f32 host_name) without the
dequant overhead the Q4_K path pays.
F16 M=256 is intentionally skipped: the F16 weight stack at
[E=128, N=2048, K=4096] is ~2.15 GB; on a 16 GB host the M=256 setup
crowds the input/output buffers tighter than the bench's value.
Q4_K coverage (M=1, 32, 256) is unchanged.
Numbers on MacBook Air M4 16 GB, macOS 26.3.1:
qmatmul_id_f16_n2048_k4096_m1 1.71 ms / 78 GFLOPs/s
qmatmul_id_f16_n2048_k4096_m32 25.3 ms / 170 GFLOPs/s
For comparison, the Q4_K cases at the same shapes:
qmatmul_id_q4k_n2048_k4096_m1 1.94 ms / 69 GFLOPs/s
qmatmul_id_q4k_n2048_k4096_m32 27.9 ms / 154 GFLOPs/s
F16 is ~10% faster across both shapes, attributable to skipping the
per-row Q4_K dequant in the kernel's K-loop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The single-pass MoE-fused matmul shader
kernel_mul_mm_idis vendored atcandle-metal-kernels/src/metal_src/quantized.metal:7267but is not callable from Rust today. This PR:Adds
call_quantized_matmul_mm_idincandle-metal-kernelsmirroringcall_quantized_matmul_mm_t. Public dispatch function for callers that own their own MoE routing (CustomOp users, candle-transformers MoE models). Not auto-wired intoQMatMul::forward; that is a follow-up once the wrapper has callers.Parallelizes the rowids fan-out at
quantized.metal:7301. The vendored shader carries aTODO: parallelize this loophere, inherited from the older single-passkernel_mul_mm_idshape. The serial scan runs in every thread of every threadgroup, redundantly scanning the fullM * Tids buffer; for Qwen3-MoE shapes (num_experts=128,T=8) atM=256this dominates the kernel runtime. The replacement uses a stride-128 scan plus a threadgroup atomic counter. The downstream matmul reads rowids by index and writes intodst[jid[0], jid[1], :], so the resulting non-deterministic rowid order does not affect output. Upstreamggml-org/llama.cpphas since refactored to a two-passkernel_mul_mm_id_map0+kernel_mul_mm_iddesign (seeggml/src/ggml-metal/ggml-metal.metal:9721-9784) where the first pass is already parallel, so there is no longer anything to upstream this change to.Adds a Criterion microbench at
candle-core/benches/benchmarks/qmatmul_id.rsdriving the wrapper at Q4_K and F16,N=2048,K=4096,num_experts=128,T=8. Q4_K sweepsM in {1, 32, 256}; F16 sweepsM in {1, 32}only (the F16 weight stack at these shapes is ~2.15 GB and crowds a 16 GB host atM=256). Plus aQMatMul::forwardbaseline at the equivalent per-expert load for comparison.Adds an F32 correctness test (
src/tests.rs::qmatmul_mm_id_f32_correctness) that validates the wrapper output against a naive CPU reference. Writing this test surfaced a wrapper layout bug (commit0223447):ne1was set tonei0 * nei1instead ofnei0, andsrc1was wired withne11 = num_tokensinstead ofne11 = 1, so the kernel wrote dst at+token * (n * num_tokens * T)and read src1 at rowslot % num_tokens. Microbenches never read their output and the prior iterations silently wrote out of bounds forM > 1. The fix and the correctness test landed together; the rewritten wrapper now produces a contiguous[num_tokens, experts_per_token, n]output the test asserts byte-for-byte against the CPU reference.Benchmarks on MacBook Air M4, 16 GB unified memory, macOS 26.3.1 (post wrapper-layout fix):
Q4_K (the realistic MoE quantization):
mm_t * num_expertsbaselineF16 (skips the per-row dequant the Q4_K path pays):
Effect of the rowids fan-out parallelization in isolation (with vs without point 2 above): Q4_K
M=256goes from 800 ms (43 GFLOPs/s) to 52.9 ms (649 GFLOPs/s), a 15.1x speedup attributable to the scan rewrite.The fused dispatch wins over the per-expert
mm_tloop atM=32(1.25x faster on Q4_K) and is within 1.5x of the loop atM=256. Both extremes are still limited by per-threadgroup overhead from launching allnum_expertsexpert-threadgroups regardless of routing density; closing that is a follow-up (the M=1 microbench atnum_experts=8simulating an ideal active-expert pre-pack only gave ~9% over the current 128-launch path, so the structural fix would be a matvec kernel variant).Bench command:
References:
kernel_mul_mm_idupstream:ggml-org/llama.cpp. Llama.cpp has since refactored to a two-pass_id_map0+mm_iddesign (seeggml-metal-ops.cpp:2329); candle vendored only the older single-pass kernel.