Skip to content

Commit 7fde821

Browse files
committed
llm_runner: plumb prefill temperature
Session-based serving drives generation as prefill plus token steps instead of one monolithic generate call. For that path to be correct, the first sampled token produced during prefill must honor the same sampling inputs as the rest of the decode loop; otherwise requests using temperature can silently start greedily and then switch behavior on later tokens. This threads optional temperature through TextPrefiller and exposes the existing TextTokenGenerator logit-processor application so token-step callers can reuse the same sampling preparation as generate(). The goal is to remove a divergence point before session-backed serving starts depending on these primitives. Default behavior remains greedy, so existing callers that do not pass temperature keep the same semantics. The added tests focus on the new non-default path and on sharing the logit-processor logic rather than duplicating it.
1 parent d7ca5db commit 7fde821

4 files changed

Lines changed: 52 additions & 20 deletions

File tree

extension/llm/runner/test/test_text_prefiller.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ class TextPrefillerTest : public Test {
7979
MOCK_METHOD(
8080
::executorch::runtime::Result<uint64_t>,
8181
prefill_chunk,
82-
(std::vector<uint64_t>&, int64_t&),
83-
());
82+
(std::vector<uint64_t>&, int64_t&, float),
83+
(override));
8484
};
8585

8686
// Create a mock TextPrefiller
@@ -112,9 +112,9 @@ TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) {
112112
int64_t start_pos = 0;
113113

114114
// Expect prefill_chunk to be called exactly once with the entire prompt
115-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
115+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, _))
116116
.Times(1)
117-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
117+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos, float) {
118118
// Verify the tokens passed to prefill_chunk
119119
EXPECT_EQ(tokens.size(), prompt_tokens.size());
120120
for (size_t i = 0; i < tokens.size(); i++) {
@@ -217,14 +217,14 @@ TEST_F(TextPrefillerTest, PrefillHandlesPrefillChunkErrorsCorrectly) {
217217
InSequence seq;
218218

219219
// First chunk: tokens [1, 2, 3] - succeeds
220-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
221-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
220+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, _))
221+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos, float) {
222222
return Result<uint64_t>(10);
223223
});
224224

225225
// Second chunk: tokens [4, 5] - fails
226-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
227-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
226+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, _))
227+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos, float) {
228228
return Result<uint64_t>(Error::InvalidArgument);
229229
});
230230
}

extension/llm/runner/text_prefiller.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ TextPrefiller::TextPrefiller(
2828

2929
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
3030
std::vector<uint64_t>& prompt_tokens,
31-
int64_t& start_pos) {
31+
int64_t& start_pos,
32+
float temperature) {
3233
ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
3334
if (!text_decoder_runner_->is_method_loaded()) {
3435
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
@@ -54,8 +55,15 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
5455
num_tokens_to_prefill_with,
5556
prompt_tokens_to_process.begin());
5657

57-
// Process this chunk
58-
auto chunk_result = prefill_chunk(prompt_tokens_to_process, start_pos);
58+
// Process this chunk. Only the LAST chunk produces the first generated
59+
// token, so apply `temperature` there; intermediate chunks just prefill.
60+
const bool is_last_chunk =
61+
num_tokens_to_process + num_tokens_to_prefill_with >=
62+
num_prompt_tokens;
63+
auto chunk_result = prefill_chunk(
64+
prompt_tokens_to_process,
65+
start_pos,
66+
is_last_chunk ? temperature : 0.0f);
5967
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
6068
cur_token = chunk_result.get();
6169

@@ -65,13 +73,14 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
6573
return cur_token;
6674
} else {
6775
// If prompt tokens don't exceed max_seq_len_, process them directly
68-
return prefill_chunk(prompt_tokens, start_pos);
76+
return prefill_chunk(prompt_tokens, start_pos, temperature);
6977
}
7078
}
7179

7280
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
7381
std::vector<uint64_t>& prompt_tokens,
74-
int64_t& start_pos) {
82+
int64_t& start_pos,
83+
float temperature) {
7584
// enable_parallel_prefill_ maybe set even when not using kv cache
7685
// When kv cache is not used, start pos is ignored
7786
int32_t num_prompt_tokens = prompt_tokens.size();
@@ -92,7 +101,8 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
92101
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
93102

94103
start_pos += num_prompt_tokens;
95-
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
104+
cur_token =
105+
text_decoder_runner_->logits_to_token(outputs_res.get(), temperature);
96106
} else { // sequential prefill
97107
int64_t pos = 0; // position in the sequence
98108
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
@@ -128,7 +138,8 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
128138
start_pos++;
129139
}
130140

131-
cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
141+
cur_token =
142+
text_decoder_runner_->logits_to_token(logits_tensor, temperature);
132143
}
133144
return cur_token;
134145
}

extension/llm/runner/text_prefiller.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,28 @@ class ET_EXPERIMENTAL TextPrefiller {
3232
* tokenizer.
3333
* @param start_pos The starting position in KV cache of the input in the LLM
3434
* Module.
35+
* @param temperature Sampling temperature for the first generated token
36+
* (which is sampled here during prefill). Defaults to greedy (0.0).
3537
* @return The next token of the LLM Module after prefill.
3638
*/
3739
virtual ::executorch::runtime::Result<uint64_t> prefill(
3840
std::vector<uint64_t>& prompt_tokens,
39-
int64_t& start_pos);
41+
int64_t& start_pos,
42+
float temperature = 0.0f);
4043

4144
/**
4245
* Helper method to prefill a chunk of tokens.
4346
* @param prompt_tokens The chunk of text prompt tokens to process.
4447
* @param start_pos The starting position in KV cache of the input in the LLM
4548
* Module.
49+
* @param temperature Sampling temperature for the token produced by this
50+
* chunk. Defaults to greedy (0.0).
4651
* @return The next token of the LLM Module after prefilling this chunk.
4752
*/
4853
virtual ::executorch::runtime::Result<uint64_t> prefill_chunk(
4954
std::vector<uint64_t>& prompt_tokens,
50-
int64_t& start_pos);
55+
int64_t& start_pos,
56+
float temperature = 0.0f);
5157

5258
/**
5359
* Load the necessary resources for the TextPrefiller.

extension/llm/runner/text_token_generator.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ class ET_EXPERIMENTAL TextTokenGenerator {
5555
return logit_processors_.size();
5656
}
5757

58+
/// Apply the registered logit processors (grammar/tool masks, penalties,
59+
/// top-k/top-p, ...) to `logits` in order, before sampling. Both the
60+
/// generate() loop and session decode_one() call this so the two decode paths
61+
/// stay consistent.
62+
inline ::executorch::runtime::Error apply_logit_processors(
63+
executorch::aten::Tensor& logits) {
64+
for (auto& processor : logit_processors_) {
65+
ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits));
66+
}
67+
return ::executorch::runtime::Error::Ok;
68+
}
69+
5870
virtual ~TextTokenGenerator() = default;
5971

6072
/**
@@ -126,9 +138,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
126138

127139
prev_token = cur_token;
128140

129-
for (auto& processor : logit_processors_) {
130-
ET_CHECK_OK_OR_RETURN_ERROR(processor->process(logits_tensor));
131-
}
141+
ET_CHECK_OK_OR_RETURN_ERROR(apply_logit_processors(logits_tensor));
132142

133143
stats_->on_sampling_begin();
134144
cur_token =
@@ -180,6 +190,11 @@ class ET_EXPERIMENTAL TextTokenGenerator {
180190
should_stop_.store(true, std::memory_order_relaxed);
181191
}
182192

193+
/// Whether `token` is an end-of-sequence token (used by single-step decode).
194+
inline bool is_eos(uint64_t token) const {
195+
return eos_ids_->find(token) != eos_ids_->end();
196+
}
197+
183198
/**
184199
* Load the necessary resources for TextTokenGenerator.
185200
* This method should be called before using the generate() method.

0 commit comments

Comments
 (0)