Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions src/base_speculative_strategy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "base_speculative_strategy.h"

#include <algorithm>
#include <stdexcept>

#include "generators.h"
#include "search.h"
#include "speculative_sampling.h"
#include "models/speculative_decoding.h"

namespace Generators {

BaseSpeculativeStrategy::BaseSpeculativeStrategy(Generator& g)
: spec_state_{*dynamic_cast<SpeculativeDecodingState*>(g.state_.get())} {}

// Propose K draft tokens.
// Greedy: argmax, probs empty.
// Sampling: token i drawn from draft's truncated dist q_i (saved in probs[i] for the skeleton's min(1, p_i/q_i) test). d_0 reuses
// draft_pending_probs_, so only d_1..d_{K-1} run -> ~N*(K-1) passes, not N*K.
SpeculativeDecodingStrategy::Proposal BaseSpeculativeStrategy::Propose(
Generator& g, int K, int seed_length, const SamplingConfig& sampling) {
if (!spec_state_.draft_pending_valid())
throw std::runtime_error(
"BaseSpeculativeStrategy::Propose: draft pending probs not initialized. "
"AppendTokens must be called before GenerateNextToken.");

const auto& params = *g.search_->params_;
const int vocab_size = params.config.model.vocab_size;

Proposal proposal;
proposal.tokens.resize(K);
if (!sampling.greedy)
// greedy-match leaves probs empty
proposal.probs.resize(K);

auto argmax = [](std::span<const float> v) {
return static_cast<int32_t>(std::max_element(v.begin(), v.end()) - v.begin());
};

// d_0 from the carried-over pending probs.
if (sampling.greedy) {
proposal.tokens[0] = argmax(spec_state_.draft_pending_probs());
} else {
proposal.probs[0] = SamplingDistributionFromProbs(
spec_state_.draft_pending_probs(), sampling.top_k, sampling.top_p,
sampling.temperature);
proposal.tokens[0] = static_cast<int32_t>(
SampleFromDistribution({proposal.probs[0].data(), static_cast<size_t>(vocab_size)}, rng_));
}

// d_1..d_{K-1}: feed the previous draft token through the draft model.
auto single_buf = params.p_device->Allocate<int32_t>(1);
SampledCategorical sampled;
for (int i = 1; i < K; i++) {
single_buf.CpuSpan()[0] = proposal.tokens[i - 1];
single_buf.CopyCpuToDevice();
auto lgt = spec_state_.draft_state().Run(seed_length + i, single_buf, {});
auto cpu = lgt.CopyDeviceToCpu();
std::span<const float> logits{cpu.data(), static_cast<size_t>(vocab_size)};
if (sampling.greedy) {
proposal.tokens[i] = argmax(logits);
} else {
ComputeSampledCategorical(logits, sampling.top_k, sampling.top_p,
sampling.temperature, sampled);
proposal.probs[i] = ScatterToFullVocab(sampled, vocab_size);
proposal.tokens[i] = static_cast<int32_t>(
SampleFromDistribution({proposal.probs[i].data(), static_cast<size_t>(vocab_size)}, rng_));
}
}

return proposal;
}

// Re-sync draft to the committed length, then advance on final_token (its logits = next round's
// d_0). n_direct == K: all accepted, advance once more (bonus), no rewind; else rewind draft to
// seed_length + n_direct.
void BaseSpeculativeStrategy::Advance(Generator& g,
const Proposal& proposal,
int n_direct,
int32_t final_token,
int seed_length) {
const auto& params = *g.search_->params_;
const int vocab_size = params.config.model.vocab_size;
// Derive K from the proposal, not config: Step may have clamped K against max_length,
// and the draft cache was advanced by that clamped K. Avoid wrong sync.
const int K = static_cast<int>(proposal.tokens.size());

auto single_buf = params.p_device->Allocate<int32_t>(1);

int draft_kv_len = seed_length + K - 1;
const int rewind_to = seed_length + n_direct;

if (n_direct == K) {
// All proposed tokens accepted: catch draft up to where the K-th token
// would have advanced it (one extra step on last proposed token).
single_buf.CpuSpan()[0] = proposal.tokens[K - 1];
single_buf.CopyCpuToDevice();
spec_state_.draft_state().Run(seed_length + K, single_buf, {});
draft_kv_len = seed_length + K;
}

if (rewind_to < draft_kv_len)
spec_state_.draft_state().RewindTo(rewind_to);

// Advance one step on final_token; the resulting logits feed next round's d_0.
single_buf.CpuSpan()[0] = final_token;
single_buf.CopyCpuToDevice();
auto draft_lgt = spec_state_.draft_state().Run(seed_length + n_direct + 1, single_buf, {});
auto cpu_draft = draft_lgt.CopyDeviceToCpu();
spec_state_.set_draft_pending_probs(
Softmax({cpu_draft.data(), static_cast<size_t>(vocab_size)}));
}

} // namespace Generators
30 changes: 30 additions & 0 deletions src/base_speculative_strategy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "speculative_decoding_strategy.h"

namespace Generators {

struct SpeculativeDecodingState;

// BaseSpeculativeStrategy
// Draft decoder proposes, the target decoder verifies. Greedy mode uses argmax-match; sampling mode
// samples draft from its distribution q and accepts with u < min(1, p_t/p_d).
// All counters and the propose -> verify -> commit -> re-anchor skeleton live in the base.
struct BaseSpeculativeStrategy final : SpeculativeDecodingStrategy {
explicit BaseSpeculativeStrategy(Generator& g);

protected:
Proposal Propose(Generator& g, int K, int seed_length,
const SamplingConfig& sampling) override;
void Advance(Generator& g,
const Proposal& proposal,
int n_direct,
int32_t final_token,
int seed_length) override;

private:
SpeculativeDecodingState& spec_state_;
};

} // namespace Generators
38 changes: 38 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,9 @@ struct Model_Element : JSON::Element {
if (name == "decoder") {
return decoder_;
}
if (name == "draft") {
return draft_;
}
if (name == "vision") {
return vision_;
}
Expand All @@ -1206,6 +1209,7 @@ struct Model_Element : JSON::Element {
Config::Model& v_;
Encoder_Element encoder_{v_.encoder};
Decoder_Element decoder_{v_.decoder};
Decoder_Element draft_{v_.draft};
Int_Array_Element eos_token_id_{v_.eos_token_id};
Int_Array_Element tdt_durations_{v_.tdt_durations};
Vision_Element vision_{v_.vision};
Expand Down Expand Up @@ -1295,6 +1299,30 @@ struct Search_Element : JSON::Element {
Config::Search& v_;
};

struct Speculative_Element : JSON::Element {
explicit Speculative_Element(Config::Speculative& v) : v_{v} {}

// K (draft tokens per round) must be within [kMinK, kMaxK].
static constexpr int kMinK = 1;
static constexpr int kMaxK = 16;

void OnValue(std::string_view name, JSON::Value value) override {
if (name == "max_draft_tokens") {
int k = static_cast<int>(JSON::Get<double>(value));
if (k < kMinK || k > kMaxK)
throw std::runtime_error(
"speculative.max_draft_tokens must be between " + std::to_string(kMinK) + " and " +
std::to_string(kMaxK) + " Got: " + std::to_string(k) + ".");
v_.max_draft_tokens = k;
} else {
throw JSON::unknown_value_error{};
}
}

private:
Config::Speculative& v_;
};

struct DynamicBatching_Element : JSON::Element {
explicit DynamicBatching_Element(std::optional<Config::Engine::DynamicBatching>& v) : v_{v} {}

Expand Down Expand Up @@ -1372,6 +1400,14 @@ void SetSearchBool(Config::Search& search, std::string_view name, bool value) {
}
}

void SetSpeculativeNumber(Config::Speculative& speculative, std::string_view name, double value) {
try {
Speculative_Element(speculative).OnValue(name, value);
} catch (...) {
JSON::TranslateException(name);
}
}

void ClearProviders(Config& config) {
config.model.decoder.session_options.providers.clear();
}
Expand Down Expand Up @@ -1540,13 +1576,15 @@ struct Root_Element : JSON::Element {
Element& OnObject(std::string_view name) override {
if (name == "model") return model_element_;
if (name == "search") return search_element_;
if (name == "speculative") return speculative_element_;
if (name == "engine") return engine_element_;
throw JSON::unknown_value_error{};
}

Config& config_;
Model_Element model_element_{config_.model};
Search_Element search_element_{config_.search};
Speculative_Element speculative_element_{config_.speculative};
Engine_Element engine_element_{config_.engine};
};

Expand Down
7 changes: 7 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ struct Config {

} decoder;

Decoder draft;

} model;

struct Search {
Expand All @@ -428,6 +430,10 @@ struct Config {
float blank_penalty{}; // Penalty applied to blank token logits in CTC/RNNT decoding. Default 0 means no penalty.
} search;

struct Speculative {
int max_draft_tokens{4}; // Number of tokens the draft proposes per round.
} speculative;

struct Engine {
struct DynamicBatching {
size_t block_size{256}; // Total number of slots per block.
Expand All @@ -454,6 +460,7 @@ struct Config {

void SetSearchNumber(Config::Search& search, std::string_view name, double value);
void SetSearchBool(Config::Search& search, std::string_view name, bool value);
void SetSpeculativeNumber(Config::Speculative& speculative, std::string_view name, double value);
void ClearProviders(Config& config);
void SetProviderOption(Config& config, std::string_view provider_name, std::string_view option_name, std::string_view option_value);
void OverlayConfig(Config& config, std::string_view json);
Expand Down
31 changes: 31 additions & 0 deletions src/decoding_strategy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "decoding_strategy.h"

#include <memory>

#include "generators.h"
#include "standard_decoding_strategy.h"
#include "transducer_decoding_strategy.h"
#include "base_speculative_strategy.h"
#include "models/model.h"
#include "models/model_type.h"

namespace Generators {

// Default: no stats. Speculative strategies override.
SpeculativeStats DecodingStrategy::GetStats() const {
return SpeculativeStats{};
}

// Factory
std::unique_ptr<DecodingStrategy> MakeDecodingStrategy(Generator& generator) {
const auto& model_type = generator.model_->config_->model.type;
if (ModelType::IsTransducer(model_type))
return std::make_unique<TransducerDecodingStrategy>(generator);
if (model_type == "speculative")
return std::make_unique<BaseSpeculativeStrategy>(generator);
return std::make_unique<StandardDecodingStrategy>();
}

} // namespace Generators
32 changes: 32 additions & 0 deletions src/decoding_strategy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <memory>

namespace Generators {

struct Generator;
struct SpeculativeStats;

// Base interface for per-token generation dispatch. Chosen once at Generator
// construction based on the model type.
struct DecodingStrategy {
virtual ~DecodingStrategy() = default;

// Drives one user-visible "generate next token" step, committing exactly one
// token to the search sequence per call. Speculative variants compute several
// tokens per round internally but emit them one-per-call.
virtual void Step(Generator& generator) = 0;

// Default: no stats. Speculative strategies override.
virtual SpeculativeStats GetStats() const;

// Drop any per-round buffered state so a rewind/restart resumes cleanly.
// Speculative strategy needs override to clear its pending-token buffer.
virtual void Reset() {}
};

// Picks the right strategy after state_ and search_ are set up.
std::unique_ptr<DecodingStrategy> MakeDecodingStrategy(Generator& generator);

} // namespace Generators
Loading