Add speculative decoding implementation (draft + target verification) (v0)#2233
Open
samsat701 wants to merge 11 commits into
Open
Add speculative decoding implementation (draft + target verification) (v0)#2233samsat701 wants to merge 11 commits into
samsat701 wants to merge 11 commits into
Conversation
Adds model.draft + speculative.max_draft_tokens, Set/GetSpeculativeNumber, C ABI/Python stats surface, and speculative_sampling.h math helpers (ComputeAcceptProb, BuildCorrectionDistribution, SampleFromDistribution).
Moves FindNucleus + ComputeSampledCategorical into sampling_distribution.h; SampleTopK/TopP/TopKTopP now call it.
Moves GenerateNextToken's per-token logic out of Generator into DecodingStrategy objects (StandardDecodingStrategy + TransducerDecodingStrategy), selected via MakeDecodingStrategy. Adds GetSpeculativeStats() hook (zero for non-speculative). Behavior-preserving; existing CPU sampling tests pass.
Adds SpeculativeDecodingModel, a target + draft decoder-only pair sharing one tokenizer/search. Registered via model.type == speculative in CreateModel.
Adds SpeculativeDecodingStrategy (draft proposes K tokens, target verifies in one pass, accepted/correction/bonus tokens are buffered and emitted one per GenerateNextToken) plus the BaseSpeculativeStrategy implementation, wired in via MakeDecodingStrategy. Step now commits one token per call with a Reset() hook on rewind. Adds C++ and Python speculative decoding tests.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR introduces speculative decoding for decoder-only LLMs in ONNX Runtime GenAI by composing a draft + target model, adding shared sampling utilities, and refactoring Generator::GenerateNextToken() into a pluggable decoding-strategy dispatch layer.
Changes:
- Add speculative decoding model/state/strategy (draft propose → target verify → accept/correct/bonus → KV re-anchor) plus stats instrumentation surfaced through C/C++/Python APIs.
- Refactor standard and transducer generation into
DecodingStrategyimplementations, routingGenerateNextToken()through a strategy factory. - Centralize top-k/top-p distribution construction in
sampling_distribution.hand add unit + Python e2e coverage for speculative sampling and guards.
Reviewed changes
Copilot reviewed 26 out of 26 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| test/speculative_sampling_tests.cpp | Adds C++ unit tests for acceptance probability, correction distribution, sampling, and shared distribution-building helpers. |
| test/python/test_speculative_decoding.py | Adds Python e2e tests for config/state guards and self-speculative generation + stats/rewind behavior. |
| src/decoding_strategy.h | Introduces the DecodingStrategy interface and MakeDecodingStrategy() factory. |
| src/decoding_strategy.cpp | Implements strategy factory routing for transducer/speculative/standard decoding. |
| src/standard_decoding_strategy.h | Declares standard single-token decoding strategy. |
| src/standard_decoding_strategy.cpp | Moves classic GenerateNextToken logic into StandardDecodingStrategy::Step. |
| src/transducer_decoding_strategy.h | Declares transducer decoding strategy for RNNT/TDT. |
| src/transducer_decoding_strategy.cpp | Implements transducer per-token stepping via TransducerState. |
| src/speculative_sampling.h | Adds speculative sampling helpers (softmax, acceptance prob, correction dist, sampling). |
| src/sampling_distribution.h | Adds shared sparse/dense sampling distribution utilities and nucleus selection. |
| src/search.cpp | Refactors CPU sampling to use ComputeSampledCategorical for top-k/top-p/top-k+top-p. |
| src/config.h | Adds model.draft decoder block and speculative.max_draft_tokens config. |
| src/config.cpp | Parses speculative.max_draft_tokens and adds runtime setter SetSpeculativeNumber. |
| src/models/speculative_decoding.h | Adds speculative model/state definitions composing target + draft DecoderOnly_Model. |
| src/models/speculative_decoding.cpp | Implements speculative model/state construction, guardrails, and prefill behavior. |
| src/models/model.cpp | Wires model.type == "speculative" to construct SpeculativeDecodingModel. |
| src/speculative_decoding_strategy.h | Adds base speculative decoding strategy interface and shared round/state buffering. |
| src/speculative_decoding_strategy.cpp | Implements the round logic (verify, accept/correct/bonus, buffering, re-anchor) + stats. |
| src/base_speculative_strategy.h | Declares a concrete speculative strategy using draft decoder propose + state advance. |
| src/base_speculative_strategy.cpp | Implements draft propose/advance logic and pending draft-prob carry-over. |
| src/generators.h | Adds SpeculativeStats, speculative params accessors, and strategy_ member. |
| src/generators.cpp | Routes GenerateNextToken() through strategy_, adds speculative param accessors/stats plumbing, resets strategy on rewind. |
| src/ort_genai_c.h | Adds C ABI APIs for speculative options + OgaSpeculativeStats + stats getter. |
| src/ort_genai_c.cpp | Implements C ABI plumbing for speculative params + generator stats. |
| src/ort_genai.h | Adds C++ wrapper methods for speculative options and stats. |
| src/python/python.cpp | Exposes speculative options + stats through pybind (set_speculative_options, get_speculative_options, get_speculative_stats). |
Author
|
@microsoft-github-policy-service agree company="Microsoft" |
When a prompt length equals max_length, K clamps to 0 after the K>=1 check, causing out-of-bounds access in RunRound/Propose (tokens[0], target_dists[K-1]). Treat K<=0 as done, matching standard decoding's behavior at max_length. Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
In greedy mode the accept/correction/bonus logic only uses the argmax; the full softmax is wasted work (argmax(softmax(x)) == argmax(x)). Use argmax directly to cut round overhead. Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Compare the full ProviderOptions contents (options key/values and device_filtering_options), not just names, so mismatched device/EP configs are rejected at construction. Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
effective_speedup was committed/rounds (tokens per round), not a speedup. Replace with speedup = E[tok/round] / (1 + k*(T_draft/T_target) + x)
…/microsoft/onnxruntime-genai into t-sshrimali/speculative-decoding # Please enter a commit message to explain why this merge is necessary, # especially if it merges an updated upstream into a topic branch. # # Lines starting with '#' will be ignored, and an empty message aborts # the commit.
Each round ran the target twice: verify K draft tokens, then a second single-token pass to commit the kept token. Fold that away by carrying the kept token into the next round's verify batch (width K -> K+1), so the target runs once per round.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR adds speculative decoding to ONNX Runtime GenAI, decoder-only LLMs. A small draft model proposes K candidate tokens autoregressively (cheap); the large target model verifies all K in one parallel forward pass; a rejection-sampling acceptance rule keeps the output distributionally identical to running the target alone. Literature reported speedups are 2-3x for well-matched pairs.
A round consists of:
Each round commits 1...K+1 tokens from a single target call. A round is computed up front and tokens are buffered and emitted one per GenerateNextToken().
Public API/config:
Main additions:
Motivation and Context
Currently, LLM inference latency is based on per-token forward pass of a large model. For autoregressive generation of n tokens, we have n expensive forward passes. Many of the tokens are “easy” to predict. A small model would predict the same next token as a large one. We pay full per-token cost for the target model even when a much cheaper model would have produced the same token. ORT-GenAI has no current support for this technique.
Known Issues
All guarded with a clear throw (fail fast, never silently wrong). Deliberate v0 cuts:
Testing