Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 48 additions & 26 deletions src/ai/backend/common/jwt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,70 +5,92 @@
Hive Router. It uses the X-BackendAI-Token custom header to avoid conflicts with
existing Bearer token usage in appproxy.

Supports both HS256 (symmetric, per-user secret keys) and RS256 (asymmetric,
RSA key pairs) signing algorithms, with JWKS utilities for distributed key
management.

Key components:
- JWTSigner: Generates JWT tokens from authenticated user context (webserver)
- JWTValidator: Validates JWT tokens and extracts user claims (manager)
- JWTConfig: Configuration for JWT authentication
- JWTClaims: Dataclass representing JWT payload claims
- JWKSKeySet: Public key set indexed by key ID for RS256 validation
- JWKSFetcher: Async JWKS endpoint fetcher with TTL caching
- Key utilities: RSA key generation, loading, serialization, and JWK conversion

Example usage in webserver:
Example usage (HS256):
from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext

config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"])
config = JWTConfig()
signer = JWTSigner(config)

user_context = JWTUserContext(
user_id=user_uuid,
access_key=access_key,
role="user",
domain_name="default",
is_admin=False,
is_superadmin=False,
)
token = signer.generate_token(user_context)

# Add to request headers
headers["X-BackendAI-Token"] = token

Example usage in manager:
from ai.backend.common.jwt import JWTValidator, JWTConfig

config = JWTConfig(secret_key=os.environ["JWT_SECRET_KEY"])
validator = JWTValidator(config)
token = signer.generate_token(user_context, secret_key)

token = request.headers.get("X-BackendAI-Token")
claims = validator.validate_token(token)
Example usage (RS256):
from ai.backend.common.jwt import JWTSigner, JWTConfig, JWTUserContext
from ai.backend.common.jwt.keys import load_private_key

# Use claims for authentication
user_id = claims.sub
access_key = claims.access_key
config = JWTConfig(algorithm="RS256")
signer = JWTSigner(config)
private_key = load_private_key(Path("/path/to/private.pem"))
token = signer.generate_token(user_context, private_key=private_key, kid="key-1")
"""

from .config import JWTConfig
from .exceptions import (
from ai.backend.common.jwt.config import JWTAlgorithm, JWTConfig
from ai.backend.common.jwt.exceptions import (
JWKSError,
JWKSFetchError,
JWKSKeyNotFoundError,
JWTDecodeError,
JWTError,
JWTExpiredError,
JWTInvalidClaimsError,
JWTInvalidSignatureError,
)
from .signer import JWTSigner
from .types import JWTClaims, JWTUserContext
from .validator import JWTValidator
from ai.backend.common.jwt.jwks import JWKSFetcher, JWKSKeySet
from ai.backend.common.jwt.keys import (
generate_rsa_key_pair,
load_private_key,
load_public_key,
private_key_to_pem,
public_key_to_jwk,
public_key_to_pem,
)
from ai.backend.common.jwt.signer import JWTSigner
from ai.backend.common.jwt.types import JWTClaims, JWTUserContext
from ai.backend.common.jwt.validator import JWTValidator

__all__ = [
# Configuration
"JWTAlgorithm",
"JWTConfig",
# Types
"JWTClaims",
"JWTUserContext",
# Core classes
"JWTSigner",
"JWTValidator",
# JWKS
"JWKSKeySet",
"JWKSFetcher",
# Key management
"generate_rsa_key_pair",
"load_private_key",
"load_public_key",
"private_key_to_pem",
"public_key_to_pem",
"public_key_to_jwk",
# Exceptions
"JWTError",
"JWTExpiredError",
"JWTInvalidSignatureError",
"JWTInvalidClaimsError",
"JWTDecodeError",
"JWKSError",
"JWKSFetchError",
"JWKSKeyNotFoundError",
]
33 changes: 27 additions & 6 deletions src/ai/backend/common/jwt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,66 @@
from __future__ import annotations

from datetime import timedelta
from enum import StrEnum
from pathlib import Path

from pydantic import Field

from ai.backend.common.config import BaseConfigSchema


class JWTAlgorithm(StrEnum):
"""Supported JWT signing algorithms."""

HS256 = "HS256"
RS256 = "RS256"


class JWTConfig(BaseConfigSchema):
"""
Configuration for JWT-based authentication in GraphQL Federation.

This configuration must be consistent between webserver (which generates tokens)
and manager (which validates tokens).

Note: JWT tokens are signed using per-user secret keys (from keypair table),
not a shared system secret key. This maintains the same security model as HMAC authentication.
Supports both HS256 (symmetric, per-user secret keys) and RS256 (asymmetric,
RSA key pairs) signing algorithms.

JWT tokens are transmitted via X-BackendAI-Token HTTP header.

Attributes:
enabled: Whether JWT authentication is enabled
algorithm: JWT signing algorithm (must be HS256)
algorithm: JWT signing algorithm (HS256 or RS256)
token_expiration_seconds: Token validity duration in seconds
private_key_path: Path to PEM-encoded RSA private key (RS256 only)
public_key_path: Path to PEM-encoded RSA public key (RS256 only)
"""

enabled: bool = Field(
default=True,
description="Enable JWT authentication for GraphQL Federation requests",
)

algorithm: str = Field(
default="HS256",
description="JWT signing algorithm (only HS256 is supported)",
algorithm: JWTAlgorithm = Field(
default=JWTAlgorithm.HS256,
description="JWT signing algorithm (HS256 or RS256)",
)

token_expiration_seconds: int = Field(
default=900, # 15 minutes
description="JWT token expiration time in seconds (default: 900 = 15 minutes)",
)

private_key_path: Path | None = Field(
default=None,
description="Path to PEM-encoded RSA private key file (required for RS256 signing)",
)

public_key_path: Path | None = Field(
default=None,
description="Path to PEM-encoded RSA public key file (required for RS256 validation)",
)

@property
def token_expiration(self) -> timedelta:
"""
Expand Down
34 changes: 34 additions & 0 deletions src/ai/backend/common/jwt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,37 @@ class JWTDecodeError(JWTError, web.HTTPUnauthorized):

error_type = "https://api.backend.ai/probs/jwt-decode-error"
error_title = "Failed to decode JWT token."


class JWKSError(JWTError):
"""
Base exception for JWKS-related errors.

All JWKS-specific exceptions inherit from this base class.
"""

error_type = "https://api.backend.ai/probs/jwks-error"
error_title = "JWKS error."


class JWKSFetchError(JWKSError, web.HTTPUnauthorized):
"""
Failed to fetch JWKS from the remote endpoint.

Raised when the JWKS endpoint is unreachable or returns invalid data.
"""

error_type = "https://api.backend.ai/probs/jwks-fetch-error"
error_title = "Failed to fetch JWKS."


class JWKSKeyNotFoundError(JWKSError, web.HTTPUnauthorized):
"""
Key ID (kid) not found in the JWKS key set.

Raised when a token references a key ID that is not present
in the available JWKS key set.
"""

error_type = "https://api.backend.ai/probs/jwks-key-not-found"
error_title = "Key ID not found in JWKS."
Loading
Loading