Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
68f0201
lmi-native configuration, fallbacks, retries
sidnarayanan Apr 17, 2026
51d3bb9
clean up docstrings
sidnarayanan Apr 17, 2026
5804864
drop router
sidnarayanan Apr 17, 2026
9689698
add dispatch/retry tests
sidnarayanan Apr 17, 2026
0c93218
update refusal fallback test
sidnarayanan Apr 17, 2026
2715531
update more tests
sidnarayanan Apr 17, 2026
232e1cb
make BadRequestError fallbackable
sidnarayanan Apr 17, 2026
d7bf950
agents accept LLMConfig
sidnarayanan Apr 17, 2026
edbe8e5
update cassettes
sidnarayanan Apr 17, 2026
be876fd
drop deepseek tests
sidnarayanan Apr 17, 2026
4609fc2
regen cassettes
sidnarayanan Apr 18, 2026
02b829e
ModelSpec.responses_api
sidnarayanan Apr 18, 2026
ca77c21
update cassettes
sidnarayanan Apr 18, 2026
f57642b
fix anthropic n>1 test
sidnarayanan Apr 18, 2026
8de5cf6
rename llm_model->llm_config in tests
sidnarayanan Apr 18, 2026
7fcc321
regen another cassette
sidnarayanan Apr 18, 2026
6a12634
fix refurb
sidnarayanan Apr 19, 2026
c993938
fix refurb
sidnarayanan Apr 19, 2026
9547c20
bump to litellm==1.83.0
sidnarayanan Apr 19, 2026
741d709
uv lock
sidnarayanan Apr 19, 2026
cd06f99
push response validation down into lmi
sidnarayanan Apr 19, 2026
9834861
drop max_tokens=4096 default
sidnarayanan Apr 19, 2026
6d8d567
fix refurb
sidnarayanan Apr 20, 2026
b0f2b54
add retry logging
sidnarayanan Apr 20, 2026
223a249
verbose logging
sidnarayanan Apr 20, 2026
3f87402
pre-dispatch log
sidnarayanan Apr 20, 2026
f3fb059
update litellm pin
sidnarayanan Apr 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/lmi/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
"coredis>=3.0.1", # Lower pin v3 for Redis class, pin v3.0.1 for fix in https://github.com/alisaifee/coredis/commit/7e9b1a1b384cd97725cab479d2ce091e3b0823d2
"fhaviary>=0.14.0", # For multi-image support
"limits[async-redis]>=4.8", # Specify 'async-redis' since that's what we use. Lower pin for RedisBridge.key_prefix.
"litellm>=1.81.10,<=1.82.6", # Lower pin for MAX_CALLBACKS refactor from https://github.com/BerriAI/litellm/pull/20781, upper pin for supply chain attack
"litellm>=1.83.10", # Lower pin for MAX_CALLBACKS refactor from https://github.com/BerriAI/litellm/pull/20781, upper pin for supply chain attack
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think rebase atop main as this has been changed a bit

"openai>=2", # Pin to keep recent
"orjson", # Required by litellm for Responses API
"pydantic~=2.0,>=2.10.1",
Expand Down
279 changes: 279 additions & 0 deletions packages/lmi/src/lmi/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
"""Configuration types for LMI.

An `LLMConfig` is an ordered chain of `ModelSpec` entries. `models[0]` is the
primary model; `models[1:]` are fallbacks tried in order when the primary
fails in ways that another model might handle.

`LLMConfig.from_legacy_dict` accepts the dict-shaped configuration
(`{model_list, fallbacks, router_kwargs}`) that mirrors litellm's Router layout.
"""

from __future__ import annotations

from collections.abc import Awaitable, Callable
from typing import Annotated, Any

import litellm
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, SecretStr

from lmi.constants import DEFAULT_VERTEX_SAFETY_SETTINGS
from lmi.types import LLMResult

ResponseValidator = Callable[[LLMResult], Awaitable[None] | None]

_DEFAULT_TEMPERATURE = 1.0
_OPENAI_ONLY_PARAMS = frozenset({"logprobs", "top_logprobs"})

# Per-call retry kwargs that LiteLLM honors via its own internal retry loop. LMI
# owns retries through `_run_with_fallbacks` + `ModelSpec.max_retries`, so these
# must never reach `litellm.acompletion`/`litellm.aresponses` regardless of how
# they ended up in `ModelSpec.extra_params`.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# they ended up in `ModelSpec.extra_params`.
# they ended up in our `ModelSpec.extra_params`.

Felt we should clarify that ModelSpec is our thing

_LITELLM_RETRY_KWARGS = frozenset({"num_retries", "max_retries"})


class ModelSpec(BaseModel):
"""One model in an `LLMConfig` chain."""

model_config = ConfigDict(extra="forbid")

name: str = Field(
description=(
"LiteLLM model string, e.g. 'gpt-4o-mini' or 'claude-3-5-sonnet-20241022'."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But don't LiteLLM model string start with like openai/gpt-4o-mini? Maybe clarify this.

),
)
api_base: str | None = None
api_key: SecretStr | None = None
timeout: float = Field(default=60.0, description="Per-request timeout in seconds.")
max_retries: int = Field(
default=3,
description=(
"Retries against this model before falling over to the next entry"
" in the chain."
),
)
extra_params: dict[str, Any] = Field(
default_factory=dict,
description=(
"Pass-through kwargs for litellm.acompletion / litellm.aresponses,"
" e.g. temperature, max_tokens, safety_settings, vertex_project."
),
)
responses_api: bool = Field(
default=False,
description=(
"If True, dispatch this model via OpenAI's stateful Responses API"
" (`litellm.aresponses`) instead of the Chat Completions API"
" (`litellm.acompletion`)."
),
)

@classmethod
def from_name(cls, name: str, **overrides: Any) -> ModelSpec:
"""Build a `ModelSpec` with provider-aware defaults for `extra_params`.

Applies: Gemini default safety settings; `temperature` / `max_tokens`
defaults; and silent drop of `logprobs` / `top_logprobs` for non-OpenAI
providers (which don't support them). Explicit values in `overrides`
always win over the defaults.
Comment on lines +74 to +77
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModelSpec.from_name() docstring says it will “silently drop” logprobs / top_logprobs for non-OpenAI providers, but the implementation raises a ValueError instead. Please align the docstring with the actual behavior (or change the behavior) so callers know whether to expect coercion or an exception.

Suggested change
Applies: Gemini default safety settings; `temperature` / `max_tokens`
defaults; and silent drop of `logprobs` / `top_logprobs` for non-OpenAI
providers (which don't support them). Explicit values in `overrides`
always win over the defaults.
Applies: Gemini default safety settings and `temperature` /
`max_tokens` defaults. `logprobs` and `top_logprobs` are treated as
OpenAI-only parameters: passing them for non-OpenAI providers raises
`ValueError`. Explicit values in `overrides` always win over the
defaults.

Copilot uses AI. Check for mistakes.

`overrides` may set any `ModelSpec` field, plus request-shape kwargs
(`temperature`, `max_tokens`, `n`, `logprobs`, `top_logprobs`,
`safety_settings`, ...) which are merged into `extra_params`.
"""
is_openai = _is_openai_provider(name)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using litellm for this instead of a DIY'd helper, it can detect this using litellm.get_llm_provider


extra: dict[str, Any] = {}
if "gemini" in name:
extra["safety_settings"] = DEFAULT_VERTEX_SAFETY_SETTINGS
extra["temperature"] = _DEFAULT_TEMPERATURE

spec_field_overrides: dict[str, Any] = {}
extra_overrides: dict[str, Any] = {}
for key, value in overrides.items():
if key in cls.model_fields:
spec_field_overrides[key] = value
elif key in _OPENAI_ONLY_PARAMS and not is_openai:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gemini supports these too, can you have Gemini work with these too?

raise ValueError(
f"{key!r} is only supported on OpenAI models; got {name!r}."
)
else:
extra_overrides[key] = value

spec_field_overrides.setdefault("name", name)
merged_extra = (
extra | extra_overrides | spec_field_overrides.pop("extra_params", {})
)
return cls(extra_params=merged_extra, **spec_field_overrides)
Comment on lines +90 to +106
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do this validation in a field/model_validator, so it applies to all ModelSpec, not just this classmethod?


def to_litellm_kwargs(self) -> dict[str, Any]:
"""Flatten into kwargs for `litellm.acompletion` / `litellm.aresponses`."""
sanitized_extra = {
k: v for k, v in self.extra_params.items() if k not in _LITELLM_RETRY_KWARGS
}
out: dict[str, Any] = {
"model": self.name,
"timeout": self.timeout,
} | sanitized_extra
if self.api_base is not None:
out["api_base"] = self.api_base
if self.api_key is not None:
out["api_key"] = self.api_key.get_secret_value()
return out


class LLMConfig(BaseModel):
"""Ordered model chain: `models[0]` is primary, `models[1:]` are fallbacks."""

model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)

models: list[ModelSpec] = Field(
min_length=1,
description=(
"Ordered list of models. The first entry is the primary; subsequent"
" entries are tried in order when earlier models fail in ways that"
" another model might handle (context overflow, content policy,"
" model-unavailable, or exhausted retries)."
),
)
response_validator: ResponseValidator | None = Field(
default=None,
exclude=True,
description=(
"Optional callable invoked on each successful `LLMResult`. Raises"
" any exception to reject the response; we convert that into"
" `ResponseValidationError` and let the retry/fallback loop"
" handle it."
),
)

@classmethod
def coerce(cls, v: Any) -> LLMConfig:
"""Accept an `LLMConfig` or any dict shape LMI knows about.

Supported inputs:

- an `LLMConfig` instance (passes through)
- a dict with `"models"` — the typed-dict form of `LLMConfig`
- a dict with `"model_list"` — the legacy litellm-Router shape; see
`from_legacy_dict`
- a dict with `"name"` — a bare model name plus flat request-shape
kwargs (e.g. `temperature`, `max_tokens`); built via
`ModelSpec.from_name`
"""
if isinstance(v, cls):
return v
if not isinstance(v, dict):
raise TypeError(f"Cannot build an LLMConfig from {type(v).__name__}")
if "models" in v:
return cls.model_validate(v)
if "model_list" in v:
return cls.from_legacy_dict(v)
if "name" in v:
kwargs = dict(v)
name = kwargs.pop("name")
return cls(models=[ModelSpec.from_name(name, **kwargs)])
raise ValueError(
"Can't infer LLMConfig shape from dict; expected 'models',"
" 'model_list', or 'name' key"
)

def with_extra_params(self, **params: Any) -> LLMConfig:
"""Return a copy where every `ModelSpec.extra_params` has `params` merged in.

Useful for chain-wide request-shape additions like stop sequences: the
caller doesn't have to rebuild each spec individually, and the original
`LLMConfig` is left untouched.
"""
return self.model_copy(
update={
"models": [
m.model_copy(update={"extra_params": m.extra_params | params})
for m in self.models
]
}
)

@classmethod
def from_legacy_dict(cls, legacy: dict[str, Any]) -> LLMConfig:
"""Build an `LLMConfig` from the legacy dict-shaped configuration.

The legacy shape is `{model_list: [{model_name, litellm_params}, ...],
fallbacks: [{primary_name: [fallback_names, ...]}, ...], router_kwargs: {...}}`.
The `fallbacks` list is flattened into the ordering of `models`; any
entries in `model_list` not reached by the primary's fallback chain are
appended at the end.
"""
model_list = legacy.get("model_list") or []
if not model_list:
raise ValueError("Legacy config has empty or missing model_list")

fallback_map: dict[str, list[str]] = {}
for entry in legacy.get("fallbacks") or []:
fallback_map.update(entry)

params_by_name: dict[str, dict[str, Any]] = {
m["model_name"]: dict(m.get("litellm_params", {})) for m in model_list
}

primary_name = model_list[0]["model_name"]
ordered: list[str] = [primary_name, *fallback_map.get(primary_name, [])]
for name in params_by_name:
if name not in ordered:
ordered.append(name)

router_kwargs = legacy.get("router_kwargs") or {}
default_timeout = router_kwargs.get("timeout", 60.0)
default_retries = router_kwargs.get("num_retries", 3)
Comment on lines +225 to +226
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we can not inject , 60.0 and , 3 defaults here? The less injection we have, the better


return cls(
models=[
_spec_from_legacy_params(
params_by_name.get(name, {}),
default_timeout=default_timeout,
default_retries=default_retries,
)
for name in ordered
]
Comment on lines +214 to +236
Copy link

Copilot AI Apr 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLMConfig.from_legacy_dict() can pass an empty dict into _spec_from_legacy_params() when fallbacks references a model_name that isn’t present in model_list, which will then fail with a KeyError on params['model']. Consider validating that all referenced fallback model names exist (and raising a clear ValueError listing the missing names) before constructing the ordered chain.

Copilot uses AI. Check for mistakes.
)


def _is_openai_provider(name: str) -> bool:
try:
return "openai" in litellm.get_llm_provider(name)
except litellm.BadRequestError:
return False


_RESERVED_LEGACY_PARAMS = frozenset({
"model",
"api_base",
"api_key",
"timeout",
"max_retries",
})


def _spec_from_legacy_params(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this a classmethod of ModelSpec? Then we can get Self as the return type

params: dict[str, Any],
*,
default_timeout: float,
default_retries: int,
) -> ModelSpec:
api_key = params.get("api_key")
return ModelSpec(
name=params["model"],
api_base=params.get("api_base"),
api_key=SecretStr(api_key) if api_key is not None else None,
timeout=params.get("timeout", default_timeout),
max_retries=params.get("max_retries", default_retries),
extra_params={
k: v for k, v in params.items() if k not in _RESERVED_LEGACY_PARAMS
},
)


# Pydantic field annotation that accepts any input `LLMConfig.coerce` supports.
# Use in place of a bare `LLMConfig` when you want the model to accept both
# typed instances and the dict shapes LMI recognises, without writing a
# `@field_validator` on every class.
LLMConfigField = Annotated[LLMConfig, BeforeValidator(LLMConfig.coerce)]
3 changes: 0 additions & 3 deletions packages/lmi/src/lmi/constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import os
from sys import version_info

USE_RESPONSES_API = os.environ.get("USE_RESPONSES_API", "").lower() in {"1", "true"}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice this is awesome


# Estimate from OpenAI's FAQ
# https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
CHARACTERS_PER_TOKEN_ASSUMPTION: float = 4.0
Expand Down
44 changes: 44 additions & 0 deletions packages/lmi/src/lmi/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,46 @@
from typing import Any


class JSONSchemaValidationError(ValueError):
"""Raised when the completion does not match the specified schema."""


class ModelRefusalError(RuntimeError):
"""Raised when an LLM declines to complete a request (e.g. content filter).

Carries the raw provider response so callers that choose to handle the
refusal (rather than fall back) can still inspect it.
"""

def __init__(
self,
message: str,
*,
model: str,
finish_reason: str | None,
response: Any = None,
) -> None:
super().__init__(message)
self.model = model
self.finish_reason = finish_reason
self.response = response


class ResponseValidationError(RuntimeError):
"""Raised when an `LLMConfig.response_validator` rejects an `LLMResult`.

Treated as transient by the retry/fallback loop so the validator gets a
fresh attempt at the same model (up to `ModelSpec.max_retries`) before
advancing to the next model.
"""


class AllModelsExhaustedError(RuntimeError):
"""Raised when every model in an `LLMConfig.models` chain has failed or been skipped."""

def __init__(self, last_exc: BaseException | None = None) -> None:
super().__init__(
"All models in the LLMConfig chain failed."
+ (f" Last error: {last_exc!r}" if last_exc is not None else "")
)
self.last_exc = last_exc
Loading
Loading