Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ embedder:
model_name: jinaai/jina-embeddings-v3
base_url: http://vllm:8000/v1
api_key: EMPTY
max_model_len: 8192
max_model_len: 2047 # truncate_prompt_tokens = this - 1; keep below the model's context boundary (vllm#29496)
timeout: 120.0 # per-request HTTP timeout (s); raise for slow remote endpoints
batch_size: 32 # chunks per embedding request (big docs are split into batches)
embed_concurrency: 4 # max embedding requests in flight at once
Expand Down
5 changes: 4 additions & 1 deletion openrag/core/config/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class EmbedderConfig(ConfigMixin):
model_name: str = "jinaai/jina-embeddings-v3"
base_url: str = "http://vllm:8000/v1"
api_key: str = Field(default="EMPTY", repr=False)
max_model_len: int = 8192
# 2047 (just below the 2048 boundary): the embedder sends
# truncate_prompt_tokens = max_model_len - 1, avoiding the Qwen3-Embedding
# context-boundary hang (vllm-project/vllm#29496).
max_model_len: int = Field(default=2047, gt=0)
# Constrained > 0: a bad env var should fail at config load, not silently
# degrade into surprising runtime behavior (VLLMEmbedder rewrites a
# non-positive batch_size/embed_concurrency to 1 and would pass a <= 0
Expand Down
13 changes: 13 additions & 0 deletions openrag/di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,24 @@ def factory(_name: str = "default"):
def _wire_named_component_factories(self, settings: Settings) -> None:
"""Wire Phase 14 named inference factories from ``settings.models``."""
models = settings.models
embed_defaults = settings.embedder

def _embedder_extra_kwargs(cfg: Any) -> dict[str, Any]:
"""Backfill max_model_len/embed_concurrency from static settings when the
endpoint's ``extra`` omits them (an explicit ``extra`` value wins)."""
defaults: dict[str, Any] = {}
if "max_model_len" not in cfg.extra:
defaults["max_model_len"] = embed_defaults.max_model_len
if "embed_concurrency" not in cfg.extra:
defaults["embed_concurrency"] = embed_defaults.embed_concurrency
return defaults

self.embedder_factory, self._embedder_cache = make_component_factory(
registry=embedder_registry,
config_section=models.embedder,
default_impl="vllm",
client_caches=self._client_caches,
extra_kwargs_fn=_embedder_extra_kwargs,
)
self.reranker_factory, self._reranker_cache = make_component_factory(
registry=reranker_registry,
Expand Down
17 changes: 16 additions & 1 deletion openrag/services/inference/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,17 @@ def __init__(
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
self._client = httpx.AsyncClient(timeout=timeout, headers=headers)
logger.bind(
model=self._model,
max_model_len=self._max_model_len,
batch_size=self._batch_size,
embed_concurrency=self._embed_concurrency,
).debug("VLLMEmbedder ready")
if self._max_model_len is None:
logger.bind(model=self._model).warning(
"VLLMEmbedder built without max_model_len — truncate_prompt_tokens disabled; "
"pooling models can hang/400 on context-boundary inputs (vllm#29496)."
)

async def embed(self, texts: list[str]) -> list[list[float]]:
"""Embed *texts*, splitting large inputs into bounded-concurrent batches.
Expand Down Expand Up @@ -243,7 +254,11 @@ async def _run(batch: list[str]) -> list[list[float]]:
*tasks,
desc=f"Embedding {len(texts)} chunks ({len(batches)} batches of {self._batch_size})",
)
except BaseException:
except BaseException as exc:
done = sum(1 for task in tasks if task.done() and not task.cancelled() and task.exception() is None)
logger.bind(batches_done=done, n_batches=len(batches), error=repr(exc)).warning(
"Embedding failed after {d}/{b} batches", d=done, b=len(batches)
)
for task in tasks:
if not task.done():
task.cancel()
Expand Down
9 changes: 9 additions & 0 deletions openrag/services/workers/indexer_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,15 @@ def factory(name: str = "default") -> Any:
return entry[1]
impl_kwargs = {key: value for key, value in model_cfg.extra.items() if key != "implementation"}
impl = model_cfg.extra.get("implementation", "vllm")
# Backfill max_model_len/embed_concurrency from static settings when the
# endpoint's `extra` omits them — otherwise truncate_prompt_tokens is off
# and pooling models hang/400 on boundary inputs (vllm#29496). Explicit
# `extra` wins. Mirrors the API container's _embedder_extra_kwargs.
embed_defaults = getattr(cfg, "embedder", None)
for default_key in ("max_model_len", "embed_concurrency"):
default = getattr(embed_defaults, default_key, None)
if default is not None:
impl_kwargs.setdefault(default_key, default)
instance = embedder_registry.create(
impl,
endpoint=model_cfg.endpoint,
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/di/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,17 @@ def test_container_wires_named_factories_from_models_config(self):
llm = c.llm_factory("chat-a")
vlm = c.vlm_factory("vision-a")

# max_model_len / embed_concurrency aren't carried by the endpoint config,
# so the embedder factory backfills them from the static ``embedder`` settings.
assert embedder.kwargs == {
"endpoint": "http://embedder:8000/v1",
"model_name": "embed-model",
"batch_size": 16,
"timeout": 12.5,
"api_key": "embed-key",
"dimension": 384,
"max_model_len": 2047,
"embed_concurrency": 4,
}
assert reranker.kwargs["endpoint"] == "http://reranker:8000"
assert reranker.kwargs["model_name"] == "rank-model"
Expand All @@ -564,6 +568,14 @@ def test_container_wires_named_factories_from_models_config(self):
assert vlm.kwargs["endpoint"] == "http://vlm:8000/v1"
assert vlm.kwargs["max_tokens"] == 256

def test_embedder_extra_overrides_backfilled_max_model_len(self):
"""A per-endpoint ``extra.max_model_len`` wins over the settings default."""
settings = _settings_with_named_models()
settings.models.embedder["embed-a"].extra["max_model_len"] = 4096
c = ServiceContainer(settings)

assert c.embedder_factory("embed-a").kwargs["max_model_len"] == 4096

def test_named_factories_cache_by_endpoint_name(self):
"""Repeated factory calls for the same endpoint return one client."""
c = ServiceContainer(_settings_with_named_models())
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/services/workers/test_indexer_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,44 @@ def __init__(self, **kwargs):
embedder_registry._registry.pop("live-probe-embedder", None)


def test_embedder_factory_backfills_max_model_len_from_settings() -> None:
"""A named endpoint whose `extra` omits max_model_len inherits it from the
static embedder settings (an explicit per-endpoint value still wins)."""
from core.embeddings import embedder_registry
from services.workers.indexer_pool import _build_embedder_factory

class ProbeEmbedder:
def __init__(self, **kwargs):
self.kwargs = kwargs

embedder_registry.register("backfill-probe-embedder")(ProbeEmbedder)
try:
registry: dict = {
"no-extra": ModelEndpointConfig(
endpoint="http://embed.example/v1",
model_name="embed-model",
extra={"implementation": "backfill-probe-embedder", "api_key": "k"},
),
"explicit": ModelEndpointConfig(
endpoint="http://embed.example/v1",
model_name="embed-model",
extra={"implementation": "backfill-probe-embedder", "max_model_len": 4096},
),
}
cfg = SimpleNamespace(
models=SimpleNamespace(embedder=registry),
embedder=SimpleNamespace(max_model_len=2047, embed_concurrency=4),
)

factory = _build_embedder_factory(cfg)
backfilled = factory("no-extra")
assert backfilled.kwargs["max_model_len"] == 2047
assert backfilled.kwargs["embed_concurrency"] == 4
assert factory("explicit").kwargs["max_model_len"] == 4096 # per-endpoint extra wins
finally:
embedder_registry._registry.pop("backfill-probe-embedder", None)


def test_embedder_factory_rebuilds_on_api_key_rotation() -> None:
from core.embeddings import embedder_registry
from services.workers.indexer_pool import _build_embedder_factory
Expand Down
Loading