Skip to content

Commit 5f302dd

Browse files
committed
[UPDATE] Update
[ghstack-poisoned]
2 parents ab7c8bc + 6d66a4d commit 5f302dd

9 files changed

Lines changed: 521 additions & 75 deletions

File tree

extension/llm/runner/llm_session.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,11 @@ struct LLMServingCapacity {
6161
// sessions would copy the whole model); raise only on a backend proven to
6262
// share packed weights.
6363
int32_t max_physical_sessions_without_weight_duplication = 1;
64-
// Planned bytes one session adds (KV + activations), for memory-budget
65-
// admission. 0 = unknown; the server skips the memory clamp.
64+
// Planned bytes one session adds (KV + activations). Reported for a FUTURE
65+
// memory-budget admission policy; NOT yet enforced -- admission is currently
66+
// by session COUNT only (--max-sessions). Over-provisioning therefore fails
67+
// at the first execute (cudaMalloc) of the over-committed session, not at
68+
// admit time. 0 = unknown.
6669
int64_t estimated_bytes_per_session = 0;
6770
};
6871

@@ -79,14 +82,28 @@ class ET_EXPERIMENTAL LLMSession {
7982
/// `initial_sampling` (optional): the sampling config for the FIRST generated
8083
/// token, for backends that sample during prefill (e.g. in-graph sampling).
8184
/// Pass it so the first token uses the request's sampling instead of a stale
82-
/// default. Backends that only sample in decode_one() ignore it.
85+
/// default. Backends that only sample in decode_one() ignore it. NOTE:
86+
/// because the first token is sampled here, it does NOT pass through
87+
/// decode_one()'s logit processors -- a grammar/tool mask that must constrain
88+
/// the opening token is not applied to it (a known limitation for
89+
/// grammar-constrained serving).
90+
///
91+
/// ERROR CONTRACT: an error may be returned AFTER backend state has already
92+
/// mutated. On any error from prefill_tokens()/decode_one(), the session is
93+
/// POISONED -- position() may no longer agree with the resident KV. The
94+
/// caller must call reset() (and only proceed once it returns Ok) before any
95+
/// further prefill/decode; it must NOT retry the failed call. The serving
96+
/// worker enforces this (marks the session dirty and forces a reset next
97+
/// request).
8398
virtual ::executorch::runtime::Error prefill_tokens(
8499
std::vector<uint64_t> tokens,
85100
const SamplingConfig* initial_sampling = nullptr) = 0;
86101

87102
/// Decode one token from the pending state; looping reproduces a full
88103
/// generation while returning exact sampled token ids. A single decode_one()
89104
/// runs one forward pass and is not interruptible mid-call (see stop()).
105+
/// On error the session is poisoned -- see the error contract on
106+
/// prefill_tokens() (reset() before any further use; never retry).
90107
virtual ::executorch::runtime::Result<DecodeResult> decode_one(
91108
const SamplingConfig& sampling) = 0;
92109

extension/llm/server/python/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,14 @@ Session capacity is determined by the worker/engine — a single worker hosts ma
160160
isolated sessions on one weight load — so `--num-runners` accepts 1; extra worker
161161
processes would each carry their own copy of the weights.
162162

163+
The **generic `text_llm_worker` is scratch-only (V1)**: `TextLLMEngine::serving_capacity()`
164+
is a conservative 1, so `max_named = max(0, capacity-1) = 0` — the default
165+
`server.py` serves only the anonymous scratch session (no named `session_id`s, no
166+
warm resume). The named-session / warm-resume / token-ID machinery is exercised
167+
by a model-specific worker whose engine reports capacity > 1 (the Qwen3.5-MoE CUDA
168+
worker). This is intentional; the generic worker stays minimal until a backend is
169+
proven to host multiple physical sessions without duplicating weights.
170+
163171
Cancellation is best-effort: a worker request runs to completion and is not
164172
interruptible mid-generation in V1, so `runner.stop()` means "the control plane
165173
stops consuming and the worker finishes the current request" rather than a hard

extension/llm/server/python/chat_template.py

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,38 @@
2525

2626
_DEFAULT_SPECIAL_TOKENS = ["<|im_end|>", "<|endoftext|>", "<|eot_id|>", "<|end|>"]
2727

28+
# Chat turn terminators eligible to be used as generation stop strings. This is a
29+
# deliberate allowlist of end-of-turn / end-of-text tokens -- NOT the tokenizer's
30+
# full special-token set. Structural/tool delimiters (e.g. <tool_call>) must reach
31+
# the tool parser, so they are intentionally excluded: using them as hard stops
32+
# would truncate a tool call before it is ever parsed.
33+
_TURN_TERMINATORS = (
34+
"<|im_end|>",
35+
"<|endoftext|>",
36+
"<|eot_id|>",
37+
"<|end|>",
38+
"<|end_of_text|>",
39+
"<end_of_turn>",
40+
"</s>",
41+
)
42+
43+
44+
def _content_text(content) -> str:
45+
"""Best-effort text for the ChatML fallback: a str as-is, or the concatenated
46+
text parts of an OpenAI list-content message (non-text parts dropped). Avoids
47+
rendering a Python repr of structured content. None -> empty string."""
48+
if isinstance(content, str):
49+
return content
50+
if isinstance(content, list):
51+
out = []
52+
for part in content:
53+
if isinstance(part, dict) and part.get("type") == "text":
54+
out.append(str(part.get("text", "")))
55+
elif isinstance(part, str):
56+
out.append(part)
57+
return "".join(out)
58+
return str(content or "")
59+
2860

2961
def _decode_tool_call_arguments(messages: list[dict[str, Any]]) -> None:
3062
"""In-place: parse each tool call's ``function.arguments`` from a JSON string
@@ -120,25 +152,64 @@ def count_tokens(self, prompt: str) -> Optional[int]:
120152
return len(self._hf.encode(prompt, add_special_tokens=False))
121153
return None
122154

123-
def special_tokens(self) -> list[str]:
124-
"""Special-token strings whose appearance ends the visible content.
155+
def turn_stop_sequences(self) -> list[str]:
156+
"""Generation stop strings: model/template-specific *turn terminators*
157+
only -- the tokenizer's EOS plus known chat turn-end tokens -- NOT the
158+
full special-token set.
159+
160+
Structural/tool delimiters (e.g. <tool_call>) are deliberately excluded:
161+
if a tokenizer registers them as special, using the whole special set as
162+
hard stops would halt generation at the delimiter and truncate the tool
163+
call before the parser ever sees it. Whitespace-only tokens are dropped.
164+
User-supplied request `stop` strings are handled separately and are not
165+
affected by this set.
166+
167+
May return [] if the tokenizer has no eos_token and registers none of the
168+
known terminators as special; in that case end-of-turn detection relies
169+
entirely on the worker's EOS-by-token-id check (e.g. the Qwen engine adds
170+
<|im_end|> to eos_ids), so the string set here is only a backstop.
171+
"""
172+
if self._hf is None:
173+
return list(_DEFAULT_SPECIAL_TOKENS)
174+
specials = {
175+
t
176+
for t in (getattr(self._hf, "all_special_tokens", []) or [])
177+
if isinstance(t, str) and t.strip()
178+
}
179+
out: list[str] = []
180+
eos = getattr(self._hf, "eos_token", None)
181+
if isinstance(eos, str) and eos.strip():
182+
out.append(eos)
183+
for t in _TURN_TERMINATORS:
184+
if t in specials and t not in out:
185+
out.append(t)
186+
return out
125187

126-
From the HF tokenizer when available (model-accurate), else a default set
127-
covering common chat models.
188+
def special_tokens(self) -> list[str]:
189+
"""ALL special-token strings, for final content cleanup -- stripping any
190+
special token that leaked into visible output. Deliberately broad, and
191+
distinct from turn_stop_sequences(): this set must NOT be used as
192+
generation stops or pre-parse truncation (that would halt/cut a tool call
193+
at a structural delimiter), only to scrub trailing specials from the
194+
already-parsed visible content. Whitespace-only tokens are dropped so a
195+
stray ' ' token can't truncate content at the first double space.
128196
"""
129197
if self._hf is not None:
130198
toks = list(getattr(self._hf, "all_special_tokens", []) or [])
131-
return [t for t in toks if isinstance(t, str) and t]
199+
return [t for t in toks if isinstance(t, str) and t.strip()]
132200
return list(_DEFAULT_SPECIAL_TOKENS)
133201

134202
@staticmethod
135203
def _fallback(messages: list[ChatMessage]) -> str:
136-
# Approximate ChatML. Provide --hf-tokenizer for model-correct formatting
137-
# (including reasoning controls like enable_thinking, which the fallback
138-
# cannot reproduce).
204+
# Approximate ChatML, TEXT-ONLY. Provide --hf-tokenizer for model-correct
205+
# formatting (reasoning controls like enable_thinking, and structured
206+
# tool/multimodal turns, which this fallback cannot reproduce). This path
207+
# renders only text content: assistant `tool_calls` and a tool-role
208+
# `tool_call_id` are dropped, so it is NOT a correctness path for tool or
209+
# multimodal conversations -- use a real --hf-tokenizer for those.
139210
parts = []
140211
for m in messages:
141-
content = m.content if isinstance(m.content, str) else str(m.content or "")
212+
content = _content_text(m.content)
142213
parts.append(f"<|im_start|>{m.role}\n{content}<|im_end|>")
143214
parts.append("<|im_start|>assistant\n")
144215
return "\n".join(parts)

extension/llm/server/python/serving_chat.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,16 @@ def __init__(
6161
# Detector CLASS; a fresh instance is created per request so streaming
6262
# state is never shared across concurrent requests.
6363
self._tool_detector_cls = tool_detector_cls
64-
# Special tokens (e.g. <|im_end|>) the runner decodes to text; we cut the
65-
# visible content at the first one so they don't leak into responses.
66-
self._stops = template.special_tokens()
64+
# Two distinct sets (see chat_template):
65+
# * _stops: NARROW turn terminators (e.g. <|im_end|>) used as generation
66+
# stops AND for pre-parse truncation (_options/_collect_until_stop/
67+
# _truncate_raw/_clean). Excludes structural/tool delimiters so a
68+
# <tool_call> is never halted or cut before _extract_tools sees it.
69+
# * _content_specials: BROAD all-special-tokens set, used ONLY by
70+
# _strip_specials for final cleanup of the already-parsed visible
71+
# content, so a stray special token can't leak to the user.
72+
self._stops = template.turn_stop_sequences()
73+
self._content_specials = template.special_tokens()
6774

6875
@staticmethod
6976
def _tool_schemas(req: ChatCompletionRequest) -> dict[str, dict]:
@@ -80,7 +87,9 @@ def _tool_schemas(req: ChatCompletionRequest) -> dict[str, dict]:
8087
return schemas
8188

8289
def _strip_specials(self, text: str) -> str:
83-
cut = _earliest_stop(text, self._stops)
90+
# Broad set: scrub ANY special token that leaked into already-parsed
91+
# visible content (not the narrow generation-stop set).
92+
cut = _earliest_stop(text, self._content_specials)
8493
return text[:cut] if cut is not None else text
8594

8695
@staticmethod
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Tests for HermesDetector (Hermes/Qwen JSON <tool_call> format).
8+
9+
Covers the explicit all-or-nothing malformed-call policy and the no-markup-leak
10+
guarantee: an undefined/malformed/truncated call degrades to the leading text
11+
with the <tool_call> markup stripped, never surfaced to the client.
12+
"""
13+
14+
import json
15+
16+
from executorch.extension.llm.server.python.tool_parsers import HermesDetector
17+
18+
_TOOLS = {
19+
"get_weather": {"type": "object", "properties": {"city": {"type": "string"}}},
20+
"echo": {"type": "object", "properties": {"text": {"type": "string"}}},
21+
}
22+
23+
24+
def _parse(text, tools=_TOOLS):
25+
return HermesDetector().detect_and_parse(text, tools)
26+
27+
28+
def test_basic_call():
29+
text = (
30+
'<tool_call>{"name": "get_weather", "arguments": {"city": "Paris"}}</tool_call>'
31+
)
32+
r = _parse(text)
33+
assert len(r.calls) == 1 and r.calls[0].name == "get_weather"
34+
assert json.loads(r.calls[0].arguments) == {"city": "Paris"}
35+
36+
37+
def test_multiple_calls_still_parse():
38+
text = (
39+
'<tool_call>{"name": "echo", "arguments": {"text": "a"}}</tool_call>'
40+
'<tool_call>{"name": "echo", "arguments": {"text": "b"}}</tool_call>'
41+
)
42+
r = _parse(text)
43+
assert [json.loads(c.arguments)["text"] for c in r.calls] == ["a", "b"]
44+
45+
46+
def test_no_tool_call_is_passthrough():
47+
r = _parse("just some text")
48+
assert not r.calls and r.normal_text == "just some text"
49+
50+
51+
def test_malformed_block_with_valid_sibling_degrades_no_leak():
52+
# All-or-nothing: one malformed block degrades the WHOLE response (the valid
53+
# sibling is NOT emitted in isolation), and no <tool_call> markup leaks.
54+
text = (
55+
'lead<tool_call>{"name": "echo", "arguments": {"text": "ok"}}</tool_call>'
56+
"<tool_call>{bad json}</tool_call>"
57+
)
58+
r = _parse(text)
59+
assert not r.calls
60+
assert "<tool_call>" not in r.normal_text
61+
assert r.normal_text == "lead"
62+
63+
64+
def test_unclosed_marker_degrades_no_leak():
65+
text = 'lead<tool_call>{"name": "echo", "arguments": {"text": "x"}}'
66+
r = _parse(text)
67+
assert not r.calls
68+
assert "<tool_call>" not in r.normal_text
69+
assert r.normal_text == "lead"
70+
71+
72+
def test_string_value_containing_close_marker_not_truncated():
73+
# A JSON string value containing literal </tool_call> must not truncate the
74+
# captured JSON (raw_decode parses the whole object regardless).
75+
text = (
76+
'<tool_call>{"name": "echo", "arguments": '
77+
'{"text": "a </tool_call> b"}}</tool_call>'
78+
)
79+
r = _parse(text)
80+
assert len(r.calls) == 1
81+
assert json.loads(r.calls[0].arguments) == {"text": "a </tool_call> b"}
82+
83+
84+
def test_arguments_null_falls_back_to_parameters():
85+
text = (
86+
'<tool_call>{"name": "echo", "arguments": null, '
87+
'"parameters": {"text": "p"}}</tool_call>'
88+
)
89+
r = _parse(text)
90+
assert json.loads(r.calls[0].arguments) == {"text": "p"}
91+
92+
93+
def test_undefined_tool_degrades_to_full_text():
94+
# A WELL-FORMED call to an undefined tool degrades the whole response to
95+
# visible text (unchanged policy: surface the model's intent, never a partial
96+
# set). This differs from the malformed/truncated case, which strips markup.
97+
text = 'hi<tool_call>{"name": "nope", "arguments": {}}</tool_call>'
98+
r = _parse(text)
99+
assert not r.calls
100+
assert "<tool_call>" in r.normal_text # full text, markup visible

extension/llm/server/python/tests/test_qwen_tool_parser.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,95 @@ def test_untyped_param_falls_back_to_json_guess():
124124
)
125125
r = _parse(text, tools)
126126
assert json.loads(r.calls[0].arguments) == {"n": 42, "items": [1, 2]}
127+
128+
129+
_TYPED = {
130+
"code_tool": {"type": "object", "properties": {"code": {"type": "string"}}},
131+
"calc": {
132+
"type": "object",
133+
"properties": {
134+
"n": {"type": "integer"},
135+
"x": {"type": "number"},
136+
"flag": {"type": "boolean"},
137+
},
138+
},
139+
}
140+
141+
142+
def test_param_value_with_literal_parameter_close():
143+
# A value containing literal </parameter> must be preserved, not truncated.
144+
text = "<function=code_tool><parameter=code>a </parameter> b</parameter></function>"
145+
r = _parse(text, _TYPED)
146+
assert json.loads(r.calls[0].arguments) == {"code": "a </parameter> b"}
147+
148+
149+
def test_param_value_with_function_markup():
150+
# A value containing <function=...> markup must stay in the value, not split.
151+
text = (
152+
"<function=code_tool><parameter=code>x = <function=foo></parameter></function>"
153+
)
154+
r = _parse(text, _TYPED)
155+
assert len(r.calls) == 1
156+
assert json.loads(r.calls[0].arguments) == {"code": "x = <function=foo>"}
157+
158+
159+
def test_declared_integer_with_float_string_kept_raw():
160+
text = "<function=calc><parameter=n>10.0</parameter></function>"
161+
val = json.loads(_parse(text, _TYPED).calls[0].arguments)["n"]
162+
assert val == "10.0" and isinstance(val, str) # not float 10.0
163+
164+
165+
def test_declared_boolean_with_one_kept_raw():
166+
text = "<function=calc><parameter=flag>1</parameter></function>"
167+
val = json.loads(_parse(text, _TYPED).calls[0].arguments)["flag"]
168+
assert val == "1" and isinstance(val, str) # not int 1
169+
170+
171+
def test_declared_integer_with_underscores_kept_raw():
172+
text = "<function=calc><parameter=n>1_000</parameter></function>"
173+
val = json.loads(_parse(text, _TYPED).calls[0].arguments)["n"]
174+
assert val == "1_000" and isinstance(val, str) # not int 1000
175+
176+
177+
def _reject_bare_constant(c):
178+
# json.loads parse_constant hook: fires only for bare NaN/Infinity/-Infinity.
179+
raise AssertionError(f"emitted bare non-finite constant: {c}")
180+
181+
182+
def test_declared_number_non_finite_never_emitted():
183+
for bad in ("NaN", "Infinity", "-Infinity", "1e999"):
184+
text = f"<function=calc><parameter=x>{bad}</parameter></function>"
185+
args = _parse(text, _TYPED).calls[0].arguments
186+
# Strict-client safe: no bare NaN/Infinity constant in the emitted JSON.
187+
json.loads(args, parse_constant=_reject_bare_constant)
188+
assert json.loads(args)["x"] == bad # kept as the raw string
189+
190+
191+
def test_multiple_valid_calls_still_parse():
192+
text = (
193+
"<function=add><parameter=a>1</parameter><parameter=b>2</parameter></function>"
194+
"<function=add><parameter=a>3</parameter><parameter=b>4</parameter></function>"
195+
)
196+
r = _parse(text)
197+
assert [json.loads(c.arguments) for c in r.calls] == [
198+
{"a": 1, "b": 2},
199+
{"a": 3, "b": 4},
200+
]
201+
202+
203+
def test_truncated_call_degrades_without_leaking_markup():
204+
# A call cut off by max_tokens (no closing tags) must NOT leak the partial
205+
# <function=...> markup -- only the leading text survives (mirrors Hermes).
206+
text = "Sure! <function=get_weather><parameter=city>Paris"
207+
r = _parse(text, _TYPED)
208+
assert not r.calls
209+
assert "<function=" not in r.normal_text
210+
assert r.normal_text == "Sure!"
211+
212+
213+
def test_truncated_tool_call_wrapper_no_leak():
214+
text = "ok <tool_call>\n<function=get_weather><parameter=city>Par"
215+
r = _parse(text, _TYPED)
216+
assert not r.calls
217+
assert "<tool_call>" not in r.normal_text and "<function=" not in r.normal_text
218+
assert r.normal_text == "ok"

0 commit comments

Comments
 (0)