Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ struct Config {
float top_p{}; // If set to float >0 and <1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
float temperature{1.0f}; // Temperature to control during generation. Default is 1.0.
bool early_stopping{true}; // Whether to stop the beam search when at least num_beams sentences are finished per batch or not.
int no_repeat_ngram_size{}; // Unused param
int no_repeat_ngram_size{}; // If > 0, no n-gram of this size may repeat in the generated sequence. 0 disables.
float diversity_penalty{}; // Unused param
float length_penalty{1.0f}; // Exponential penalty to the length that is used with beam-based generation. length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
bool past_present_share_buffer{}; // The past/present kv tensors are shared and allocated once to max_length (cuda only)
Expand Down
1 change: 1 addition & 0 deletions src/engine/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ void Request::GenerateNextTokens(DeviceSpan<float> logits) {
auto& search_params = search_->params_->search;
search_->ApplyMinLength(search_params.min_length);
search_->ApplyRepetitionPenalty(search_params.repetition_penalty);
search_->ApplyNoRepeatNgram(search_params.no_repeat_ngram_size);

if (!search_params.do_sample || search_params.top_k == 1 || search_params.temperature == 0) {
search_->SelectTop();
Expand Down
1 change: 1 addition & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ void Generator::GenerateNextToken() {
auto& search = search_->params_->search;
search_->ApplyMinLength(search.min_length);
search_->ApplyRepetitionPenalty(search.repetition_penalty);
search_->ApplyNoRepeatNgram(search.no_repeat_ngram_size);

if (g_log.enabled && g_log.generate_next_token) {
auto& stream = Log("generate_next_token");
Expand Down
39 changes: 39 additions & 0 deletions src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,4 +580,43 @@ void Search_Cpu::ApplyRepetitionPenalty(float penalty) {
}
}

void Search_Cpu::ApplyNoRepeatNgram(int ngram_size) {
if (ngram_size <= 0)
return;

const int sequence_length = sequences_.GetSequenceLength();
// Need at least one complete n-gram in history before anything can be banned.
if (sequence_length < ngram_size)
return;

const int prefix_length = ngram_size - 1;
const int batch_beam_size = params_->BatchBeamSize();
for (int i = 0; i < batch_beam_size; i++) {
std::span<float> const beam_token_scores = GetScores(i);
std::span<const int32_t> const sequence = sequences_.GetSequence(i).CopyDeviceToCpu();

// The prefix we are about to extend: the trailing (ngram_size - 1) tokens.
std::span<const int32_t> const target_prefix = sequence.subspan(sequence_length - prefix_length, prefix_length);

// Scan every historical n-gram. Its first (ngram_size - 1) tokens form a prefix
// and its last token is what followed. If the prefix matches the trailing prefix,
// ban that following token so the same n-gram cannot repeat.
const int last_start = sequence_length - ngram_size;
for (int start = 0; start <= last_start; start++) {
bool matches = true;
for (int j = 0; j < prefix_length; j++) {
if (sequence[start + j] != target_prefix[j]) {
matches = false;
break;
}
}
if (matches) {
const int32_t banned_token = sequence[start + prefix_length];
if (banned_token >= 0 && banned_token < params_->config.model.vocab_size)
beam_token_scores[banned_token] = std::numeric_limits<float>::lowest();
}
}
}
}

} // namespace Generators
4 changes: 4 additions & 0 deletions src/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ struct Search : LeakChecked<Search> {
// Scoring features
virtual void ApplyMinLength(int min_length) = 0;
virtual void ApplyRepetitionPenalty(float penalty) = 0;
// Bans tokens that would complete an already-seen n-gram of size ngram_size.
// Default no-op so backends that don't implement it (e.g. CUDA) are unaffected.
virtual void ApplyNoRepeatNgram(int /*ngram_size*/) {}

// Set user input tokens
virtual void AppendTokens(DeviceSpan<int32_t>& next_tokens) { assert(false); };
Expand All @@ -50,6 +53,7 @@ struct Search_Cpu : Search {

void ApplyMinLength(int min_length) override;
void ApplyRepetitionPenalty(float penalty) override;
void ApplyNoRepeatNgram(int ngram_size) override;

std::span<float> GetScores(int batch_beam_index);

Expand Down
46 changes: 46 additions & 0 deletions test/sampling_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,52 @@ TEST(SamplingTests, TopKExceedingVocabSizeIsRejected) {
EXPECT_THROW(OgaGenerator::Create(*model, *params), std::runtime_error);
}

// Functional correctness test for ApplyNoRepeatNgram.
// Greedy decoding is forced toward a token that would repeat an n-gram already
// present in the sequence. With no_repeat_ngram_size=3 that token must be banned
// so the next-best allowed token is chosen instead.
TEST(SamplingTests, NoRepeatNgramCorrectnessCpu) {
const int vocab_size = 1000; // Must match tiny-random-gpt2-fp32 model's actual vocab
const int batch_size = 1;

auto config = OgaConfig::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
config->ClearProviders();
auto model = OgaModel::Create(*config);

std::array<int64_t, 2> shape = {static_cast<int64_t>(batch_size), static_cast<int64_t>(vocab_size)};
std::vector<float> logits_cpu(vocab_size * batch_size, 5.0f);
logits_cpu[7] = 10.0f; // highest score; would complete the repeated 3-gram (5, 6, 7)
logits_cpu[42] = 8.0f; // best score among allowed tokens

// Sequence already contains the 3-gram (5, 6, 7) and currently ends with the
// prefix (5, 6), so token 7 would repeat that 3-gram.
std::vector<int32_t> prefill_tokens = {5, 6, 7, 8, 5, 6};

auto first_token = [&](int ngram_size) -> int32_t {
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 32);
params->SetSearchOptionBool("do_sample", false); // greedy
params->SetSearchOption("batch_size", batch_size);
params->SetSearchOption("repetition_penalty", 1.0); // isolate the n-gram effect
params->SetSearchOption("no_repeat_ngram_size", static_cast<double>(ngram_size));

auto generator = OgaGenerator::Create(*model, *params);
generator->AppendTokens(prefill_tokens.data(), static_cast<int>(prefill_tokens.size()));
generator->SetLogits(*OgaTensor::Create(logits_cpu.data(), shape));
generator->GenerateNextToken();
return generator->GetNextTokens()[0];
};

// Baseline: without n-gram blocking, greedy picks the highest-scoring token (7).
EXPECT_EQ(first_token(0), 7)
<< "Greedy should pick the highest-scoring token when n-gram blocking is off.";

// With no_repeat_ngram_size=3, token 7 would repeat the 3-gram (5, 6, 7) and
// must be banned, so greedy falls back to the next-best allowed token (42).
EXPECT_EQ(first_token(3), 42)
<< "Token 7 should be banned (repeats 3-gram 5,6,7); expected fallback to token 42.";
}

#if USE_CUDA
TEST(SamplingTests, BatchedSamplingTopPCuda) {
std::vector<int32_t> input_ids{0, 1, 2, 3};
Expand Down
Loading