Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ init:
develop:
uv run fastapi dev --port 9000

test:
uv run python -m pytest .

start_db:
mongod --dbpath db/data --logpath db/logs/mongodb.log

Expand Down
2 changes: 1 addition & 1 deletion server/models/laf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime
from html import unescape
from enum import Enum
from html import unescape
from typing import Annotated, TypedDict

from pydantic import (
Expand Down
232 changes: 232 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
"""Unit tests for server.helpers.auth."""

import asyncio
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch

import jwt
import pytest
from fastapi import HTTPException, Request, status

from server.helpers.auth import (
BlacklistedTokenException,
create_access_token,
required_auth,
simple_auth_check,
validate_token,
)


def _request_with_cookie(token: str | None) -> MagicMock:
request = MagicMock(spec=Request)
cookies: dict[str, str] = {}
if token is not None:
cookies["authToken"] = token
request.cookies.get = MagicMock(
side_effect=lambda key, default=None: cookies.get(key, default)
)
return request


def test_validate_token_blacklisted_raises() -> None:
async def _run() -> None:
request = _request_with_cookie("tok")
with patch(
"server.helpers.auth.is_token_blacklisted",
new_callable=AsyncMock,
return_value=True,
):
with pytest.raises(BlacklistedTokenException):
await validate_token(request, "tok")

asyncio.run(_run())


def test_validate_token_decodes_and_returns_payload() -> None:
async def _run() -> None:
request = _request_with_cookie("tok")
expected = {"sub": "user1"}
with (
patch(
"server.helpers.auth.is_token_blacklisted",
new_callable=AsyncMock,
return_value=False,
),
patch(
"server.helpers.auth.jwt.decode", return_value=expected
) as mock_decode,
):
with patch("server.helpers.auth.settings") as mock_settings:
mock_settings.SECRET_KEY = "secret"
mock_settings.ALGORITHM = "HS256"
result = await validate_token(request, "tok")
assert result == expected
mock_decode.assert_called_once_with("tok", "secret", algorithms=["HS256"])

asyncio.run(_run())


def test_create_access_token_with_expires_delta() -> None:
async def _run() -> str:
fixed_now = datetime(2025, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
delta = timedelta(hours=2)
with patch("server.helpers.auth.datetime") as mock_dt:
mock_dt.now.return_value = fixed_now
mock_dt.timedelta = timedelta
mock_dt.timezone = timezone
with patch("server.helpers.auth.settings") as mock_settings:
mock_settings.SECRET_KEY = "x" * 32
mock_settings.ALGORITHM = "HS256"
return await create_access_token({"sub": "u1"}, expires_delta=delta)

token = asyncio.run(_run())
payload = jwt.decode(
token,
"x" * 32,
algorithms=["HS256"],
options={"verify_exp": False},
)
assert payload["sub"] == "u1"
fixed_now = datetime(2025, 1, 15, 12, 0, 0, tzinfo=timezone.utc)
assert payload["exp"] == int((fixed_now + timedelta(hours=2)).timestamp())


def test_create_access_token_default_expiry_about_15_minutes() -> None:
async def _run() -> str:
fixed_now = datetime(2025, 6, 1, 10, 30, 0, tzinfo=timezone.utc)
with patch("server.helpers.auth.datetime") as mock_dt:
mock_dt.now.return_value = fixed_now
mock_dt.timedelta = timedelta
mock_dt.timezone = timezone
with patch("server.helpers.auth.settings") as mock_settings:
mock_settings.SECRET_KEY = "x" * 32
mock_settings.ALGORITHM = "HS256"
return await create_access_token({"sub": "u2"})

token = asyncio.run(_run())
payload = jwt.decode(
token,
"x" * 32,
algorithms=["HS256"],
options={"verify_exp": False},
)
fixed_now = datetime(2025, 6, 1, 10, 30, 0, tzinfo=timezone.utc)
expected_exp = int((fixed_now + timedelta(minutes=15)).timestamp())
assert payload["exp"] == expected_exp


def test_simple_auth_check_no_cookie() -> None:
async def _run() -> None:
request = _request_with_cookie(None)
ok, msg, payload = await simple_auth_check(request)
assert ok is False
assert msg == "No token found"
assert payload is None

asyncio.run(_run())


def test_simple_auth_check_success() -> None:
async def _run() -> None:
request = _request_with_cookie("good")
with patch(
"server.helpers.auth.validate_token",
new_callable=AsyncMock,
return_value={"sub": "x"},
):
ok, msg, payload = await simple_auth_check(request)
assert ok is True
assert msg == ""
assert payload == {"sub": "x"}

asyncio.run(_run())


def test_simple_auth_check_expired_token() -> None:
async def _run() -> None:
request = _request_with_cookie("expired")
with patch(
"server.helpers.auth.validate_token",
new_callable=AsyncMock,
side_effect=jwt.ExpiredSignatureError(),
):
ok, msg, payload = await simple_auth_check(request)
assert ok is False
assert msg == "Token expired"
assert payload is None

asyncio.run(_run())


def test_simple_auth_check_invalid_token() -> None:
async def _run() -> None:
request = _request_with_cookie("bad")
with patch(
"server.helpers.auth.validate_token",
new_callable=AsyncMock,
side_effect=jwt.InvalidTokenError(),
):
ok, msg, payload = await simple_auth_check(request)
assert ok is False
assert msg == "Invalid token"
assert payload is None

asyncio.run(_run())


def test_simple_auth_check_blacklisted() -> None:
async def _run() -> None:
request = _request_with_cookie("bl")
with patch(
"server.helpers.auth.validate_token",
new_callable=AsyncMock,
side_effect=BlacklistedTokenException(),
):
ok, msg, payload = await simple_auth_check(request)
assert ok is False
assert msg == "User logged out"
assert payload is None

asyncio.run(_run())


def test_required_auth_returns_payload() -> None:
async def _run() -> None:
request = _request_with_cookie("ok")
with patch(
"server.helpers.auth.simple_auth_check",
new_callable=AsyncMock,
return_value=(True, "", {"sub": "admin"}),
):
result = await required_auth(request)
assert result == {"sub": "admin"}

asyncio.run(_run())


@pytest.mark.parametrize(
"authenticated,message,payload",
[
(False, "No token found", None),
(True, "", None),
],
ids=["unauthenticated", "authenticated_but_no_payload"],
)
def test_required_auth_raises_401(
authenticated: bool,
message: str,
payload: dict | None,
) -> None:
async def _run() -> None:
request = _request_with_cookie("x")
with patch(
"server.helpers.auth.simple_auth_check",
new_callable=AsyncMock,
return_value=(authenticated, message, payload),
):
with pytest.raises(HTTPException) as exc_info:
await required_auth(request)
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == message

asyncio.run(_run())
55 changes: 55 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Unit tests for server.helpers.cache."""

from fastapi import Request

from server.helpers.cache import cache_key_exclude_request


def _http_request() -> Request:
"""Minimal Starlette/FastAPI Request for isinstance checks."""
return Request(
{
"type": "http",
"asgi": {"version": "3.0", "spec_version": "2.4"},
"http_version": "1.1",
"method": "GET",
"scheme": "http",
"path": "/",
"raw_path": b"/",
"root_path": "",
"query_string": b"",
"headers": [],
"client": ("127.0.0.1", 12345),
"server": ("127.0.0.1", 80),
}
)


def sample_fn(*_args, **_kwargs) -> None:
"""Module-level function for cache key naming."""
pass


def test_cache_key_exclude_request_strips_request_from_args() -> None:
req = _http_request()
key = cache_key_exclude_request(sample_fn, 1, req, "keep")
expected = f"{sample_fn.__module__}.{sample_fn.__name__}:(1, 'keep'):()"
assert key == expected


def test_cache_key_exclude_request_strips_request_from_kwargs() -> None:
req = _http_request()
key = cache_key_exclude_request(sample_fn, 10, r=req, z=1, a=2)
expected = f"{sample_fn.__module__}.{sample_fn.__name__}:(10,):(('a', 2), ('z', 1))"
assert key == expected


def test_cache_key_exclude_request_kwargs_sorted_deterministically() -> None:
key_zy = cache_key_exclude_request(sample_fn, 0, z=1, y=2)
key_yz = cache_key_exclude_request(sample_fn, 0, y=2, z=1)
assert key_zy == key_yz


def test_cache_key_exclude_request_includes_module_and_function_name() -> None:
key = cache_key_exclude_request(sample_fn, 42)
assert key.startswith(f"{sample_fn.__module__}.{sample_fn.__name__}:")
Loading
Loading