Skip to content

Add speculative decoding implementation (draft + target verification) (v0)#2233

Open
samsat701 wants to merge 11 commits into
mainfrom
t-sshrimali/speculative-decoding
Open

Add speculative decoding implementation (draft + target verification) (v0)#2233
samsat701 wants to merge 11 commits into
mainfrom
t-sshrimali/speculative-decoding

Conversation

@samsat701

Copy link
Copy Markdown

Description

This PR adds speculative decoding to ONNX Runtime GenAI, decoder-only LLMs. A small draft model proposes K candidate tokens autoregressively (cheap); the large target model verifies all K in one parallel forward pass; a rejection-sampling acceptance rule keeps the output distributionally identical to running the target alone. Literature reported speedups are 2-3x for well-matched pairs.

A round consists of:

  1. Propose - draft produces K tokens autoregressively
  2. Verify - target scores all K in one pass (K+1 distributions)
  3. Accept / correct / bonus - walk draft tokens left to right:
  • Greedy: accept token i if it equals target's argmax at position i
  • Sampling: accept with probability min(1, p_target/p_draft)
  • First rejection -> correction token from normalize(max(0, p_target-p_draft))
  • All K accepted -> 1 bonus token from the target's trailing distribution
  1. Re-anchor: rewind target KV cache to accepted prefix, advance on committed token

Each round commits 1...K+1 tokens from a single target call. A round is computed up front and tokens are buffered and emitted one per GenerateNextToken().

Public API/config:

  • Model type model.type = "speculative" + sibling model.draft decoder block
  • speculative.max_draft_tokens (K), default 4, range [1,16]
  • C ABI: OgaGeneratorParams{Set,Get}SpeculativeNumber, OgaGenerator_GetSpeculativeStats + OgaSpeculativeStats
  • C++: OgaGeneratorParams::SetSpeculativeOption/GetSpeculativeNumber, OgaGenerator::GetSpeculativeStats
  • Python: params.set_speculative_options(...)/get_speculative_options(), generator.get_speculative_stats()
  • Stats: rounds, draft proposed/accepted, correction/bonus counts, avg draft & target ms/token, acceptance rate, effective tokens per round

Main additions:

  • Config, API, sampling math: speculative_sampling.h (ComputeAcceptProb, BuildCorrectionDistribution, SampleFromDistribution, Softmax, SamplingDistributionFromProbs) + the config/ABI/Python surface
  • Refactor sampling helpers: nucleus + top-k/top-p moved from search.cpp to shared sampling_distribution.h (ComputeSampledCategorical, SampledCategorical, ScatterToFullVocab); standard decode is behavior-preserving
  • Pluggable DecodingStrategy: GenerateNextToken()'s body extracted into StandardDecodingStrategy/TransducerDecodingStrategy, selected via MakeDecodingStrategy; adds GetStats() and Reset() hooks. Behavior-preserving.
  • SpeculativeDecodingModel/State: composes target + draft DecoderOnly_Model (each its own cloned Config), shared tokenizer/search; Run() does prefill and caches the draft's pending distribution
  • SpeculativeDecodingStrategy + BaseSpeculativeStrategy + tests: the propose->verify->accept->re-anchor loop, one-token-per-call buffering

Motivation and Context

Currently, LLM inference latency is based on per-token forward pass of a large model. For autoregressive generation of n tokens, we have n expensive forward passes. Many of the tokens are “easy” to predict. A small model would predict the same next token as a large one. We pay full per-token cost for the target model even when a much cheaper model would have produced the same token. ORT-GenAI has no current support for this technique.

Known Issues

All guarded with a clear throw (fail fast, never silently wrong). Deliberate v0 cuts:

  • Pipeline/multimodal model support
  • Guidance/constrained decoding
  • repetition_penalty != 1.0 and min_length > 0
  • Sliding-window / LFM2 (hybrid) KV cache / legacy combined KV exports (requires separate KV cache format)
  • batch_size > 1, num_beams > 1
  • Continuous decoding (RewindToLength -> GenerateNextToken isn't supported yet, current workaround is rewind -> AppendTokens (re-prefill) -> GenerateNextToken)
  • CPU is main target provider for v0, robust GPU/NPU functionality testing will be conducted in v1
  • Cross EP draft/target (eg. draft on CPU, target on GPU) (currently must use the same EP)

Testing

  • C++ unit tests (speculative_sampling_tests.cpp): accept-prob, correction distribution, CDF sampling, shared ComputeSampledCategorical
  • Python e2e tests (test_speculative_decoding.py): config guards, self-speculative generation (draft==target so output is comparable to plain greedy), state guards, rewind

Adds model.draft + speculative.max_draft_tokens, Set/GetSpeculativeNumber, C ABI/Python stats surface, and speculative_sampling.h math helpers (ComputeAcceptProb, BuildCorrectionDistribution, SampleFromDistribution).
Moves FindNucleus + ComputeSampledCategorical into sampling_distribution.h; SampleTopK/TopP/TopKTopP now call it.
Moves GenerateNextToken's per-token logic out of Generator into DecodingStrategy objects (StandardDecodingStrategy + TransducerDecodingStrategy), selected via MakeDecodingStrategy. Adds GetSpeculativeStats() hook (zero for non-speculative). Behavior-preserving; existing CPU sampling tests pass.
Adds SpeculativeDecodingModel, a target + draft decoder-only pair sharing one tokenizer/search. Registered via model.type == speculative in CreateModel.
Adds SpeculativeDecodingStrategy (draft proposes K tokens, target verifies in one pass, accepted/correction/bonus tokens are buffered and emitted one per GenerateNextToken) plus the BaseSpeculativeStrategy implementation, wired in via MakeDecodingStrategy. Step now commits one token per call with a Reset() hook on rewind. Adds C++ and Python speculative decoding tests.
@samsat701 samsat701 requested a review from Copilot June 19, 2026 02:17
@samsat701 samsat701 requested a review from a team as a code owner June 19, 2026 02:17

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces speculative decoding for decoder-only LLMs in ONNX Runtime GenAI by composing a draft + target model, adding shared sampling utilities, and refactoring Generator::GenerateNextToken() into a pluggable decoding-strategy dispatch layer.

Changes:

  • Add speculative decoding model/state/strategy (draft propose → target verify → accept/correct/bonus → KV re-anchor) plus stats instrumentation surfaced through C/C++/Python APIs.
  • Refactor standard and transducer generation into DecodingStrategy implementations, routing GenerateNextToken() through a strategy factory.
  • Centralize top-k/top-p distribution construction in sampling_distribution.h and add unit + Python e2e coverage for speculative sampling and guards.

Reviewed changes

Copilot reviewed 26 out of 26 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
test/speculative_sampling_tests.cpp Adds C++ unit tests for acceptance probability, correction distribution, sampling, and shared distribution-building helpers.
test/python/test_speculative_decoding.py Adds Python e2e tests for config/state guards and self-speculative generation + stats/rewind behavior.
src/decoding_strategy.h Introduces the DecodingStrategy interface and MakeDecodingStrategy() factory.
src/decoding_strategy.cpp Implements strategy factory routing for transducer/speculative/standard decoding.
src/standard_decoding_strategy.h Declares standard single-token decoding strategy.
src/standard_decoding_strategy.cpp Moves classic GenerateNextToken logic into StandardDecodingStrategy::Step.
src/transducer_decoding_strategy.h Declares transducer decoding strategy for RNNT/TDT.
src/transducer_decoding_strategy.cpp Implements transducer per-token stepping via TransducerState.
src/speculative_sampling.h Adds speculative sampling helpers (softmax, acceptance prob, correction dist, sampling).
src/sampling_distribution.h Adds shared sparse/dense sampling distribution utilities and nucleus selection.
src/search.cpp Refactors CPU sampling to use ComputeSampledCategorical for top-k/top-p/top-k+top-p.
src/config.h Adds model.draft decoder block and speculative.max_draft_tokens config.
src/config.cpp Parses speculative.max_draft_tokens and adds runtime setter SetSpeculativeNumber.
src/models/speculative_decoding.h Adds speculative model/state definitions composing target + draft DecoderOnly_Model.
src/models/speculative_decoding.cpp Implements speculative model/state construction, guardrails, and prefill behavior.
src/models/model.cpp Wires model.type == "speculative" to construct SpeculativeDecodingModel.
src/speculative_decoding_strategy.h Adds base speculative decoding strategy interface and shared round/state buffering.
src/speculative_decoding_strategy.cpp Implements the round logic (verify, accept/correct/bonus, buffering, re-anchor) + stats.
src/base_speculative_strategy.h Declares a concrete speculative strategy using draft decoder propose + state advance.
src/base_speculative_strategy.cpp Implements draft propose/advance logic and pending draft-prob carry-over.
src/generators.h Adds SpeculativeStats, speculative params accessors, and strategy_ member.
src/generators.cpp Routes GenerateNextToken() through strategy_, adds speculative param accessors/stats plumbing, resets strategy on rewind.
src/ort_genai_c.h Adds C ABI APIs for speculative options + OgaSpeculativeStats + stats getter.
src/ort_genai_c.cpp Implements C ABI plumbing for speculative params + generator stats.
src/ort_genai.h Adds C++ wrapper methods for speculative options and stats.
src/python/python.cpp Exposes speculative options + stats through pybind (set_speculative_options, get_speculative_options, get_speculative_stats).

Comment thread src/speculative_decoding_strategy.cpp
Comment thread src/models/speculative_decoding.cpp
Comment thread src/speculative_decoding_strategy.cpp Outdated
@samsat701

Copy link
Copy Markdown
Author

@microsoft-github-policy-service agree company="Microsoft"

samsat701 and others added 6 commits June 19, 2026 13:42
When a prompt length equals max_length, K clamps to 0 after the K>=1 check, causing out-of-bounds access in RunRound/Propose (tokens[0], target_dists[K-1]). Treat K<=0 as done, matching standard decoding's behavior at max_length.

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
In greedy mode the accept/correction/bonus logic only uses the argmax; the full softmax is wasted work (argmax(softmax(x)) == argmax(x)). Use argmax directly to cut round overhead.

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Compare the full ProviderOptions contents (options key/values and device_filtering_options), not just names, so mismatched device/EP configs are rejected at construction.

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
effective_speedup was committed/rounds (tokens per round), not a speedup. Replace with speedup = E[tok/round] / (1 + k*(T_draft/T_target) + x)
…/microsoft/onnxruntime-genai into t-sshrimali/speculative-decoding

# Please enter a commit message to explain why this merge is necessary,
# especially if it merges an updated upstream into a topic branch.
#
# Lines starting with '#' will be ignored, and an empty message aborts
# the commit.
Each round ran the target twice: verify K draft tokens, then a second single-token pass to commit the kept token. Fold that away by carrying the kept token into the next round's verify batch (width K -> K+1), so the target runs once per round.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants