Skip to content

Commit ee1ea74

Browse files
committed
llm_runner: add streaming text boundary helpers
The serving stack streams decoded text incrementally, but token boundaries do not guarantee user-visible text boundaries. A chunk can end in the middle of a UTF-8 sequence, or immediately before bytes that complete a configured stop string. Handling that ad hoc in each server path would make streaming correctness easy to regress. This adds small, model-agnostic runner helpers for computing UTF-8-safe and stop-string-safe prefixes before emitting text. Keeping the logic in the runner utility layer gives both generic and model-specific workers one tested implementation instead of duplicating fragile string handling. The change is intentionally narrow: it introduces pure helper functions and focused unit coverage only. It does not change runner generation behavior on its own.
1 parent d7ca5db commit ee1ea74

2 files changed

Lines changed: 204 additions & 0 deletions

File tree

extension/llm/runner/test/test_util.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ namespace {
1818
using ::executorch::aten::ScalarType;
1919
using ::executorch::extension::make_tensor_ptr;
2020
using ::executorch::extension::llm::convert_to_bfloat16;
21+
using ::executorch::extension::llm::stop_safe_prefix_len;
22+
using ::executorch::extension::llm::utf8_complete_prefix_len;
2123

2224
class ConvertToBFloat16Test : public ::testing::Test {
2325
protected:
@@ -63,4 +65,94 @@ TEST_F(ConvertToBFloat16Test, RejectsNonFloatTensor) {
6365
EXPECT_EQ(result.error(), ::executorch::runtime::Error::InvalidArgument);
6466
}
6567

68+
TEST(Utf8CompletePrefixLenTest, HandlesAsciiAndMultiByteBoundaries) {
69+
EXPECT_EQ(utf8_complete_prefix_len(""), 0u);
70+
EXPECT_EQ(utf8_complete_prefix_len("ascii"), 5u);
71+
72+
// Complete multi-byte characters are fully consumed.
73+
EXPECT_EQ(utf8_complete_prefix_len("\xc3\xa9"), 2u); // é (2-byte)
74+
EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82\xac"), 3u); // € (3-byte)
75+
EXPECT_EQ(utf8_complete_prefix_len("\xf0\x9f\x98\x80"), 4u); // 😀 (4-byte)
76+
77+
// A character split across the end is held back (not counted).
78+
EXPECT_EQ(utf8_complete_prefix_len("\xc3"), 0u); // 1/2 of é
79+
EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82"), 0u); // 2/3 of €
80+
EXPECT_EQ(utf8_complete_prefix_len("\xf0\x9f\x98"), 0u); // 3/4 of 😀
81+
82+
// A complete prefix followed by a split character keeps the complete part.
83+
EXPECT_EQ(utf8_complete_prefix_len("hi\xe2\x82"), 2u);
84+
EXPECT_EQ(utf8_complete_prefix_len("\xe2\x82\xac\xf0\x9f"), 3u);
85+
86+
// An invalid lead byte counts as length 1 (emitted, not stalled).
87+
EXPECT_EQ(utf8_complete_prefix_len("\x80"), 1u);
88+
}
89+
90+
TEST(StopSafePrefixLenTest, NoStopsEmitsEverything) {
91+
bool hit = true;
92+
EXPECT_EQ(stop_safe_prefix_len("hello world", {}, hit), 11u);
93+
EXPECT_FALSE(hit);
94+
}
95+
96+
TEST(StopSafePrefixLenTest, SingleByteStopMissEmitsEverything) {
97+
bool hit = true;
98+
const std::string text = "caf\xc3\xa9";
99+
EXPECT_EQ(stop_safe_prefix_len(text, {"Z"}, hit), text.size());
100+
EXPECT_FALSE(hit);
101+
}
102+
103+
TEST(StopSafePrefixLenTest, EmptyStopsDoNotHoldBack) {
104+
bool hit = true;
105+
EXPECT_EQ(stop_safe_prefix_len("hello", {""}, hit), 5u);
106+
EXPECT_FALSE(hit);
107+
}
108+
109+
TEST(StopSafePrefixLenTest, StopFoundReturnsEarliestOffsetAndExcludesIt) {
110+
bool hit = false;
111+
// "STOP" begins at offset 6; emit "Hello " (6 bytes), drop the stop and rest.
112+
EXPECT_EQ(stop_safe_prefix_len("Hello STOP there", {"STOP"}, hit), 6u);
113+
EXPECT_TRUE(hit);
114+
// Earliest of several wins.
115+
hit = false;
116+
EXPECT_EQ(stop_safe_prefix_len("aXbY", {"Y", "X"}, hit), 1u);
117+
EXPECT_TRUE(hit);
118+
}
119+
120+
TEST(StopSafePrefixLenTest, EarliestStopWinsEvenWhenLongerStopSetsHoldBack) {
121+
bool hit = false;
122+
EXPECT_EQ(stop_safe_prefix_len("abcXtail", {"LONGSTOP", "X"}, hit), 3u);
123+
EXPECT_TRUE(hit);
124+
}
125+
126+
TEST(StopSafePrefixLenTest, HoldsBackPossiblePartialStopTail) {
127+
bool hit = false;
128+
// No full stop yet, but the trailing "ST" could become "STOP": hold back
129+
// len("STOP")-1 == 3 bytes, so of "hi ST" (5 bytes) only "hi" (2) is safe.
130+
EXPECT_EQ(stop_safe_prefix_len("hi ST", {"STOP"}, hit), 2u);
131+
EXPECT_FALSE(hit);
132+
}
133+
134+
TEST(StopSafePrefixLenTest, HoldBackSnapsToUtf8Boundary) {
135+
bool hit = false;
136+
// "ab" + "€"(3 bytes). Stop "XX" => hold back 1 byte, which would land inside
137+
// the euro sign; snap down so the multi-byte char isn't split.
138+
const std::string text = "ab\xe2\x82\xac";
139+
const size_t safe = stop_safe_prefix_len(text, {"XX"}, hit);
140+
EXPECT_FALSE(hit);
141+
EXPECT_EQ(safe, 2u); // only "ab"; the € is held whole
142+
}
143+
144+
TEST(StopSafePrefixLenTest, HoldBackWithIncompleteUtf8TailSnapsToBoundary) {
145+
bool hit = false;
146+
const std::string text = "ab\xe2\x82";
147+
EXPECT_EQ(stop_safe_prefix_len(text, {"XX"}, hit), 2u);
148+
EXPECT_FALSE(hit);
149+
}
150+
151+
TEST(StopSafePrefixLenTest, HoldZeroDoesNotEmitDanglingUtf8LeadByte) {
152+
bool hit = false;
153+
const std::string text = "ab\xc3";
154+
EXPECT_EQ(stop_safe_prefix_len(text, {"Z"}, hit), 2u);
155+
EXPECT_FALSE(hit);
156+
}
157+
66158
} // namespace

extension/llm/runner/util.h

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
#include <executorch/runtime/platform/compiler.h>
1414
#include <stdio.h>
1515
#include <time.h>
16+
#include <algorithm>
1617
#include <cctype>
18+
#include <string>
1719
#include <vector>
1820
#if defined(__linux__) || defined(__ANDROID__) || defined(__unix__)
1921
#include <sys/resource.h>
@@ -81,6 +83,116 @@ ET_EXPERIMENTAL void inline safe_printf(const char* piece) {
8183
printf("%s", piece);
8284
}
8385

86+
// Length of the longest prefix of `s` that does not end in the middle of a
87+
// UTF-8 multi-byte sequence. A byte-level tokenizer can emit a token that is
88+
// only part of a character (e.g. one byte of a 3-byte CJK codepoint or emoji),
89+
// so a caller streaming text must hold the incomplete tail until it completes
90+
// rather than decode the partial bytes. An invalid lead byte counts as length 1
91+
// (emitted, so the caller can replace it) rather than stalling output.
92+
ET_EXPERIMENTAL size_t inline utf8_complete_prefix_len(const std::string& s) {
93+
size_t i = 0;
94+
const size_t n = s.size();
95+
while (i < n) {
96+
const unsigned char c = static_cast<unsigned char>(s[i]);
97+
size_t len;
98+
if (c < 0x80) {
99+
len = 1;
100+
} else if ((c >> 5) == 0x6) {
101+
len = 2;
102+
} else if ((c >> 4) == 0xE) {
103+
len = 3;
104+
} else if ((c >> 3) == 0x1E) {
105+
len = 4;
106+
} else {
107+
len = 1; // invalid lead byte; emit it and let the caller replace it
108+
}
109+
if (i + len > n) {
110+
break; // incomplete trailing sequence: hold it for more bytes
111+
}
112+
i += len;
113+
}
114+
return i;
115+
}
116+
117+
ET_EXPERIMENTAL size_t inline utf8_safe_prefix_len(
118+
const std::string& s,
119+
size_t len) {
120+
len = std::min(len, s.size());
121+
if (len == 0) {
122+
return 0;
123+
}
124+
const auto* data = reinterpret_cast<const unsigned char*>(s.data());
125+
size_t i = len;
126+
while (i > 0 && (data[i - 1] & 0xC0) == 0x80) {
127+
--i;
128+
}
129+
if (i == 0) {
130+
return 0;
131+
}
132+
const size_t lead_pos = i - 1;
133+
const unsigned char lead = data[lead_pos];
134+
size_t need = 0;
135+
if (lead < 0x80) {
136+
need = 1;
137+
} else if ((lead & 0xE0) == 0xC0) {
138+
need = 2;
139+
} else if ((lead & 0xF0) == 0xE0) {
140+
need = 3;
141+
} else if ((lead & 0xF8) == 0xF0) {
142+
need = 4;
143+
} else {
144+
return lead_pos;
145+
}
146+
return len - lead_pos == need ? len : lead_pos;
147+
}
148+
149+
// How many leading bytes of `text` a streaming consumer may safely emit given a
150+
// set of `stops` strings, and whether a stop was hit (`stop_hit`).
151+
// * If any stop occurs, returns the byte offset of the EARLIEST occurrence
152+
// and
153+
// sets stop_hit=true — text before it is safe; the stop and everything
154+
// after are dropped (the stop is excluded from output).
155+
// * Otherwise returns the length minus the longest possible partial-stop tail
156+
// (max(len(stop))-1 bytes), snapped DOWN to a UTF-8 boundary so a
157+
// multi-byte character is never split; stop_hit=false. Holding back that
158+
// conservative tail lets a stop that straddles the next piece still be
159+
// caught without suffix-prefix matching each stop.
160+
// `text` is expected to be complete-UTF-8 (e.g. the assembled output of
161+
// utf8_complete_prefix_len) and stops are expected to be real text, so a found
162+
// stop offset cannot split a UTF-8 character. Empty `stops` => emit everything,
163+
// no hold-back.
164+
ET_EXPERIMENTAL size_t inline stop_safe_prefix_len(
165+
const std::string& text,
166+
const std::vector<std::string>& stops,
167+
bool& stop_hit) {
168+
stop_hit = false;
169+
if (stops.empty()) {
170+
return text.size();
171+
}
172+
size_t earliest = std::string::npos;
173+
size_t max_len = 0;
174+
for (const auto& s : stops) {
175+
if (s.empty()) {
176+
continue;
177+
}
178+
max_len = std::max(max_len, s.size());
179+
const size_t p = text.find(s);
180+
if (p != std::string::npos &&
181+
(earliest == std::string::npos || p < earliest)) {
182+
earliest = p;
183+
}
184+
}
185+
if (earliest != std::string::npos) {
186+
stop_hit = true;
187+
return earliest;
188+
}
189+
const size_t hold = max_len > 0 ? max_len - 1 : 0;
190+
if (text.size() <= hold) {
191+
return 0;
192+
}
193+
return utf8_safe_prefix_len(text, text.size() - hold);
194+
}
195+
84196
// ----------------------------------------------------------------------------
85197
// utilities: time
86198

0 commit comments

Comments
 (0)