Skip to content

Qwen3.6 MTP#2218

Draft
tianleiwu wants to merge 18 commits into
mainfrom
tlwu/20260610/qwen_3.6_mtp
Draft

Qwen3.6 MTP#2218
tianleiwu wants to merge 18 commits into
mainfrom
tlwu/20260610/qwen_3.6_mtp

Conversation

@tianleiwu

Copy link
Copy Markdown
Contributor

No description provided.

tianleiwu added 18 commits June 9, 2026 16:54
The CUDA QMoE expert weights produced by the model builder were not in the
layout/encoding the fpA_intB mixed-GEMM kernel consumes, so INT4 exports of
Qwen3.5/3.6-MoE generated incoherent output while fp16 was correct.

base.py:
- Add _cutlass_prepacked_blockwise_quantize: quantize each expert weight with
  ORT's blockwise quantize_matmul_4bits (on the transposed [K, N] weight),
  keep the SIGNED scales, and offline CUTLASS-prepack via
  pack_weights_for_cuda_mixed_gemm (force_arch=80, which the kernel expects for
  all SM >= 80). This is the encoding validated by the com.microsoft QMoE CUDA
  parity tests. Taking abs() of the scales (as the previous path did) corrupts
  every block whose anchor element is negative and yields garbage weights.
- make_qmoe_weights: route CUDA QMoE through the new prepacked path and assert
  block_size is 64 or 128 (the only sizes the CUDA kernel supports).
- Plumb a tri-state weights_prepacked QMoE attribute (default None = omit =
  kernel's prepacked default; override via extra_options qmoe_weights_prepacked).

qwen.py:
- Exclude the MoE router and shared-expert gate MatMuls from INT4/INT8
  quantization; 4-bit rounding of these tiny routing matmuls flips top-k expert
  selection and injects large error into every MoE layer.
- Fix a node-name collision in the shared expert (rename the gated Mul to
  .../gate/Mul) that produced a duplicate value name and a ShapeInferenceError.
- make_qmoe_weights: treat weights_prepacked in {None, 1} as prepacked so an
  explicit qmoe_weights_prepacked=1 still produces CUTLASS-prepacked weights
  (previously only None did, while make_qmoe_op emitted weights_prepacked=1 on
  raw weights).
- Replace the block_size assert with a ValueError (asserts are stripped by
  python -O) and apply it to all CUDA QMoE paths.
- Gate the raw (weights_prepacked=0) path and the emitted weights_prepacked
  attribute on the CUDA EP.
- Fix the misleading 'ships raw weights' comment and document that the
  blockwise scales are SIGNED (abs() reintroduces the garbage-output bug).
- Add test/python/models/test_qmoe_weights.py covering path dispatch, EP
  gating, block-size validation, and the signed-scale regression guard.
Add a Qwen35MtpHead builder that emits a separate mtp.onnx (one full-attention
+ MoE decoder layer + fc + pre/post RMSNorms) for multi-token-prediction
self-speculative decoding, gated behind --extra_options enable_mtp=true. It
reuses the parent Qwen35MoeTextModel machinery and loads the mtp.* + shared
embedding/lm_head weights directly from safetensors (HF transformers discards
mtp.* on load). The builder also emits the 'mtp' section (and the decoder's
hidden_states output) into genai_config.json.

Validated: the exported fp16 mtp.onnx reproduces 88.3% greedy acceptance,
bit-identical to the PyTorch reference.

Fixes: the MTP head must not inherit include_hidden_states/exclude_lm_head
(would make the final-norm output double as a graph output and feed the
lm_head, creating a graph cycle); guard the shared save_model cache-dir
cleanup so a multi-model builder does not fail on the second save.
Add Config::Model::Mtp (filename, session/run options, I/O names) plus JSON
parsing in config.cpp, mirroring the existing Encoder/Decoder pattern. Also add
a hidden_states field to the decoder outputs so a model exported with
include_hidden_states exposes the last hidden state (consumed by the MTP head).
Backward compatible: configs without an 'mtp' section parse unchanged.
Hybrid (GatedDeltaNet) recurrent state cannot be partially cropped like the
attention KV cache, so a rejected speculative draft cannot be rolled back with
the existing RewindTo. Add RecurrentState::Snapshot()/RestoreSnapshot() that
copy the conv + linear-attention state buffers out and back in place (stable
addresses for CUDA-graph replay), and make RewindTo(index != 0) restore from
the snapshot. A 'rewind by 1' then rolls back both the KV cache (crop) and the
recurrent state (snapshot restore).

Expose the snapshot through every layer: State::SnapshotState (virtual) +
DecoderOnly_State override, Generator::SnapshotState, OgaGenerator_SnapshotState
(C API), OgaGenerator::SnapshotState (C++ wrapper), and snapshot_state (Python).

Validated end to end on the real genai engine: a draft/verify MTP loop (main
model via genai, mtp.onnx via onnxruntime) reaches 84.5% accept / 1.57x
effective tokens-per-main-forward and matches plain greedy exactly through
draft rejections, confirming the recurrent rollback is correct.
Let mtp.onnx run as a standalone genai model so the MTP draft executes
in-engine (real genai kernels + KV cache) instead of a separate session.

Add a HiddenStatesInputs feeder that owns a resizable [batch, seq, hidden]
device tensor and is refreshed each step from a caller-staged value (the main
model's last hidden state). DecoderOnly_State creates it only when
config.model.decoder.inputs.hidden_states is set, so normal models are
unaffected. Expose staging through Generator::SetHiddenStates ->
OgaGenerator_SetHiddenStates (C API) -> OgaGenerator::SetHiddenStates (C++)
-> set_hidden_states (Python).

Verified: mtp.onnx loaded as a genai og.Model with set_hidden_states produces
draft logits matching the raw onnxruntime reference (identical argmax and
top-5; max abs diff 0.012, expected fp16 kernel variance).
Add examples/python/qwen-3.6-mtp.py: a consolidated, reusable MtpGenerator that
drives the draft/verify loop entirely through the genai Python API (both the
main model and the MTP head as og.Model, using set_hidden_states / snapshot_state
/ rewind_to). Add examples/python/qwen-3.6-mtp.md documenting the MTP head
architecture, export (enable_mtp + include_hidden_states), config layout, the
draft/verify algorithm, the hybrid-model recurrent-state rollback problem, the
API additions, and measured results. Link the example from the examples README.

Verified end to end: ~88% accept, ~1.67 tokens/forward, output matching greedy.
The hidden_states output was previously only retrievable as an ORT-allocated
extra output, which fails under CUDA graph (capture binds outputs to static
buffers): get_output("hidden_states") returned an unconstructed tensor.

Add a managed HiddenStatesOutputs that owns the output tensor and makes it
static on single-token (graph-captured) steps, mirroring Logits. Created by
DecoderOnly_State only when the decoder emits a hidden_states output, so models
without one are unaffected. With this, get_output("hidden_states") works with
CUDA graph enabled (verified: plain greedy decode stays coherent and the hidden
state is retrievable), which is a prerequisite for accelerating MTP with CUDA
graph -- isolated latency shows the MTP draft step drops 18.4ms -> 0.52ms once
graph-captured.
The MTP hidden_states feeder previously always staged through host memory
(memcpy to a pinned CPU buffer + async H2D copy). When the source hidden state
is already on the model's device (the in-engine case, where the main model's
hidden_states output is consumed by the MTP head), copy device-to-device on the
shared CUDA stream instead. All genai CUDA sessions share one compute stream, so
the enqueued D2D copy is correctly ordered after the producer's Run and before
this model's Run with no host round-trip and no host synchronization (cf.
onnxruntime issue #28539 on the async IO-binding pattern). Falls back to the
host-staged copy only when the source lives on the CPU.
Add a Synchronization and CUDA graph section to the Qwen3.6 MTP design doc:
the shared single compute stream that makes the two-model handoff sync-free,
the device-to-device hidden_states feed, the static-output requirement under
graph capture, and why graph-capturing the verify step (plus an in-engine
generator) is the remaining lever for a wall-clock win. References onnxruntime
issues #28539 (async IO binding) and #28686 (async CUDA-graph replay).
Add a first-class MtpGenerator that runs the Qwen3.6 MTP self-speculative
draft/verify loop entirely in C++, composing a main-model generator and an MTP
draft-head generator on the shared compute stream. The main model's last hidden
state is handed to the MTP head device-to-device (no host round-trip): the C++
State::GetOutput returns the on-device OrtValue, which is sliced to the last
position and fed via the hidden_states input feeder. Verification runs [t, d] in
a single main forward; on reject the recurrent state is restored from a snapshot
and the correct token re-run.

Expose it through the C API (OgaCreateMtpGenerator, OgaMtpGenerator_AppendTokens
/ GenerateNextToken / IsDone / GetSequence{Count,Data} / Get{Forward,Accept,
Trial}Count / OgaDestroyMtpGenerator), the C++ wrapper, and Python
(og.MtpGenerator).

Verified: output matches plain greedy decoding (one prompt bit-exact through
draft accept/reject; others match modulo fp16 near-ties), ~84-94% accept,
1.6-1.8 tokens per main forward.

Fix: ArgmaxLogitsRow kept a CopyDeviceToCpu result from a temporary device span
whose pinned host buffer was freed at end of expression (dangling read); name
the span so it outlives the read.
…enchmarks

- examples/python/qwen-3.6-mtp.py: default to the built-in og.MtpGenerator; keep
  the pure-Python loop as ReferenceMtpGenerator behind --reference.
- examples/python/qwen-3.6-mtp.md: document og.MtpGenerator in the API table and
  run section; record wall-clock results (Python ref ~0.3x, in-engine ~1.0x
  break-even); explain the eager 2-token verify and the multi-gpu_graph_id fix;
  expand future work with embedding/lm_head memory sharing.
- qwen_3.6.md: add an 'MTP self-speculative decoding — implemented' experiment
  log with the shipped pieces, correctness, the wall-clock benchmark, per-step
  latency, the verify-shape graph-capture lever, and memory-saving future work.
Give each captured per-step input length its own CUDA-graph annotation id so the
1-token decode and the 2-token speculative verify each replay an independent
captured graph bound to their own static buffers (ORT CUDA EP keys captured
graphs by gpu_graph_id).

- State: graph_ids_ map (id per captured length, lazily generated);
  State::Run(session, capture?, capture_length) selects the id.
- DecoderOnly_State::Run captures shape[1] in [1, max_graph_capture_length].
- GeneratorParams::max_graph_capture_length (default 1; MtpGenerator sets 2).
- Tensor::CreateTensor gains static_capacity_bytes to pre-size the static buffer
  to the max captured shape so its base address is stable across both graphs.
- input_ids/position_ids/logits/hidden_states static I/O pre-size to 2 tokens.

Fixes the graph-on output corruption (mixed graphed-1-token / eager-2-token
shared static buffers) and recovers the eager-verify cost. With CUDA graph on,
the in-engine og.MtpGenerator goes ~115 -> ~127 tok/s; remaining gap vs the
(also graph-accelerated) baseline is host-side draft/verify orchestration.
The MTP draft/verify loop selected tokens on the host: cast the main model's
[1,S,248320] logits to fp32, copy whole rows (~2MB each, up to 3/step) to CPU,
std::max_element over 248K elements -- a synchronizing D2H plus a serial host
scan on the critical path.

Add DeviceInterface::ArgMax (k=1 top-1) backed by the existing high-performance
genai CUDA Top-K (distributed_select_sort, unbeatable for k=1), so only the
4-byte token id crosses the bus. The virtual is appended at the END of
DeviceInterface to keep the vtable layout / ABI stable. MtpGenerator uses it for
both the batched verify rows (logits@L and logits@L+1 in one launch) and the MTP
head draft; the host argmax remains as a fallback for devices without ArgMax.

Result (fp16, 1xH200, batch1, greedy, CUDA graph ON): in-engine og.MtpGenerator
~127 -> ~160 tok/s (128 tok) / ~177 tok/s (200 tok) = 1.14x / 1.24x vs the
graph-accelerated baseline (~140-143). Output coherent, accept unchanged
(numerically identical argmax). This crosses MTP from break-even into a real
wall-clock speedup.
On an accepted step the MTP head is run once on the accepted (hidden@L, d) pair
only to keep its KV cache aligned -- the next committed token comes from the
verify pass's row-1 argmax, not this draft. The draft's 248K-vocab argmax and
its stream sync were pure waste on ~86% of steps. Add a need_draft flag to
DraftNextToken; the post-accept call now runs KV-advance only (no GetLogits, no
ArgMax, no sync). Output unchanged; removes a provably-unnecessary full-vocab
argmax + stream sync per accepted step.
…rward

The accept path ran two MTP-head forwards per step: one to draft d from t, and a
second single-token forward only to fold the accepted d into the head's KV. Those
two feeds are sequential and consecutive -- (hidden@L, d) for the KV-advance and
(hidden@L+1, t_next) for the next draft -- so fuse them into a single 2-token MTP
forward (the MTP head is full-attention, so a 2-token forward is identical to two
1-token feeds). The next step's draft is computed ahead and stashed in
pending_draft_; the next GenerateNextToken reuses it. On reject the pending draft
is invalidated and the loop falls back to a single-token draft.

Made the MTP hidden_states INPUT feeder static-buffer aware (pre-sized to 2,
mirroring the output) so the head's 1- and 2-token graph captures keep stable
buffer addresses.

Removes one MTP forward per accepted step. Output bit-identical (same accept
stats). In-engine og.MtpGenerator: ~1.18x -> ~1.26x (128 tok), ~1.20-1.25x
(200 tok); speedup scales with accept rate.
…host)

The MTP head's embed_tokens and lm_head are bit-identical to the main model's
(~2 GB of mtp.onnx's ~3.8 GB in fp16). After saving both ONNX files, the builder
now redirects mtp.onnx's copies of those two initializers to the main model's
external data file (model.onnx.data, at the main tensor's offset/length) and
packs their bytes out of mtp.onnx.data, rebuilding it with tight offsets.

Sharing applies only to byte-identical tensors (same name/dtype/shape + matching
sampled bytes), so a quantized-main vs fp16-MTP lm_head mismatch is left
untouched; any failure is non-fatal (models stay valid, just larger).

Measured (fp16): mtp.onnx.data 3.79 GB -> 1.76 GB (-2.03 GB disk). Both sessions
resolve the two tensors from the same model.onnx.data, so the host mmaps a single
copy (host-RAM dedup). Output bit-identical, throughput unchanged (~1.23x@128).

Scope: saves disk + host RAM. GPU VRAM is unchanged -- ORT's public API uploads
each session's initializers to the device independently, so cross-session device
sharing is not exposed.
@tianleiwu tianleiwu requested a review from a team as a code owner June 11, 2026 16:04
Copilot AI review requested due to automatic review settings June 11, 2026 16:04
@tianleiwu tianleiwu marked this pull request as draft June 11, 2026 16:04

super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)

self.model_type = "Qwen3_5_Moe_textForCausalLM"

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds end-to-end support for Qwen3.6 “MTP” (multi-token prediction) self-speculative decoding across the model builder, runtime (C/C++), and Python bindings, alongside CUDA QMoE expert-weight quantization fixes/guards and CUDA-graph capture improvements for multi-shape replay.

Changes:

  • Add an in-engine MTP draft/verify generator (C API + C++ wrapper + Python bindings) and the required runtime primitives (hidden-states I/O staging, recurrent-state snapshot/restore, on-device argmax).
  • Extend model builder/export for Qwen3.6 MTP head generation (mtp.onnx) and config emission (genai_config.json mtp section + hidden_states wiring), plus weight dedup optimization.
  • Add CUDA QMoE quantization path selection (CUTLASS-prepacked vs raw) with regression tests and safer cache-dir cleanup.

Reviewed changes

Copilot reviewed 33 out of 33 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
test/python/models/test_qmoe_weights.py Adds regression tests for CUDA QMoE quantization dispatch/validation and signed scales behavior.
src/tensor.h Extends Tensor::CreateTensor API with optional static-capacity hint for CUDA-graph-stable buffers.
src/tensor.cpp Implements static buffer pre-sizing via static_capacity_bytes to keep addresses stable across captures.
src/smartptrs.h Adds DeviceInterface::ArgMax hook (default false) for on-device argmax.
src/python/python.cpp Exposes generator snapshot/hidden-states staging and adds MtpGenerator Python binding.
src/python/py/models/README.md Documents builder options to export Qwen3.6 MTP head.
src/python/py/models/builders/qwen.py Implements Qwen3.6 MTP head export, config augmentation, and embedding/lm_head external-data sharing.
src/python/py/models/builders/base.py Adds QMoE weights_prepacked tri-state plumbing and CUDA-specific quantization paths; improves cache-dir cleanup guard.
src/ort_genai.h Adds C++ wrapper methods for snapshot/hidden-states and a new OgaMtpGenerator wrapper.
src/ort_genai_c.h Introduces C API surface for MTP generator + snapshot/hidden-states APIs.
src/ort_genai_c.cpp Implements the new C API entrypoints and wires them to the C++ runtime.
src/mtp_generator.h Declares the in-engine MTP draft/verify orchestrator.
src/mtp_generator.cpp Implements MTP draft/verify loop, including on-device argmax usage and pipelined draft fusion.
src/models/recurrent_state.h Adds snapshot/restore APIs for recurrent state rollback in speculative decoding.
src/models/recurrent_state.cpp Implements recurrent-state snapshot/restore and updates partial rewind behavior/messages.
src/models/position_inputs.cpp Pre-sizes static position_ids buffers to max captured length for multi-shape CUDA graphs.
src/models/model.h Adds SnapshotState/SetHiddenStates hooks and multi-length CUDA graph-id support.
src/models/model.cpp Implements per-length CUDA graph annotation IDs and updates State::Run signature.
src/models/logits.cpp Pre-sizes static logits buffers for multi-length CUDA-graph capture.
src/models/input_ids.cpp Pre-sizes static input_ids (and int64 cast buffer) for multi-length CUDA-graph capture.
src/models/hidden_states_inputs.h Introduces managed hidden_states input/output helpers for MTP head staging and CUDA-graph-safe outputs.
src/models/hidden_states_inputs.cpp Implements hidden_states staging with D2D/CPU fallback and graph-capture static buffer sizing.
src/models/decoder_only.h Wires hidden_states input/output helpers into decoder-only state.
src/models/decoder_only.cpp Enables graph capture for lengths up to max_graph_capture_length and updates hidden_states I/O each step.
src/generators.h Adds max_graph_capture_length and generator methods for snapshot/hidden-states staging.
src/generators.cpp Implements Generator::SnapshotState and Generator::SetHiddenStates plumbing.
src/cuda/interface.cpp Implements CUDA ArgMax via Top-K (k=1) with cached scratch + pinned host staging.
src/config.h Adds hidden_states names and new model.mtp config section schema.
src/config.cpp Adds JSON parsing for decoder hidden_states and the mtp section.
qwen_3.6.md Adds a detailed optimization log / design and benchmark notes for Qwen3.6 + MTP.
examples/python/README.md Documents the new Qwen3.6 MTP example script.
examples/python/qwen-3.6-mtp.py Adds runnable Python example for built-in and reference MTP decoding.
examples/python/qwen-3.6-mtp.md Adds detailed design/export/runtime documentation for Qwen3.6 MTP in GenAI.

Comment on lines +307 to +313
# For CUDA QMoE the builder ships expert weights already CUTLASS-prepacked
# (offline via pack_weights_for_cuda_mixed_gemm, see make_qmoe_weights), so
# the QMoE op's default interpretation (weights_prepacked=-1/auto =
# prepacked) is exactly what we want and the attribute is omitted (None).
# Override via extra_options["qmoe_weights_prepacked"] (e.g. 0 to ship raw
# [E, N, K/pack] weights and let the runtime PrePack hook transform them).
weights_prepacked = int(extra_options["qmoe_weights_prepacked"]) if "qmoe_weights_prepacked" in extra_options else None
Comment on lines +2188 to +2202
# MTP (multi-token prediction) self-speculative head.
# When ``enable_mtp`` is set, an auxiliary ``mtp.onnx`` model is exported
# alongside the main model (see ``Qwen35MtpHead``). It is disabled for the
# MTP head itself (``is_mtp_head``) to avoid infinite recursion.
self.mtp_head = None
self.enable_mtp = bool(extra_options.get("enable_mtp", False)) and not getattr(self, "is_mtp_head", False)
if self.enable_mtp:
# Stash the constructor arguments so the MTP head can be built from a
# pristine config after the main model has been generated.
self._mtp_config = copy.deepcopy(config)
self._mtp_io_dtype = io_dtype
self._mtp_onnx_dtype = onnx_dtype
self._mtp_ep = ep
self._mtp_cache_dir = cache_dir
self._mtp_extra_options = copy.deepcopy(extra_options)
Comment thread src/mtp_generator.cpp
Comment on lines +27 to +39
MtpGenerator::MtpGenerator(const Model& main_model, const Model& mtp_model, const GeneratorParams& params)
: main_model_{main_model}, mtp_model_{mtp_model} {
// MTP runs both a 1-token decode and a 2-token verify on the main model. Allow CUDA graph
// capture of both shapes (each captured under its own annotation id with pre-sized static
// buffers). Harmless for the MTP head, which only ever runs a single token per step.
const_cast<GeneratorParams&>(params).max_graph_capture_length = 2;

main_ = CreateGenerator(main_model_, params);
mtp_ = CreateGenerator(mtp_model_, params);

hidden_size_ = main_model_.config_->model.decoder.hidden_size;
vocab_size_ = main_model_.config_->model.vocab_size;
max_length_ = params.search.max_length;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants