Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1293,8 +1293,18 @@ struct Search_Element : JSON::Element {
}
}

Element& OnArray(std::string_view name) override {
if (name == "suppress_tokens")
return suppress_tokens_;
if (name == "begin_suppress_tokens")
return begin_suppress_tokens_;
throw JSON::unknown_value_error{};
}

private:
Config::Search& v_;
Int_Array_Element suppress_tokens_{v_.suppress_tokens};
Int_Array_Element begin_suppress_tokens_{v_.begin_suppress_tokens};
};

struct DynamicBatching_Element : JSON::Element {
Expand Down Expand Up @@ -1374,6 +1384,16 @@ void SetSearchBool(Config::Search& search, std::string_view name, bool value) {
}
}

void SetSearchTokensArray(Config::Search& search, std::string_view name, std::span<const int32_t> tokens) {
if (name == "suppress_tokens") {
search.suppress_tokens.assign(tokens.begin(), tokens.end());
} else if (name == "begin_suppress_tokens") {
search.begin_suppress_tokens.assign(tokens.begin(), tokens.end());
} else {
throw std::runtime_error("Unknown search tokens array option: " + std::string(name));
}
}

void ClearProviders(Config& config) {
config.model.decoder.session_options.providers.clear();
}
Expand Down
5 changes: 5 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,10 @@ struct Config {
int random_seed{-1}; // -1 = Seed with random device, otherwise use value to seed RNG
std::optional<size_t> chunk_size; // Chunk size for prefill chunking during context processing. If present, chunking is enabled with the chunk size > 0.
float blank_penalty{}; // Penalty applied to blank token logits in CTC/RNNT decoding. Default 0 means no penalty.
// Token ids whose logits are set to -inf at every decoding step.
std::vector<int> suppress_tokens;
// Token ids whose logits are set to -inf only at the first generated step.
std::vector<int> begin_suppress_tokens;
} search;

struct Engine {
Expand Down Expand Up @@ -454,6 +458,7 @@ struct Config {

void SetSearchNumber(Config::Search& search, std::string_view name, double value);
void SetSearchBool(Config::Search& search, std::string_view name, bool value);
void SetSearchTokensArray(Config::Search& search, std::string_view name, std::span<const int32_t> tokens);
void ClearProviders(Config& config);
void SetProviderOption(Config& config, std::string_view provider_name, std::string_view option_name, std::string_view option_value);
void OverlayConfig(Config& config, std::string_view json);
Expand Down
11 changes: 11 additions & 0 deletions src/cuda/search_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,4 +270,15 @@ void Search_Cuda::ApplyRepetitionPenalty(float penalty) {
params_->search.max_length, GetSequenceLength(), penalty, GetStream());
}

void Search_Cuda::ApplySuppressTokens(const std::vector<int>& suppress_tokens) {
if (suppress_tokens.empty())
return;

const int vocab_size = params_->config.model.vocab_size;
for (auto token_id : suppress_tokens) {
if (token_id >= 0 && token_id < vocab_size)
cuda::LaunchSetScoreProcessor(GetScores().data(), params_->BatchBeamSize(), vocab_size, token_id, std::numeric_limits<float>::lowest(), GetStream());
}
}

} // namespace Generators
1 change: 1 addition & 0 deletions src/cuda/search_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct Search_Cuda : Search {

void ApplyMinLength(int min_length) override;
void ApplyRepetitionPenalty(float penalty) override;
void ApplySuppressTokens(const std::vector<int>& suppress_tokens) override;

std::span<float> GetScores(int batch_beam_index);
std::span<float> GetScores();
Expand Down
8 changes: 8 additions & 0 deletions src/engine/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ void Request::GenerateNextTokens(DeviceSpan<float> logits) {
auto& search_params = search_->params_->search;
search_->ApplyMinLength(search_params.min_length);
search_->ApplyRepetitionPenalty(search_params.repetition_penalty);
search_->ApplySuppressTokens(search_params.suppress_tokens);
if (!search_params.begin_suppress_tokens.empty()) {
// begin_suppress_tokens are only suppressed at the first generated step (current length == prompt length).
if (begin_suppress_length_ < 0)
begin_suppress_length_ = search_->GetSequenceLength();
if (search_->GetSequenceLength() == begin_suppress_length_)
search_->ApplySuppressTokens(search_params.begin_suppress_tokens);
}

if (!search_params.do_sample || search_params.top_k == 1 || search_params.temperature == 0) {
search_->SelectTop();
Expand Down
2 changes: 2 additions & 0 deletions src/engine/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ struct Request : std::enable_shared_from_this<Request>,
std::vector<int32_t> prefill_input_ids_;
int64_t seen_sequence_length_{};
int64_t processed_sequence_length_{};
// Sequence length at the first generated step; used for begin_suppress_tokens. -1 = not yet set.
int64_t begin_suppress_length_{-1};
std::shared_ptr<GeneratorParams> params_;
std::unique_ptr<Search> search_;
std::weak_ptr<Engine> engine_;
Expand Down
8 changes: 8 additions & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,14 @@ void Generator::GenerateNextToken() {
auto& search = search_->params_->search;
search_->ApplyMinLength(search.min_length);
search_->ApplyRepetitionPenalty(search.repetition_penalty);
search_->ApplySuppressTokens(search.suppress_tokens);
if (!search.begin_suppress_tokens.empty()) {
// begin_suppress_tokens are only suppressed at the first generated step (current length == prompt length).
if (begin_suppress_length_ < 0)
begin_suppress_length_ = search_->GetSequenceLength();
if (search_->GetSequenceLength() == begin_suppress_length_)
search_->ApplySuppressTokens(search.begin_suppress_tokens);
}

if (g_log.enabled && g_log.generate_next_token) {
auto& stream = Log("generate_next_token");
Expand Down
2 changes: 2 additions & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ struct Generator : LeakChecked<Generator> {
// Non-null when the model is a transducer (RNNT, TDT); points into state_.
TransducerState* transducer_state_{nullptr};
int phi3_rope_threshold_{}; // 0 means no ROPE rewind needed
// Sequence length at the first generated step; used for begin_suppress_tokens. -1 = not yet set.
int begin_suppress_length_{-1};
enum class SamplingMethod { kGreedy,
kTopK,
kTopP,
Expand Down
4 changes: 4 additions & 0 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ struct OgaGeneratorParams : OgaAbstract {
OgaCheckResult(OgaGeneratorParamsSetSearchBool(this, name, value));
}

void SetSearchOptionTokensArray(const char* name, const int32_t* tokens, size_t tokens_count) {
OgaCheckResult(OgaGeneratorParamsSetSearchTokensArray(this, name, tokens, tokens_count));
}

void SetGuidance(const char* type, const char* data, bool enable_ff_tokens = false) {
OgaCheckResult(OgaGeneratorParamsSetGuidance(this, type, data, enable_ff_tokens));
}
Expand Down
7 changes: 7 additions & 0 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,13 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* para
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchTokensArray(OgaGeneratorParams* params, const char* name, const int32_t* tokens, size_t tokens_count) {
OGA_TRY
Generators::SetSearchTokensArray(params->search, name, std::span<const int32_t>(tokens, tokens_count));
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGeneratorParamsSetGuidance(OgaGeneratorParams* params, const char* type, const char* data, bool enable_ff_tokens) {
OGA_TRY
params->SetGuidance(type, data, enable_ff_tokens);
Expand Down
10 changes: 10 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,16 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGenerato
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* params, const char* name, bool value);

/**
* \brief Set an array of token ids for a search parameter (e.g. suppress_tokens, begin_suppress_tokens)
* \param[in] params The generator params to set.
* \param[in] name The name of the search parameter.
* \param[in] tokens The array of token ids.
* \param[in] tokens_count The number of token ids in the array.
* \return OgaResult containing the error message if setting the generator params failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchTokensArray(OgaGeneratorParams* params, const char* name, const int32_t* tokens, size_t tokens_count);

/**
* \brief Sets the guidance type and data for the Generator params
* \param[in] params The generator params to set the guidance on
Expand Down
17 changes: 17 additions & 0 deletions src/python/py/models/builders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,14 @@ def make_tied_embeddings_init(self, config):
# matmul_nbits_quantizer.py has a different naming for default quantization, so lm_head.MatMul.weight_Q{}G{} does not match.
self.shared_embeddings = self.int8_lm_head or self.extra_options.get("int4_algo_config", "default") in {"rtn", "k_quant"}

def add_suppress_tokens_to_search_config(self, search_config, suppress_tokens, begin_suppress_tokens):
# Add suppress tokens for HF generation parity. Only emitted when present and non-empty.
# `suppress_tokens` are suppressed at every decoding step; `begin_suppress_tokens` only at the first step.
if suppress_tokens:
search_config["suppress_tokens"] = list(suppress_tokens)
if begin_suppress_tokens:
search_config["begin_suppress_tokens"] = list(begin_suppress_tokens)

def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
# Create config with attributes from config.json and generation_config.json (if latter file exists)
config = AutoConfig.from_pretrained(
Expand All @@ -549,6 +557,8 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
"temperature": 1.0,
"top_k": 50,
"top_p": 1.0,
"suppress_tokens": None,
"begin_suppress_tokens": None,
}
for key, default_val in defaults.items():
val = getattr(gen_config, key)
Expand Down Expand Up @@ -631,6 +641,13 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
},
}

# Suppress tokens from generation_config.json (HF generation parity).
self.add_suppress_tokens_to_search_config(
genai_config["search"],
getattr(config, "suppress_tokens", None),
getattr(config, "begin_suppress_tokens", None),
)

if self.ep == "trt-rtx" and self.window_size is not None and self.window_size > 0:
# Compute layer indices that use sliding window attention
layer_idxs = [
Expand Down
9 changes: 9 additions & 0 deletions src/python/py/models/builders/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.pad_token_id = config.pad_token_id
self.vocab_size = config.vocab_size

# Suppress tokens for HF generation parity (Whisper suppresses non-speech/special tokens).
self.suppress_tokens = config.suppress_tokens if hasattr(config, "suppress_tokens") else None
self.begin_suppress_tokens = config.begin_suppress_tokens if hasattr(config, "begin_suppress_tokens") else None

self.hf_token = self.decoder.hf_token
self.hf_remote = self.decoder.hf_remote
self.context_length = self.decoder.context_length
Expand Down Expand Up @@ -989,5 +993,10 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
},
}

# Suppress tokens from the model config (HF generation parity).
self.add_suppress_tokens_to_search_config(
genai_config["search"], self.suppress_tokens, self.begin_suppress_tokens
)

with open(os.path.join(out_dir, "genai_config.json"), "w") as f:
json.dump(genai_config, f, indent=4)
7 changes: 6 additions & 1 deletion src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,13 @@ struct PyGeneratorParams {
params_->SetSearchOptionBool(name.c_str(), entry.second.cast<bool>());
} else if (pybind11::isinstance<pybind11::int_>(entry.second)) {
params_->SetSearchOption(name.c_str(), entry.second.cast<int>());
} else if (pybind11::isinstance<pybind11::list>(entry.second) || pybind11::isinstance<pybind11::tuple>(entry.second)) {
std::vector<int32_t> tokens;
for (auto item : entry.second)
tokens.push_back(pybind11::cast<int32_t>(item));
params_->SetSearchOptionTokensArray(name.c_str(), tokens.data(), tokens.size());
} else
throw std::runtime_error("Unknown search option type, can be float/bool/int:" + name);
throw std::runtime_error("Unknown search option type, can be float/bool/int/list[int]:" + name);
}
}

Expand Down
15 changes: 15 additions & 0 deletions src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -576,4 +576,19 @@ void Search_Cpu::ApplyRepetitionPenalty(float penalty) {
}
}

void Search_Cpu::ApplySuppressTokens(const std::vector<int>& suppress_tokens) {
if (suppress_tokens.empty())
return;

const int batch_beam_size = params_->BatchBeamSize();
const int vocab_size = params_->config.model.vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
std::span<float> const beam_token_scores = GetScores(i);
for (auto token_id : suppress_tokens) {
if (token_id >= 0 && token_id < vocab_size)
beam_token_scores[token_id] = std::numeric_limits<float>::lowest();
}
}
}

} // namespace Generators
2 changes: 2 additions & 0 deletions src/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct Search : LeakChecked<Search> {
// Scoring features
virtual void ApplyMinLength(int min_length) = 0;
virtual void ApplyRepetitionPenalty(float penalty) = 0;
virtual void ApplySuppressTokens(const std::vector<int>& suppress_tokens) = 0;

// Set user input tokens
virtual void AppendTokens(DeviceSpan<int32_t>& next_tokens) { assert(false); };
Expand All @@ -50,6 +51,7 @@ struct Search_Cpu : Search {

void ApplyMinLength(int min_length) override;
void ApplyRepetitionPenalty(float penalty) override;
void ApplySuppressTokens(const std::vector<int>& suppress_tokens) override;

std::span<float> GetScores(int batch_beam_index);

Expand Down
47 changes: 47 additions & 0 deletions test/python/test_onnxruntime_genai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,53 @@ def test_greedy_search(test_data_path, relative_model_path):
assert int(generator.token_count()) == len(generator.get_sequence(0))


def test_suppress_tokens(test_data_path):
model_path = os.fspath(
Path(test_data_path) / "models" / "hf-internal-testing" / "tiny-random-gpt2-fp32"
)

config = og.Config(model_path)
config.overlay('{ "model": { "vocab_size": 5 } }')
model = og.Model(config)

# suppress_tokens are passed as a list through set_search_options.
# Token 1 has the highest logit, but tokens 1 and 3 are suppressed at every
# step, so greedy selection should fall back to token 2.
params = og.GeneratorParams(model)
params.set_search_options(
do_sample=False, max_length=10, batch_size=1, suppress_tokens=[1, 3]
)

generator = og.Generator(model, params)
logits = np.array([[0.1, 0.9, 0.5, 0.7, 0.2]], dtype=np.float32)
generator.set_logits(logits)
generator.generate_next_token()
assert generator.get_next_tokens()[0] == 2


def test_begin_suppress_tokens(test_data_path):
model_path = os.fspath(
Path(test_data_path) / "models" / "hf-internal-testing" / "tiny-random-gpt2-fp32"
)

config = og.Config(model_path)
config.overlay('{ "model": { "vocab_size": 5 } }')
model = og.Model(config)

# begin_suppress_tokens are suppressed only at the first generated step.
# Token 1 has the highest logit but is suppressed at the begin step, so token 3 is chosen.
params = og.GeneratorParams(model)
params.set_search_options(
do_sample=False, max_length=10, batch_size=1, begin_suppress_tokens=[1]
)

generator = og.Generator(model, params)
logits = np.array([[0.1, 0.9, 0.5, 0.7, 0.2]], dtype=np.float32)
generator.set_logits(logits)
generator.generate_next_token()
assert generator.get_next_tokens()[0] == 3


@pytest.mark.parametrize(
"relative_model_path",
(
Expand Down
44 changes: 44 additions & 0 deletions test/sampling_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,50 @@ TEST(SamplingTests, BatchedSamplingTopPCpu) {
EXPECT_TRUE(0 == std::memcmp(expected_output.data(), next_tokens.data(), expected_output.size() * sizeof(int32_t)));
}

TEST(SamplingTests, SuppressTokensCpu) {
// Greedy selection: token 1 has the highest logit, but suppressing tokens 1 and 3
// should force selection of the next highest unsuppressed token (token 2).
std::vector<float> logits_cpu = {0.1f, 0.9f, 0.5f, 0.7f, 0.2f};

auto config = OgaConfig::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
config->Overlay(R"({ "model": { "vocab_size" : 5 }, "search": { "suppress_tokens": [1, 3] } })");

auto model = OgaModel::Create(*config);
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 10);
params->SetSearchOption("batch_size", 1);

auto generator = OgaGenerator::Create(*model, *params);
auto logits_tensor = OgaTensor::Create(logits_cpu.data(), std::array<int64_t, 2>{1LL, 5LL});
generator->SetLogits(*logits_tensor);

generator->GenerateNextToken();
auto next_tokens = generator->GetNextTokens();
EXPECT_EQ(next_tokens[0], 2);
}

TEST(SamplingTests, BeginSuppressTokensCpu) {
// begin_suppress_tokens are suppressed only at the first generated step.
// Token 1 has the highest logit but is suppressed at the begin step, so token 3 (next highest) is chosen.
std::vector<float> logits_cpu = {0.1f, 0.9f, 0.5f, 0.7f, 0.2f};

auto config = OgaConfig::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
config->Overlay(R"({ "model": { "vocab_size" : 5 }, "search": { "begin_suppress_tokens": [1] } })");

auto model = OgaModel::Create(*config);
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 10);
params->SetSearchOption("batch_size", 1);

auto generator = OgaGenerator::Create(*model, *params);
auto logits_tensor = OgaTensor::Create(logits_cpu.data(), std::array<int64_t, 2>{1LL, 5LL});
generator->SetLogits(*logits_tensor);

generator->GenerateNextToken();
auto next_tokens = generator->GetNextTokens();
EXPECT_EQ(next_tokens[0], 3);
}

TEST(SamplingTests, BatchedSamplingTopKCpu) {
std::vector<int32_t> input_ids{0, 1, 2, 3};
std::vector<float> logits_cpu{2.0f, 1.5f, 1.25f, 0.25f, 0.25f,
Expand Down
Loading