Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,35 +1205,41 @@ async def _probe_provider_key(provider: str, key: str) -> dict:

started = time.perf_counter()
try:
timeout = httpx.Timeout(5.0)
timeout = httpx.Timeout(10.0, connect=5.0)
async with httpx.AsyncClient(timeout=timeout) as client:
if provider == "anthropic":
r = await client.post(
"https://api.anthropic.com/v1/messages",
r = await client.get(
"https://api.anthropic.com/v1/models",
headers={
"x-api-key": key,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
},
json={
"model": "claude-haiku-4-5-20251001",
"max_tokens": 1,
"messages": [{"role": "user", "content": "ping"}],
},
)
status = "ok" if r.status_code in {200, 400} else "invalid_key" if r.status_code == 401 else "unreachable"
status = "ok" if r.status_code == 200 else "invalid_key" if r.status_code in {401, 403} else "unreachable"
elif provider == "openai":
r = await client.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {key}"},
)
status = "ok" if r.status_code == 200 else "invalid_key" if r.status_code == 401 else "unreachable"
status = "ok" if r.status_code == 200 else "invalid_key" if r.status_code in {401, 403} else "unreachable"
elif provider == "groq":
r = await client.get(
"https://api.groq.com/openai/v1/models",
headers={"Authorization": f"Bearer {key}"},
)
status = "ok" if r.status_code == 200 else "invalid_key" if r.status_code == 401 else "unreachable"
status = "ok" if r.status_code == 200 else "invalid_key" if r.status_code in {401, 403} else "unreachable"
elif provider == "deepseek":
r = await client.get(
"https://api.deepseek.com/v1/models",
headers={"Authorization": f"Bearer {key}"},
)
status = "ok" if r.status_code == 200 else "invalid_key" if r.status_code in {401, 403} else "unreachable"
elif provider == "nvidia":
r = await client.get(
"https://integrate.api.nvidia.com/v1/models",
headers={"Authorization": f"Bearer {key}"},
)
status = "ok" if r.status_code == 200 else "invalid_key" if r.status_code in {401, 403} else "unreachable"
else:
status = "unchecked"
except Exception:
Expand All @@ -1247,14 +1253,15 @@ async def validate_settings():
from llm import _ENV_NAMES, _KEY_NAMES

cfg = get_settings()
probeable = {"anthropic", "openai", "groq", "deepseek", "nvidia"}
providers = ["anthropic", "openai", "groq", *[p for p in _KEY_NAMES if p not in {"anthropic", "openai", "groq"}]]

async def one(provider: str):
key_name = _KEY_NAMES.get(provider, "")
key = str(cfg.get(key_name) or os.environ.get(_ENV_NAMES.get(provider, ""), "") or "").strip()
if not key:
return provider, {"status": "not_configured", "latency_ms": 0}
if provider not in {"anthropic", "openai", "groq"}:
if provider not in probeable:
return provider, {"status": "unchecked", "latency_ms": 0}
return provider, await _probe_provider_key(provider, key)

Expand Down
77 changes: 77 additions & 0 deletions backend/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,83 @@ def test_validate_returns_dict(self):
self.assertIsInstance(resp.json(), dict)


class TestProviderKeyProbe(unittest.IsolatedAsyncioTestCase):
"""Probe should hit /v1/models on each provider (cheap, no inference cost)
and map auth-failure codes to invalid_key rather than unreachable."""

async def _probe(self, provider, status_code):
from main import _probe_provider_key

captured = {}

class _FakeResponse:
def __init__(self, code):
self.status_code = code

class _FakeAsyncClient:
def __init__(self, *_a, **_kw):
pass

async def __aenter__(self):
return self

async def __aexit__(self, *_a):
return False

async def get(self, url, headers=None):
captured["method"] = "GET"
captured["url"] = url
captured["headers"] = dict(headers or {})
return _FakeResponse(status_code)

async def post(self, *_a, **_kw):
raise AssertionError("probe must not POST — that costs tokens")

with mock.patch("httpx.AsyncClient", _FakeAsyncClient):
result = await _probe_provider_key(provider, "k-test-123")
return result, captured

async def test_anthropic_uses_models_endpoint_with_x_api_key(self):
result, captured = await self._probe("anthropic", 200)
self.assertEqual(result["status"], "ok")
self.assertEqual(captured["url"], "https://api.anthropic.com/v1/models")
self.assertEqual(captured["headers"].get("x-api-key"), "k-test-123")
self.assertIn("anthropic-version", captured["headers"])

async def test_openai_uses_models_endpoint_with_bearer(self):
result, captured = await self._probe("openai", 200)
self.assertEqual(result["status"], "ok")
self.assertEqual(captured["url"], "https://api.openai.com/v1/models")
self.assertEqual(captured["headers"].get("Authorization"), "Bearer k-test-123")

async def test_groq_uses_openai_compat_models_endpoint(self):
result, captured = await self._probe("groq", 200)
self.assertEqual(result["status"], "ok")
self.assertEqual(captured["url"], "https://api.groq.com/openai/v1/models")

async def test_deepseek_probe(self):
result, captured = await self._probe("deepseek", 200)
self.assertEqual(result["status"], "ok")
self.assertEqual(captured["url"], "https://api.deepseek.com/v1/models")

async def test_nvidia_probe(self):
result, captured = await self._probe("nvidia", 200)
self.assertEqual(result["status"], "ok")
self.assertEqual(captured["url"], "https://integrate.api.nvidia.com/v1/models")

async def test_401_maps_to_invalid_key(self):
result, _ = await self._probe("anthropic", 401)
self.assertEqual(result["status"], "invalid_key")

async def test_403_maps_to_invalid_key(self):
result, _ = await self._probe("openai", 403)
self.assertEqual(result["status"], "invalid_key")

async def test_5xx_maps_to_unreachable(self):
result, _ = await self._probe("groq", 503)
self.assertEqual(result["status"], "unreachable")


class TestFollowupsEndpoint(unittest.TestCase):
def test_due_followups_returns_list(self):
resp = get("/api/v1/followups/due")
Expand Down