Qwen3.6 MTP#2218
Draft
tianleiwu wants to merge 18 commits into
Draft
Conversation
The CUDA QMoE expert weights produced by the model builder were not in the layout/encoding the fpA_intB mixed-GEMM kernel consumes, so INT4 exports of Qwen3.5/3.6-MoE generated incoherent output while fp16 was correct. base.py: - Add _cutlass_prepacked_blockwise_quantize: quantize each expert weight with ORT's blockwise quantize_matmul_4bits (on the transposed [K, N] weight), keep the SIGNED scales, and offline CUTLASS-prepack via pack_weights_for_cuda_mixed_gemm (force_arch=80, which the kernel expects for all SM >= 80). This is the encoding validated by the com.microsoft QMoE CUDA parity tests. Taking abs() of the scales (as the previous path did) corrupts every block whose anchor element is negative and yields garbage weights. - make_qmoe_weights: route CUDA QMoE through the new prepacked path and assert block_size is 64 or 128 (the only sizes the CUDA kernel supports). - Plumb a tri-state weights_prepacked QMoE attribute (default None = omit = kernel's prepacked default; override via extra_options qmoe_weights_prepacked). qwen.py: - Exclude the MoE router and shared-expert gate MatMuls from INT4/INT8 quantization; 4-bit rounding of these tiny routing matmuls flips top-k expert selection and injects large error into every MoE layer. - Fix a node-name collision in the shared expert (rename the gated Mul to .../gate/Mul) that produced a duplicate value name and a ShapeInferenceError.
- make_qmoe_weights: treat weights_prepacked in {None, 1} as prepacked so an
explicit qmoe_weights_prepacked=1 still produces CUTLASS-prepacked weights
(previously only None did, while make_qmoe_op emitted weights_prepacked=1 on
raw weights).
- Replace the block_size assert with a ValueError (asserts are stripped by
python -O) and apply it to all CUDA QMoE paths.
- Gate the raw (weights_prepacked=0) path and the emitted weights_prepacked
attribute on the CUDA EP.
- Fix the misleading 'ships raw weights' comment and document that the
blockwise scales are SIGNED (abs() reintroduces the garbage-output bug).
- Add test/python/models/test_qmoe_weights.py covering path dispatch, EP
gating, block-size validation, and the signed-scale regression guard.
Add a Qwen35MtpHead builder that emits a separate mtp.onnx (one full-attention + MoE decoder layer + fc + pre/post RMSNorms) for multi-token-prediction self-speculative decoding, gated behind --extra_options enable_mtp=true. It reuses the parent Qwen35MoeTextModel machinery and loads the mtp.* + shared embedding/lm_head weights directly from safetensors (HF transformers discards mtp.* on load). The builder also emits the 'mtp' section (and the decoder's hidden_states output) into genai_config.json. Validated: the exported fp16 mtp.onnx reproduces 88.3% greedy acceptance, bit-identical to the PyTorch reference. Fixes: the MTP head must not inherit include_hidden_states/exclude_lm_head (would make the final-norm output double as a graph output and feed the lm_head, creating a graph cycle); guard the shared save_model cache-dir cleanup so a multi-model builder does not fail on the second save.
Add Config::Model::Mtp (filename, session/run options, I/O names) plus JSON parsing in config.cpp, mirroring the existing Encoder/Decoder pattern. Also add a hidden_states field to the decoder outputs so a model exported with include_hidden_states exposes the last hidden state (consumed by the MTP head). Backward compatible: configs without an 'mtp' section parse unchanged.
Hybrid (GatedDeltaNet) recurrent state cannot be partially cropped like the attention KV cache, so a rejected speculative draft cannot be rolled back with the existing RewindTo. Add RecurrentState::Snapshot()/RestoreSnapshot() that copy the conv + linear-attention state buffers out and back in place (stable addresses for CUDA-graph replay), and make RewindTo(index != 0) restore from the snapshot. A 'rewind by 1' then rolls back both the KV cache (crop) and the recurrent state (snapshot restore). Expose the snapshot through every layer: State::SnapshotState (virtual) + DecoderOnly_State override, Generator::SnapshotState, OgaGenerator_SnapshotState (C API), OgaGenerator::SnapshotState (C++ wrapper), and snapshot_state (Python). Validated end to end on the real genai engine: a draft/verify MTP loop (main model via genai, mtp.onnx via onnxruntime) reaches 84.5% accept / 1.57x effective tokens-per-main-forward and matches plain greedy exactly through draft rejections, confirming the recurrent rollback is correct.
Let mtp.onnx run as a standalone genai model so the MTP draft executes in-engine (real genai kernels + KV cache) instead of a separate session. Add a HiddenStatesInputs feeder that owns a resizable [batch, seq, hidden] device tensor and is refreshed each step from a caller-staged value (the main model's last hidden state). DecoderOnly_State creates it only when config.model.decoder.inputs.hidden_states is set, so normal models are unaffected. Expose staging through Generator::SetHiddenStates -> OgaGenerator_SetHiddenStates (C API) -> OgaGenerator::SetHiddenStates (C++) -> set_hidden_states (Python). Verified: mtp.onnx loaded as a genai og.Model with set_hidden_states produces draft logits matching the raw onnxruntime reference (identical argmax and top-5; max abs diff 0.012, expected fp16 kernel variance).
Add examples/python/qwen-3.6-mtp.py: a consolidated, reusable MtpGenerator that drives the draft/verify loop entirely through the genai Python API (both the main model and the MTP head as og.Model, using set_hidden_states / snapshot_state / rewind_to). Add examples/python/qwen-3.6-mtp.md documenting the MTP head architecture, export (enable_mtp + include_hidden_states), config layout, the draft/verify algorithm, the hybrid-model recurrent-state rollback problem, the API additions, and measured results. Link the example from the examples README. Verified end to end: ~88% accept, ~1.67 tokens/forward, output matching greedy.
The hidden_states output was previously only retrievable as an ORT-allocated
extra output, which fails under CUDA graph (capture binds outputs to static
buffers): get_output("hidden_states") returned an unconstructed tensor.
Add a managed HiddenStatesOutputs that owns the output tensor and makes it
static on single-token (graph-captured) steps, mirroring Logits. Created by
DecoderOnly_State only when the decoder emits a hidden_states output, so models
without one are unaffected. With this, get_output("hidden_states") works with
CUDA graph enabled (verified: plain greedy decode stays coherent and the hidden
state is retrievable), which is a prerequisite for accelerating MTP with CUDA
graph -- isolated latency shows the MTP draft step drops 18.4ms -> 0.52ms once
graph-captured.
The MTP hidden_states feeder previously always staged through host memory (memcpy to a pinned CPU buffer + async H2D copy). When the source hidden state is already on the model's device (the in-engine case, where the main model's hidden_states output is consumed by the MTP head), copy device-to-device on the shared CUDA stream instead. All genai CUDA sessions share one compute stream, so the enqueued D2D copy is correctly ordered after the producer's Run and before this model's Run with no host round-trip and no host synchronization (cf. onnxruntime issue #28539 on the async IO-binding pattern). Falls back to the host-staged copy only when the source lives on the CPU.
Add a Synchronization and CUDA graph section to the Qwen3.6 MTP design doc: the shared single compute stream that makes the two-model handoff sync-free, the device-to-device hidden_states feed, the static-output requirement under graph capture, and why graph-capturing the verify step (plus an in-engine generator) is the remaining lever for a wall-clock win. References onnxruntime issues #28539 (async IO binding) and #28686 (async CUDA-graph replay).
Add a first-class MtpGenerator that runs the Qwen3.6 MTP self-speculative
draft/verify loop entirely in C++, composing a main-model generator and an MTP
draft-head generator on the shared compute stream. The main model's last hidden
state is handed to the MTP head device-to-device (no host round-trip): the C++
State::GetOutput returns the on-device OrtValue, which is sliced to the last
position and fed via the hidden_states input feeder. Verification runs [t, d] in
a single main forward; on reject the recurrent state is restored from a snapshot
and the correct token re-run.
Expose it through the C API (OgaCreateMtpGenerator, OgaMtpGenerator_AppendTokens
/ GenerateNextToken / IsDone / GetSequence{Count,Data} / Get{Forward,Accept,
Trial}Count / OgaDestroyMtpGenerator), the C++ wrapper, and Python
(og.MtpGenerator).
Verified: output matches plain greedy decoding (one prompt bit-exact through
draft accept/reject; others match modulo fp16 near-ties), ~84-94% accept,
1.6-1.8 tokens per main forward.
Fix: ArgmaxLogitsRow kept a CopyDeviceToCpu result from a temporary device span
whose pinned host buffer was freed at end of expression (dangling read); name
the span so it outlives the read.
…enchmarks - examples/python/qwen-3.6-mtp.py: default to the built-in og.MtpGenerator; keep the pure-Python loop as ReferenceMtpGenerator behind --reference. - examples/python/qwen-3.6-mtp.md: document og.MtpGenerator in the API table and run section; record wall-clock results (Python ref ~0.3x, in-engine ~1.0x break-even); explain the eager 2-token verify and the multi-gpu_graph_id fix; expand future work with embedding/lm_head memory sharing. - qwen_3.6.md: add an 'MTP self-speculative decoding — implemented' experiment log with the shipped pieces, correctness, the wall-clock benchmark, per-step latency, the verify-shape graph-capture lever, and memory-saving future work.
Give each captured per-step input length its own CUDA-graph annotation id so the 1-token decode and the 2-token speculative verify each replay an independent captured graph bound to their own static buffers (ORT CUDA EP keys captured graphs by gpu_graph_id). - State: graph_ids_ map (id per captured length, lazily generated); State::Run(session, capture?, capture_length) selects the id. - DecoderOnly_State::Run captures shape[1] in [1, max_graph_capture_length]. - GeneratorParams::max_graph_capture_length (default 1; MtpGenerator sets 2). - Tensor::CreateTensor gains static_capacity_bytes to pre-size the static buffer to the max captured shape so its base address is stable across both graphs. - input_ids/position_ids/logits/hidden_states static I/O pre-size to 2 tokens. Fixes the graph-on output corruption (mixed graphed-1-token / eager-2-token shared static buffers) and recovers the eager-verify cost. With CUDA graph on, the in-engine og.MtpGenerator goes ~115 -> ~127 tok/s; remaining gap vs the (also graph-accelerated) baseline is host-side draft/verify orchestration.
The MTP draft/verify loop selected tokens on the host: cast the main model's [1,S,248320] logits to fp32, copy whole rows (~2MB each, up to 3/step) to CPU, std::max_element over 248K elements -- a synchronizing D2H plus a serial host scan on the critical path. Add DeviceInterface::ArgMax (k=1 top-1) backed by the existing high-performance genai CUDA Top-K (distributed_select_sort, unbeatable for k=1), so only the 4-byte token id crosses the bus. The virtual is appended at the END of DeviceInterface to keep the vtable layout / ABI stable. MtpGenerator uses it for both the batched verify rows (logits@L and logits@L+1 in one launch) and the MTP head draft; the host argmax remains as a fallback for devices without ArgMax. Result (fp16, 1xH200, batch1, greedy, CUDA graph ON): in-engine og.MtpGenerator ~127 -> ~160 tok/s (128 tok) / ~177 tok/s (200 tok) = 1.14x / 1.24x vs the graph-accelerated baseline (~140-143). Output coherent, accept unchanged (numerically identical argmax). This crosses MTP from break-even into a real wall-clock speedup.
On an accepted step the MTP head is run once on the accepted (hidden@L, d) pair only to keep its KV cache aligned -- the next committed token comes from the verify pass's row-1 argmax, not this draft. The draft's 248K-vocab argmax and its stream sync were pure waste on ~86% of steps. Add a need_draft flag to DraftNextToken; the post-accept call now runs KV-advance only (no GetLogits, no ArgMax, no sync). Output unchanged; removes a provably-unnecessary full-vocab argmax + stream sync per accepted step.
…rward The accept path ran two MTP-head forwards per step: one to draft d from t, and a second single-token forward only to fold the accepted d into the head's KV. Those two feeds are sequential and consecutive -- (hidden@L, d) for the KV-advance and (hidden@L+1, t_next) for the next draft -- so fuse them into a single 2-token MTP forward (the MTP head is full-attention, so a 2-token forward is identical to two 1-token feeds). The next step's draft is computed ahead and stashed in pending_draft_; the next GenerateNextToken reuses it. On reject the pending draft is invalidated and the loop falls back to a single-token draft. Made the MTP hidden_states INPUT feeder static-buffer aware (pre-sized to 2, mirroring the output) so the head's 1- and 2-token graph captures keep stable buffer addresses. Removes one MTP forward per accepted step. Output bit-identical (same accept stats). In-engine og.MtpGenerator: ~1.18x -> ~1.26x (128 tok), ~1.20-1.25x (200 tok); speedup scales with accept rate.
…host) The MTP head's embed_tokens and lm_head are bit-identical to the main model's (~2 GB of mtp.onnx's ~3.8 GB in fp16). After saving both ONNX files, the builder now redirects mtp.onnx's copies of those two initializers to the main model's external data file (model.onnx.data, at the main tensor's offset/length) and packs their bytes out of mtp.onnx.data, rebuilding it with tight offsets. Sharing applies only to byte-identical tensors (same name/dtype/shape + matching sampled bytes), so a quantized-main vs fp16-MTP lm_head mismatch is left untouched; any failure is non-fatal (models stay valid, just larger). Measured (fp16): mtp.onnx.data 3.79 GB -> 1.76 GB (-2.03 GB disk). Both sessions resolve the two tensors from the same model.onnx.data, so the host mmaps a single copy (host-RAM dedup). Output bit-identical, throughput unchanged (~1.23x@128). Scope: saves disk + host RAM. GPU VRAM is unchanged -- ORT's public API uploads each session's initializers to the device independently, so cross-session device sharing is not exposed.
|
|
||
| super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) | ||
|
|
||
| self.model_type = "Qwen3_5_Moe_textForCausalLM" |
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds end-to-end support for Qwen3.6 “MTP” (multi-token prediction) self-speculative decoding across the model builder, runtime (C/C++), and Python bindings, alongside CUDA QMoE expert-weight quantization fixes/guards and CUDA-graph capture improvements for multi-shape replay.
Changes:
- Add an in-engine MTP draft/verify generator (C API + C++ wrapper + Python bindings) and the required runtime primitives (hidden-states I/O staging, recurrent-state snapshot/restore, on-device argmax).
- Extend model builder/export for Qwen3.6 MTP head generation (
mtp.onnx) and config emission (genai_config.jsonmtpsection + hidden_states wiring), plus weight dedup optimization. - Add CUDA QMoE quantization path selection (CUTLASS-prepacked vs raw) with regression tests and safer cache-dir cleanup.
Reviewed changes
Copilot reviewed 33 out of 33 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| test/python/models/test_qmoe_weights.py | Adds regression tests for CUDA QMoE quantization dispatch/validation and signed scales behavior. |
| src/tensor.h | Extends Tensor::CreateTensor API with optional static-capacity hint for CUDA-graph-stable buffers. |
| src/tensor.cpp | Implements static buffer pre-sizing via static_capacity_bytes to keep addresses stable across captures. |
| src/smartptrs.h | Adds DeviceInterface::ArgMax hook (default false) for on-device argmax. |
| src/python/python.cpp | Exposes generator snapshot/hidden-states staging and adds MtpGenerator Python binding. |
| src/python/py/models/README.md | Documents builder options to export Qwen3.6 MTP head. |
| src/python/py/models/builders/qwen.py | Implements Qwen3.6 MTP head export, config augmentation, and embedding/lm_head external-data sharing. |
| src/python/py/models/builders/base.py | Adds QMoE weights_prepacked tri-state plumbing and CUDA-specific quantization paths; improves cache-dir cleanup guard. |
| src/ort_genai.h | Adds C++ wrapper methods for snapshot/hidden-states and a new OgaMtpGenerator wrapper. |
| src/ort_genai_c.h | Introduces C API surface for MTP generator + snapshot/hidden-states APIs. |
| src/ort_genai_c.cpp | Implements the new C API entrypoints and wires them to the C++ runtime. |
| src/mtp_generator.h | Declares the in-engine MTP draft/verify orchestrator. |
| src/mtp_generator.cpp | Implements MTP draft/verify loop, including on-device argmax usage and pipelined draft fusion. |
| src/models/recurrent_state.h | Adds snapshot/restore APIs for recurrent state rollback in speculative decoding. |
| src/models/recurrent_state.cpp | Implements recurrent-state snapshot/restore and updates partial rewind behavior/messages. |
| src/models/position_inputs.cpp | Pre-sizes static position_ids buffers to max captured length for multi-shape CUDA graphs. |
| src/models/model.h | Adds SnapshotState/SetHiddenStates hooks and multi-length CUDA graph-id support. |
| src/models/model.cpp | Implements per-length CUDA graph annotation IDs and updates State::Run signature. |
| src/models/logits.cpp | Pre-sizes static logits buffers for multi-length CUDA-graph capture. |
| src/models/input_ids.cpp | Pre-sizes static input_ids (and int64 cast buffer) for multi-length CUDA-graph capture. |
| src/models/hidden_states_inputs.h | Introduces managed hidden_states input/output helpers for MTP head staging and CUDA-graph-safe outputs. |
| src/models/hidden_states_inputs.cpp | Implements hidden_states staging with D2D/CPU fallback and graph-capture static buffer sizing. |
| src/models/decoder_only.h | Wires hidden_states input/output helpers into decoder-only state. |
| src/models/decoder_only.cpp | Enables graph capture for lengths up to max_graph_capture_length and updates hidden_states I/O each step. |
| src/generators.h | Adds max_graph_capture_length and generator methods for snapshot/hidden-states staging. |
| src/generators.cpp | Implements Generator::SnapshotState and Generator::SetHiddenStates plumbing. |
| src/cuda/interface.cpp | Implements CUDA ArgMax via Top-K (k=1) with cached scratch + pinned host staging. |
| src/config.h | Adds hidden_states names and new model.mtp config section schema. |
| src/config.cpp | Adds JSON parsing for decoder hidden_states and the mtp section. |
| qwen_3.6.md | Adds a detailed optimization log / design and benchmark notes for Qwen3.6 + MTP. |
| examples/python/README.md | Documents the new Qwen3.6 MTP example script. |
| examples/python/qwen-3.6-mtp.py | Adds runnable Python example for built-in and reference MTP decoding. |
| examples/python/qwen-3.6-mtp.md | Adds detailed design/export/runtime documentation for Qwen3.6 MTP in GenAI. |
Comment on lines
+307
to
+313
| # For CUDA QMoE the builder ships expert weights already CUTLASS-prepacked | ||
| # (offline via pack_weights_for_cuda_mixed_gemm, see make_qmoe_weights), so | ||
| # the QMoE op's default interpretation (weights_prepacked=-1/auto = | ||
| # prepacked) is exactly what we want and the attribute is omitted (None). | ||
| # Override via extra_options["qmoe_weights_prepacked"] (e.g. 0 to ship raw | ||
| # [E, N, K/pack] weights and let the runtime PrePack hook transform them). | ||
| weights_prepacked = int(extra_options["qmoe_weights_prepacked"]) if "qmoe_weights_prepacked" in extra_options else None |
Comment on lines
+2188
to
+2202
| # MTP (multi-token prediction) self-speculative head. | ||
| # When ``enable_mtp`` is set, an auxiliary ``mtp.onnx`` model is exported | ||
| # alongside the main model (see ``Qwen35MtpHead``). It is disabled for the | ||
| # MTP head itself (``is_mtp_head``) to avoid infinite recursion. | ||
| self.mtp_head = None | ||
| self.enable_mtp = bool(extra_options.get("enable_mtp", False)) and not getattr(self, "is_mtp_head", False) | ||
| if self.enable_mtp: | ||
| # Stash the constructor arguments so the MTP head can be built from a | ||
| # pristine config after the main model has been generated. | ||
| self._mtp_config = copy.deepcopy(config) | ||
| self._mtp_io_dtype = io_dtype | ||
| self._mtp_onnx_dtype = onnx_dtype | ||
| self._mtp_ep = ep | ||
| self._mtp_cache_dir = cache_dir | ||
| self._mtp_extra_options = copy.deepcopy(extra_options) |
Comment on lines
+27
to
+39
| MtpGenerator::MtpGenerator(const Model& main_model, const Model& mtp_model, const GeneratorParams& params) | ||
| : main_model_{main_model}, mtp_model_{mtp_model} { | ||
| // MTP runs both a 1-token decode and a 2-token verify on the main model. Allow CUDA graph | ||
| // capture of both shapes (each captured under its own annotation id with pre-sized static | ||
| // buffers). Harmless for the MTP head, which only ever runs a single token per step. | ||
| const_cast<GeneratorParams&>(params).max_graph_capture_length = 2; | ||
|
|
||
| main_ = CreateGenerator(main_model_, params); | ||
| mtp_ = CreateGenerator(mtp_model_, params); | ||
|
|
||
| hidden_size_ = main_model_.config_->model.decoder.hidden_size; | ||
| vocab_size_ = main_model_.config_->model.vocab_size; | ||
| max_length_ = params.search.max_length; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.