Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
3b0830e
Fix CUDA QMoE INT4 export for Qwen3.5/3.6 MoE models
tianleiwu Jun 9, 2026
0c8d4e5
Address review feedback on CUDA QMoE INT4 export
tianleiwu Jun 10, 2026
3e54917
Add Qwen3.6 MTP head export to the model builder
tianleiwu Jun 11, 2026
695277e
Add MTP sub-model section to genai config
tianleiwu Jun 11, 2026
e6dbc9c
Add recurrent-state snapshot/restore for speculative decoding
tianleiwu Jun 11, 2026
a646df6
Add in-engine hidden_states input feeder for the MTP head
tianleiwu Jun 11, 2026
fa8ead7
Add Qwen3.6 MTP self-speculative decoding example and design doc
tianleiwu Jun 11, 2026
e14265e
Make the hidden_states output CUDA-graph-safe
tianleiwu Jun 11, 2026
bee4c80
Feed hidden_states device-to-device on the shared stream when possible
tianleiwu Jun 11, 2026
2c01367
Document MTP synchronization and CUDA-graph model
tianleiwu Jun 11, 2026
1872a44
Add in-engine MtpGenerator (C++ draft/verify orchestrator)
tianleiwu Jun 11, 2026
107a813
update example with MtpGenerator
tianleiwu Jun 11, 2026
a21ce17
Update Qwen3.6 MTP example/docs for in-engine MtpGenerator + record b…
tianleiwu Jun 11, 2026
9506249
Graph-capture the 2-token MTP verify shape
tianleiwu Jun 11, 2026
2f4d193
MTP: on-device argmax via genai Top-K kernel (no full-logits D2H)
tianleiwu Jun 11, 2026
152ebd7
MTP: skip the wasted argmax on the post-accept KV-advance draft
tianleiwu Jun 11, 2026
13988da
MTP: fuse post-accept KV-advance + next draft into one 2-token MTP fo…
tianleiwu Jun 11, 2026
3e5962b
MTP: share embedding + lm_head between main model and MTP head (disk/…
tianleiwu Jun 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ python model-mm.py -m {path to model folder} -e {execution provider}
python model-mm.py -m {path to model folder} -e {execution provider} --image_paths image1.jpg image2.jpg --non_interactive
```

```bash
# The `qwen-3.6-mtp` script runs Qwen3.6 with its multi-token-prediction (MTP) head for
# self-speculative decoding. See qwen-3.6-mtp.md for export instructions and design details.
python qwen-3.6-mtp.py -m {path to main model folder} -d {path to MTP head folder}
```

## Execution Providers

The ONNX Runtime GenAI Python package supports the following execution providers (EPs):
Expand Down
312 changes: 312 additions & 0 deletions examples/python/qwen-3.6-mtp.md

Large diffs are not rendered by default.

209 changes: 209 additions & 0 deletions examples/python/qwen-3.6-mtp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Qwen3.6 MTP (multi-token prediction) self-speculative decoding example.

The Qwen3.6-A3B model ships a built-in MTP head: a single extra decoder layer that,
given the main model's last hidden state and the just-emitted token, predicts the
*next-next* token. Used as a draft model, it lets the main model verify two tokens
per forward pass and accept the draft when it matches greedy decoding -- a lossless
speedup over plain autoregressive decoding.

Two ways to run it are shown:

* the built-in ``og.MtpGenerator`` (default) -- a first-class C++ generator that runs
the whole draft/verify loop in-engine, keeping the hidden-state handoff on-device
(no per-step host round-trip). This is the recommended, fastest path.
* ``--reference`` -- an equivalent hand-rolled Python loop (``ReferenceMtpGenerator``)
that drives two ``og.Model`` instances through the public API. It documents the
algorithm and is useful for experimentation, but is slower (per-step Python +
host round-trips).

Both require:
* the main Qwen3.6 decoder exported with ``include_hidden_states=true`` (so it exposes
a ``hidden_states`` output), and
* the MTP head (``mtp.onnx``, exported with ``enable_mtp=true``), loaded as a standalone
model whose ``hidden_states`` input is fed the main model's last hidden state.

See ``qwen-3.6-mtp.md`` for the design and export instructions.
"""

import argparse
import time

import numpy as np
import onnxruntime_genai as og


class ReferenceMtpGenerator:
"""Reference (educational) MTP self-speculative decoder in pure Python (greedy, 1 draft
token). It mirrors what the built-in ``og.MtpGenerator`` does in C++, but drives two
``og.Model`` generators through the public API. Prefer ``og.MtpGenerator`` in production.

Wraps a main-model generator and an MTP-head generator and exposes a simple
``generate(prompt_tokens, max_new_tokens)`` that returns the decoded tokens.
"""

def __init__(self, main_model: og.Model, mtp_model: og.Model, max_length: int = 4096):
self.main_model = main_model
self.mtp_model = mtp_model
self.max_length = max_length

def _new_main(self) -> og.Generator:
params = og.GeneratorParams(self.main_model)
params.set_search_options(max_length=self.max_length, do_sample=False)
return og.Generator(self.main_model, params)

def _mtp_draft(self, hidden_context, token_context) -> int:
"""Run the MTP head over the accumulated (hidden_state, token) context and
return its predicted next token. ``hidden_context`` holds, for each position
i, the main model's hidden state h_i; ``token_context`` holds t_{i+1}."""
params = og.GeneratorParams(self.mtp_model)
params.set_search_options(max_length=self.max_length, do_sample=False)
draft = og.Generator(self.mtp_model, params)
hidden = np.stack(hidden_context).astype(np.float16)[None, :, :]
draft.set_hidden_states(hidden)
draft.append_tokens(np.asarray(token_context, dtype=np.int32))
logits = np.asarray(draft.get_output("logits"))[0]
return int(logits[-1].argmax(-1))

def generate(self, prompt_tokens, max_new_tokens):
"""Greedy self-speculative generation. Returns (tokens, stats)."""
gen = self._new_main()
gen.append_tokens(np.asarray(prompt_tokens, dtype=np.int32))
length = len(prompt_tokens) # number of tokens already committed to the KV cache

logits = np.asarray(gen.get_output("logits"))
hidden = np.asarray(gen.get_output("hidden_states"))
token = int(logits[0, -1].argmax(-1)) # token predicted for position `length`
h = hidden[0, -1].astype(np.float16) # hidden state that produced `token`

out_tokens = []
forwards = 1
accepts = trials = 0
hidden_context, token_context = [], []

while len(out_tokens) < max_new_tokens:
out_tokens.append(token)
hidden_context.append(h)
token_context.append(token)
if len(out_tokens) >= max_new_tokens:
break

# 1. Draft the next token with the MTP head.
draft = self._mtp_draft(hidden_context, token_context)

# 2. Snapshot the recurrent state, then verify [token, draft] in one forward.
gen.snapshot_state()
gen.append_tokens(np.array([token, draft], dtype=np.int32))
forwards += 1
v_logits = np.asarray(gen.get_output("logits")) # [1, 2, V]
v_hidden = np.asarray(gen.get_output("hidden_states")) # [1, 2, H]
main_next = int(v_logits[0, 0].argmax(-1)) # main model's real token after `token`

trials += 1
if draft == main_next:
# 3a. Accept: the draft was correct -> two tokens this step, plus a
# free third prediction from the verify pass.
accepts += 1
out_tokens.append(draft)
hidden_context.append(v_hidden[0, 0].astype(np.float16))
token_context.append(draft)
token = int(v_logits[0, 1].argmax(-1))
h = v_hidden[0, 1].astype(np.float16)
length += 2
else:
# 3b. Reject: roll back the speculative forward (KV crop + recurrent
# snapshot restore), then commit only the correct token.
gen.rewind_to(length)
gen.append_tokens(np.array([token], dtype=np.int32))
forwards += 1
r_logits = np.asarray(gen.get_output("logits"))
r_hidden = np.asarray(gen.get_output("hidden_states"))
token = int(r_logits[0, -1].argmax(-1))
h = r_hidden[0, -1].astype(np.float16)
length += 1

out_tokens = out_tokens[:max_new_tokens]
stats = {
"forwards": forwards,
"accepts": accepts,
"trials": trials,
"accept_rate": accepts / max(trials, 1),
"tokens_per_forward": len(out_tokens) / forwards,
}
return out_tokens, stats


def run_builtin(main_model, mtp_model, tokenizer, prompt_tokens, args):
"""Run the built-in in-engine og.MtpGenerator (recommended path)."""
params = og.GeneratorParams(main_model)
params.set_search_options(max_length=args.max_length, do_sample=False)
gen = og.MtpGenerator(main_model, mtp_model, params)
n_prompt = len(prompt_tokens)

gen.append_tokens(np.asarray(prompt_tokens, dtype=np.int32))
start = time.perf_counter()
while not gen.is_done() and len(gen.get_sequence()) < n_prompt + args.max_new_tokens:
gen.generate_next_token()
elapsed = time.perf_counter() - start

tokens = gen.get_sequence().tolist()[n_prompt:]
s = gen.get_stats()
stats = {
"forwards": s["forwards"], "accepts": s["accepts"], "trials": s["trials"],
"accept_rate": s["accepts"] / max(s["trials"], 1),
"tokens_per_forward": len(tokens) / max(s["forwards"], 1),
}
return tokens, stats, elapsed


def main(args):
print("Loading main model...")
main_model = og.Model(args.main_model_path)
tokenizer = og.Tokenizer(main_model)
print("Loading MTP head...")
mtp_model = og.Model(args.mtp_model_path)

reference = ReferenceMtpGenerator(main_model, mtp_model, max_length=args.max_length) if args.reference else None

prompts = args.prompts or [
"Explain how photosynthesis works in plants, step by step.",
]
for prompt in prompts:
text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
prompt_tokens = tokenizer.encode(text)

if args.reference:
start = time.perf_counter()
tokens, stats = reference.generate(prompt_tokens, args.max_new_tokens)
elapsed = time.perf_counter() - start
else:
tokens, stats, elapsed = run_builtin(main_model, mtp_model, tokenizer, prompt_tokens, args)

print("\n" + "=" * 80)
print(f"Prompt: {prompt}")
print(tokenizer.decode(tokens))
print("-" * 80)
print(
f"accept rate: {stats['accept_rate']:.1%} "
f"({stats['accepts']}/{stats['trials']}) | "
f"tokens/forward: {stats['tokens_per_forward']:.2f} | "
f"{len(tokens)} tokens in {stats['forwards']} forwards, "
f"{len(tokens) / elapsed:.1f} tok/s"
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Qwen3.6 MTP self-speculative decoding")
parser.add_argument("-m", "--main_model_path", required=True,
help="Path to the main model folder (exported with include_hidden_states=true)")
parser.add_argument("-d", "--mtp_model_path", required=True,
help="Path to the MTP head model folder (mtp.onnx + a genai_config.json declaring its hidden_states input)")
parser.add_argument("-n", "--max_new_tokens", type=int, default=128,
help="Number of tokens to generate per prompt")
parser.add_argument("--max_length", type=int, default=4096, help="Max sequence length")
parser.add_argument("-p", "--prompts", nargs="*", default=None, help="Prompt(s) to run")
parser.add_argument("--reference", action="store_true",
help="Use the pure-Python ReferenceMtpGenerator instead of the built-in og.MtpGenerator")
main(parser.parse_args())
Loading
Loading