Skip to content

[Metal] Expose kernel_mul_mm_id wrapper and parallelize rowids fan-out#3555

Open
fiorelorenzo wants to merge 5 commits into
huggingface:mainfrom
fiorelorenzo:feat/metal-mul-mm-id
Open

[Metal] Expose kernel_mul_mm_id wrapper and parallelize rowids fan-out#3555
fiorelorenzo wants to merge 5 commits into
huggingface:mainfrom
fiorelorenzo:feat/metal-mul-mm-id

Conversation

@fiorelorenzo

@fiorelorenzo fiorelorenzo commented May 20, 2026

Copy link
Copy Markdown

The single-pass MoE-fused matmul shader kernel_mul_mm_id is vendored at candle-metal-kernels/src/metal_src/quantized.metal:7267 but is not callable from Rust today. This PR:

  1. Adds call_quantized_matmul_mm_id in candle-metal-kernels mirroring call_quantized_matmul_mm_t. Public dispatch function for callers that own their own MoE routing (CustomOp users, candle-transformers MoE models). Not auto-wired into QMatMul::forward; that is a follow-up once the wrapper has callers.

  2. Parallelizes the rowids fan-out at quantized.metal:7301. The vendored shader carries a TODO: parallelize this loop here, inherited from the older single-pass kernel_mul_mm_id shape. The serial scan runs in every thread of every threadgroup, redundantly scanning the full M * T ids buffer; for Qwen3-MoE shapes (num_experts=128, T=8) at M=256 this dominates the kernel runtime. The replacement uses a stride-128 scan plus a threadgroup atomic counter. The downstream matmul reads rowids by index and writes into dst[jid[0], jid[1], :], so the resulting non-deterministic rowid order does not affect output. Upstream ggml-org/llama.cpp has since refactored to a two-pass kernel_mul_mm_id_map0 + kernel_mul_mm_id design (see ggml/src/ggml-metal/ggml-metal.metal:9721-9784) where the first pass is already parallel, so there is no longer anything to upstream this change to.

  3. Adds a Criterion microbench at candle-core/benches/benchmarks/qmatmul_id.rs driving the wrapper at Q4_K and F16, N=2048, K=4096, num_experts=128, T=8. Q4_K sweeps M in {1, 32, 256}; F16 sweeps M in {1, 32} only (the F16 weight stack at these shapes is ~2.15 GB and crowds a 16 GB host at M=256). Plus a QMatMul::forward baseline at the equivalent per-expert load for comparison.

  4. Adds an F32 correctness test (src/tests.rs::qmatmul_mm_id_f32_correctness) that validates the wrapper output against a naive CPU reference. Writing this test surfaced a wrapper layout bug (commit 0223447): ne1 was set to nei0 * nei1 instead of nei0, and src1 was wired with ne11 = num_tokens instead of ne11 = 1, so the kernel wrote dst at +token * (n * num_tokens * T) and read src1 at row slot % num_tokens. Microbenches never read their output and the prior iterations silently wrote out of bounds for M > 1. The fix and the correctness test landed together; the rewritten wrapper now produces a contiguous [num_tokens, experts_per_token, n] output the test asserts byte-for-byte against the CPU reference.

Benchmarks on MacBook Air M4, 16 GB unified memory, macOS 26.3.1 (post wrapper-layout fix):

Q4_K (the realistic MoE quantization):

Case this PR (fused) mm_t * num_experts baseline
M=1 (decode, 8 experts hit) 1.94 ms / 69 GFLOPs/s 8 * 40.2 us = 0.32 ms
M=32 (small prefill, all 128) 27.9 ms / 154 GFLOPs/s 128 * 273 us = 35.0 ms
M=256 (chunked, all 128) 52.9 ms / 649 GFLOPs/s 128 * 273 us = 35.0 ms

F16 (skips the per-row dequant the Q4_K path pays):

Case this PR (fused)
M=1 1.71 ms / 78 GFLOPs/s
M=32 25.3 ms / 170 GFLOPs/s

Effect of the rowids fan-out parallelization in isolation (with vs without point 2 above): Q4_K M=256 goes from 800 ms (43 GFLOPs/s) to 52.9 ms (649 GFLOPs/s), a 15.1x speedup attributable to the scan rewrite.

The fused dispatch wins over the per-expert mm_t loop at M=32 (1.25x faster on Q4_K) and is within 1.5x of the loop at M=256. Both extremes are still limited by per-threadgroup overhead from launching all num_experts expert-threadgroups regardless of routing density; closing that is a follow-up (the M=1 microbench at num_experts=8 simulating an ideal active-expert pre-pack only gave ~9% over the current 128-launch path, so the structural fix would be a matvec kernel variant).

Bench command:

cd candle-core && cargo bench --bench bench_main --features metal -- qmatmul_id

References:

  • kernel_mul_mm_id upstream: ggml-org/llama.cpp. Llama.cpp has since refactored to a two-pass _id_map0 + mm_id design (see ggml-metal-ops.cpp:2329); candle vendored only the older single-pass kernel.
  • Related: Metal: fused MoE kernels, QTensor buffer access, encoder reuse, Q4K tuning #3444 (bounced for scope); this PR addresses only the MoE-fused matmul wrapper, no fused-quant cross-products, no encoder reuse.

Adds call_quantized_matmul_mm_id in candle-metal-kernels mirroring
call_quantized_matmul_mm_t. Dispatches the single-pass MoE kernel
already vendored at quantized.metal:7267 (Q4_0..Q6K, F16, F32).

Adds a Criterion microbench at candle-core/benches/benchmarks/qmatmul_id.rs
driving the wrapper at Q4_K, N=2048, K=4096, num_experts=128.

Reference: ggml-org/llama.cpp ggml/src/ggml-metal/ggml-metal.metal:9787
for kernel_mul_mm_id. candle vendored only the single-pass variant
(llama.cpp's separate _id_map0 pre-pass is not in tree), so this is one
dispatch rather than two.

Plumbing only; no integration into QMatmul / forward_via_f16. Per huggingface#3444
review, fused-kernel cross-products belong in CustomOp callers; this
wrapper makes that possible.
…ne bench

The vendored kernel_mul_mm_id at quantized.metal:7267 carries an upstream
TODO at L7301: the rowids fan-out runs in every thread of every
threadgroup, redundantly scanning the full M*T ids buffer. For
Qwen3-MoE shapes (num_experts=128, T=8) at M=256 this dominates kernel
runtime.

Replace the serial scan with a stride-128 parallel scan and a
threadgroup atomic_uint counter. The downstream matmul reads rowids by
index and writes into dst[jid[0], jid[1], :], so the non-deterministic
rowid order is safe. This diverges from llama.cpp ggml-metal.metal
(same TODO upstream) and is annotated in the source comment.

Effect on MacBook Air M4 16 GB:
  M=1:   1.91 ms -> 1.89 ms   (decode, not scan-bound)
  M=32:  35.0 ms -> 26.9 ms   (~1.3x)
  M=256: 800 ms -> 51.6 ms   (~15.5x)

Also bumps the qmatmul_id bench to add a QMatMul::forward baseline at
the equivalent per-expert load (M_per_expert in {1, 2, 16}) and changes
Throughput::Bytes to Throughput::Elements so Criterion's "Gelem/s"
correctly reports FLOPs/s.
Validates call_quantized_matmul_mm_id against a naive CPU reference on
small shapes (E=4, M=3, T=2, N=8, K=32) with deterministic routing
ids[t,s] = (t*2 + s) % E. Uses GgmlDType::F32 so there is no
quantization rounding; tolerance is per-element abs diff < 1e-3.

The CPU reference mirrors what the kernel actually does (not just the
wrapper's docstring claim), in particular:

  - src1 row is indexed by the slot dimension, not the token: the rowids
    fan-out stores ushort2(slot, token) and the matmul reads
    nb12*id[1] + nb11*(id[0] % ne11). For src1_shape = [num_tokens, k]
    the wrapper sets nb12 = 0, so the input row used is `slot %
    num_tokens` (this falls out of the upstream llama.cpp single-pass
    convention where src1 is per-slot, not per-token).

  - dst writes land at flat offset
        j + slot * ne0 + token * (ne0 * ne1)
    with ne0 = n and ne1 = experts_per_token * num_tokens, so the
    per-token stride is n*T*M, not the contiguous n*T. The output buffer
    in the test is sized accordingly; the comparison loop only touches
    the n-wide row per (token, slot) pair that the kernel writes.

The test covers the parallelized rowids fan-out introduced in the
previous commit (rowids order is non-deterministic but the matmul reads
by index, so the final output is stable; this is the property the test
asserts).

Companion to huggingface#3555.
The wrapper was wiring the kernel for a layout the caller does not get.
Two related bugs, both exposed by writing the correctness test:

1. `ne1` was set to `nei0 * nei1` (= experts_per_token * num_tokens).
   The kernel writes at `jid[0] * ne0 + jid[1] * (ne0 * ne1)`, where
   `jid = (slot, token)`. With `ne1 = nei0 * nei1`, the per-token byte
   multiplier becomes `n * experts_per_token * num_tokens`, which for
   `num_tokens > 1` walks off the end of the contiguous
   `[num_tokens, experts_per_token, n]` buffer a caller naturally
   allocates. The microbench in `qmatmul_id.rs` never read its output
   so the out-of-bounds writes went unnoticed.

2. `src1` was wired as `[ne11 = num_tokens, ne10 = k]`. The kernel
   reads `src1 + nb12 * id[1] + nb11 * (id[0] % ne11)`, so with
   `nb12 = 0` and `ne11 = num_tokens` the row used was `slot % num_tokens`,
   not the token. Per-(token, slot) work was therefore reading the
   wrong input vector when `slot != token`.

The fix treats `src1` as `[ne12 = num_tokens, ne11 = 1, ne10 = k]`
(singleton slot dim, broadcast across slots) and sets `ne1 = nei0`
so the dst write stride matches the contiguous
`[num_tokens, experts_per_token, n]` layout the wrapper docstring
describes. The grid X dim is decoupled from `ne1` and now derives
from `nei0 * nei1` (a safe upper bound on the per-expert routing
count); the kernel's existing per-expert `r1 * BLOCK_SIZE_N >= _ne1`
early-exit clamps unused tiles. The wrapper also now rejects 3D src1
since per-slot input layout is not supported.

Also updates the F32 correctness test added in d8efd5e to use the
contiguous output layout and the new per-token src1 indexing.
Refactors run_bench_metal to take the GgmlDType as a parameter and adds
F16 cases at M in {1, 32}. The F16 path exercises the wrapper's
non-quantized branch (kernel_mul_mm_id_f16_f32 host_name) without the
dequant overhead the Q4_K path pays.

F16 M=256 is intentionally skipped: the F16 weight stack at
[E=128, N=2048, K=4096] is ~2.15 GB; on a 16 GB host the M=256 setup
crowds the input/output buffers tighter than the bench's value.
Q4_K coverage (M=1, 32, 256) is unchanged.

Numbers on MacBook Air M4 16 GB, macOS 26.3.1:
  qmatmul_id_f16_n2048_k4096_m1   1.71 ms / 78 GFLOPs/s
  qmatmul_id_f16_n2048_k4096_m32  25.3 ms / 170 GFLOPs/s

For comparison, the Q4_K cases at the same shapes:
  qmatmul_id_q4k_n2048_k4096_m1   1.94 ms / 69 GFLOPs/s
  qmatmul_id_q4k_n2048_k4096_m32  27.9 ms / 154 GFLOPs/s

F16 is ~10% faster across both shapes, attributable to skipping the
per-row Q4_K dequant in the kernel's K-loop.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant