Skip to content

Commit 8dee99d

Browse files
author
Mateusz
committed
Fix constant-time auth token validation
1 parent d98bce0 commit 8dee99d

2 files changed

Lines changed: 90 additions & 21 deletions

File tree

src/core/security/middleware.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import logging
88
import math
9+
import secrets
910
import time
1011
from collections.abc import Awaitable, Callable
1112
from dataclasses import dataclass
@@ -28,6 +29,29 @@
2829
logger = logging.getLogger(__name__)
2930

3031

32+
def _constant_time_equals(candidate: str | None, expected: str) -> bool:
33+
"""Compare secrets without leaking matching-prefix timing."""
34+
35+
return bool(
36+
candidate
37+
and isinstance(expected, str)
38+
and secrets.compare_digest(candidate, expected)
39+
)
40+
41+
42+
def _constant_time_member(candidate: str | None, expected_values: set[str]) -> bool:
43+
"""Check membership using constant-time comparisons for each expected value."""
44+
45+
if not candidate:
46+
return False
47+
48+
matched = False
49+
for expected in expected_values:
50+
if isinstance(expected, str) and secrets.compare_digest(candidate, expected):
51+
matched = True
52+
return matched
53+
54+
3155
@dataclass
3256
class _BruteForceRecord:
3357
"""Track failed attempts and blocking metadata for a client IP."""
@@ -219,7 +243,7 @@ async def dispatch(
219243
all_valid_keys: set[str] = self.valid_keys | app_state_keys
220244

221245
method = request.method
222-
if not api_key or api_key not in all_valid_keys:
246+
if not _constant_time_member(api_key, all_valid_keys):
223247
logger.warning(
224248
"Invalid or missing API key for %s %s from client %s",
225249
method,
@@ -525,7 +549,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
525549
f"API Key authentication is enabled key_count={len(all_valid_keys)}"
526550
)
527551
method = scope.get("method", "UNKNOWN")
528-
if not api_key or api_key not in all_valid_keys:
552+
if not _constant_time_member(api_key, all_valid_keys):
529553
logger.warning(
530554
"Invalid or missing API key for %s %s from client %s",
531555
method,
@@ -659,7 +683,7 @@ async def dispatch(
659683
method = request.method
660684

661685
# Validate the token
662-
if not token or token != self.valid_token:
686+
if not _constant_time_equals(token, self.valid_token):
663687
logger.warning(
664688
"Invalid or missing auth token for %s %s from client %s",
665689
method,
@@ -763,7 +787,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
763787
method = scope.get("method", "UNKNOWN")
764788

765789
# Validate the token
766-
if not token or token != self.valid_token:
790+
if not _constant_time_equals(token, self.valid_token):
767791
logger.warning(
768792
"Invalid or missing auth token for %s %s from client %s",
769793
method,

tests/unit/core/test_authentication_di.py

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
refactored to use proper dependency injection instead of direct app.state access.
66
"""
77

8-
import json
9-
import os
10-
from unittest.mock import AsyncMock, MagicMock, patch
8+
import json
9+
import os
10+
from typing import Any, cast
11+
from unittest.mock import AsyncMock, MagicMock, patch
1112

1213
import pytest
1314
from fastapi import FastAPI, Request, Response
@@ -134,8 +135,8 @@ async def test_valid_query_key(self, api_key_middleware, mock_request):
134135
call_next.assert_called_once_with(mock_request)
135136
assert response == "next_response"
136137

137-
async def test_invalid_key(self, api_key_middleware, mock_request):
138-
"""Test that an invalid API key is rejected."""
138+
async def test_invalid_key(self, api_key_middleware, mock_request):
139+
"""Test that an invalid API key is rejected."""
139140
# Setup
140141
mock_request.headers = {"Authorization": "Bearer invalid-key"}
141142
call_next = AsyncMock(return_value="next_response")
@@ -146,9 +147,31 @@ async def test_invalid_key(self, api_key_middleware, mock_request):
146147
# Verify
147148
call_next.assert_not_called()
148149
assert response.status_code == 401
149-
assert (
150-
response.body == f'{{"detail":"{HTTP_401_UNAUTHORIZED_MESSAGE}"}}'.encode()
151-
)
150+
assert (
151+
response.body == f'{{"detail":"{HTTP_401_UNAUTHORIZED_MESSAGE}"}}'.encode()
152+
)
153+
154+
async def test_valid_bearer_key_uses_constant_time_comparison(
155+
self, api_key_middleware, mock_request, monkeypatch
156+
):
157+
"""API key validation should avoid short-circuit string membership."""
158+
159+
compared_values: list[tuple[str, str]] = []
160+
161+
def compare_digest(candidate: str, expected: str) -> bool:
162+
compared_values.append((candidate, expected))
163+
return candidate == expected
164+
165+
monkeypatch.setattr(
166+
"src.core.security.middleware.secrets.compare_digest", compare_digest
167+
)
168+
mock_request.headers = {"Authorization": "Bearer test-key"}
169+
call_next = AsyncMock(return_value="next_response")
170+
171+
response = await api_key_middleware.dispatch(mock_request, call_next)
172+
173+
assert response == "next_response"
174+
assert ("test-key", "test-key") in compared_values
152175

153176
async def test_missing_key(self, api_key_middleware, mock_request):
154177
"""Test that a missing API key is rejected."""
@@ -278,7 +301,7 @@ async def _attempt(header_value: str) -> Response:
278301
blocked = await _attempt("bad-3")
279302
assert blocked.status_code == 429
280303
assert blocked.headers.get("Retry-After") == "10"
281-
payload = json.loads(blocked.body.decode())
304+
payload = json.loads(bytes(blocked.body).decode())
282305
assert payload["retry_after_seconds"] == 10
283306

284307
# After the wait expires, another invalid attempt is allowed
@@ -289,7 +312,7 @@ async def _attempt(header_value: str) -> Response:
289312
blocked_again = await _attempt("bad-5")
290313
assert blocked_again.status_code == 429
291314
assert blocked_again.headers.get("Retry-After") == "20"
292-
payload = json.loads(blocked_again.body.decode())
315+
payload = json.loads(bytes(blocked_again.body).decode())
293316
assert payload["retry_after_seconds"] == 20
294317

295318
# Provide a valid key to reset the tracker
@@ -340,7 +363,7 @@ def advance(self, seconds: float) -> None:
340363

341364
disable_flag = {"value": False}
342365

343-
def get_setting(key: str, default=None):
366+
def get_setting(key: str, default: Any = None) -> Any:
344367
if key == "disable_auth":
345368
return disable_flag["value"]
346369
if key == "client_api_key":
@@ -355,7 +378,7 @@ def get_setting(key: str, default=None):
355378

356379
app_state_service = MagicMock(spec=IApplicationState)
357380
app_state_service.get_setting.side_effect = get_setting
358-
middleware.app_state_service = app_state_service
381+
cast(Any, middleware).app_state_service = app_state_service
359382
mock_request.app.state.service_provider.get_service.return_value = (
360383
app_state_service
361384
)
@@ -405,8 +428,8 @@ async def test_valid_token(self, auth_token_middleware, mock_request):
405428
call_next.assert_called_once_with(mock_request)
406429
assert response == "next_response"
407430

408-
async def test_invalid_token(self, auth_token_middleware, mock_request):
409-
"""Test that an invalid auth token is rejected."""
431+
async def test_invalid_token(self, auth_token_middleware, mock_request):
432+
"""Test that an invalid auth token is rejected."""
410433
# Setup
411434
mock_request.headers = {"X-Auth-Token": "invalid-token"}
412435
call_next = AsyncMock(return_value="next_response")
@@ -417,9 +440,31 @@ async def test_invalid_token(self, auth_token_middleware, mock_request):
417440
# Verify
418441
call_next.assert_not_called()
419442
assert response.status_code == 401
420-
assert (
421-
response.body == f'{{"detail":"{HTTP_401_UNAUTHORIZED_MESSAGE}"}}'.encode()
422-
)
443+
assert (
444+
response.body == f'{{"detail":"{HTTP_401_UNAUTHORIZED_MESSAGE}"}}'.encode()
445+
)
446+
447+
async def test_valid_token_uses_constant_time_comparison(
448+
self, auth_token_middleware, mock_request, monkeypatch
449+
):
450+
"""Auth token validation should avoid short-circuit string equality."""
451+
452+
compared_values: list[tuple[str, str]] = []
453+
454+
def compare_digest(candidate: str, expected: str) -> bool:
455+
compared_values.append((candidate, expected))
456+
return candidate == expected
457+
458+
monkeypatch.setattr(
459+
"src.core.security.middleware.secrets.compare_digest", compare_digest
460+
)
461+
mock_request.headers = {"X-Auth-Token": "test-token"}
462+
call_next = AsyncMock(return_value="next_response")
463+
464+
response = await auth_token_middleware.dispatch(mock_request, call_next)
465+
466+
assert response == "next_response"
467+
assert compared_values == [("test-token", "test-token")]
423468

424469
async def test_missing_token(self, auth_token_middleware, mock_request):
425470
"""Test that a missing auth token is rejected."""

0 commit comments

Comments
 (0)