Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions backend/app/adapters/inbound/api/v1/routes_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
AcceptInvitationRequest,
InviteMemberRequest,
LoginRequest,
MFAEnableRequest,
MFAVerifyRequest,
PatchMemberRequest,
RefreshTokenRequest,
RegisterRequest,
Expand All @@ -16,6 +18,7 @@
)
from app.application.dto.responses import (
InviteMemberResponse,
MFAEnrollResponse,
MFAPartialResponse,
MeResponse,
MembersListResponse,
Expand Down Expand Up @@ -94,6 +97,53 @@ async def login(
)


@router.post("/mfa/enroll", response_model=MFAEnrollResponse)
async def mfa_enroll(
claims: dict = Depends(get_current_claims),
iam_service=Depends(Container.get_iam_service),
) -> MFAEnrollResponse:
result = await iam_service.start_mfa_enrollment(user_id=UUID(str(claims["sub"])))
return MFAEnrollResponse(
secret=result["secret"], provisioning_uri=result["provisioning_uri"]
)


@router.post("/mfa/enable", status_code=status.HTTP_204_NO_CONTENT)
async def mfa_enable(
payload: MFAEnableRequest,
claims: dict = Depends(get_current_claims),
iam_service=Depends(Container.get_iam_service),
) -> None:
try:
await iam_service.enable_mfa(
user_id=UUID(str(claims["sub"])), code=payload.code
)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)
) from exc


@router.post("/mfa/verify", response_model=TokenPairResponse)
async def mfa_verify(
payload: MFAVerifyRequest,
iam_service=Depends(Container.get_iam_service),
) -> TokenPairResponse:
try:
result = await iam_service.verify_mfa(
mfa_attempt_token=payload.mfa_attempt_token, code=payload.code
)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc)
) from exc
return TokenPairResponse(
access_token=result["access_token"],
refresh_token=result["refresh_token"],
token_type=result.get("token_type", "bearer"),
)


@router.post("/switch-org", response_model=SwitchOrgResponse)
async def switch_org(
payload: SwitchOrgRequest,
Expand Down
21 changes: 19 additions & 2 deletions backend/app/adapters/inbound/api/v1/routes_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,28 @@
from fastapi.responses import StreamingResponse

from app.application.dto.requests import ChatRequest
from app.infrastructure.config import settings
from app.infrastructure.di.container import Container
from app.infrastructure.security.auth import get_current_claims
from app.infrastructure.security.rate_limit import rate_limit

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/chats", tags=["chats"])


@router.post("/{session_id}/ask")
@router.post(
"/{session_id}/ask",
dependencies=[
Depends(
rate_limit(
scope="chat",
limit=settings.chat_rate_limit_per_min,
window=settings.rate_limit_window_seconds,
)
)
],
)
async def chat(
session_id: UUID,
payload: ChatRequest,
Expand All @@ -38,4 +51,8 @@ async def event_generator():
logger.exception("Chat streaming failed")
raise

return StreamingResponse(event_generator(), media_type="text/plain")
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
47 changes: 45 additions & 2 deletions backend/app/adapters/inbound/api/v1/routes_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@
get_document_repository,
get_ingest_document_service,
get_object_storage,
get_usage_service,
get_vector_store,
)
from app.infrastructure.security.auth import get_current_claims
from app.infrastructure.security.rate_limit import rate_limit

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -121,6 +124,15 @@ def _infer_source_type(filename: str | None, content_type: str | None) -> str:
response_model=DocumentResponse,
status_code=status.HTTP_202_ACCEPTED,
summary="Upload a document for asynchronous chunking",
dependencies=[
Depends(
rate_limit(
scope="upload",
limit=settings.upload_rate_limit_per_min,
window=settings.rate_limit_window_seconds,
)
)
],
)
async def upload_document(
file: UploadFile = File(..., description="The document file to ingest"),
Expand Down Expand Up @@ -172,6 +184,16 @@ async def upload_document(
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, detail=str(exc)
) from exc

try:
await get_usage_service().record(
org_id=UUID(org_id_str),
event_type="upload",
document_count=1,
user_id=UUID(user_id_str) if user_id_str else None,
)
except Exception:
logger.exception("Failed to record upload usage for org %s", org_id_str)

return _to_response(document)


Expand Down Expand Up @@ -271,26 +293,47 @@ async def list_document_chunks(
async def delete_document(
document_id: UUID, claims: dict = Depends(get_current_claims)
):
"""Delete a document everywhere it lives.

The chain purges the three stores the document touches so the AI truly
forgets it: vectors in Pinecone (so it can no longer surface in answers),
the Postgres row (cascading its chunks), and the raw bytes in object
storage. Vectors are purged first; external-store failures are logged but
do not block removal of the application record.
"""
import asyncio

org_id_str = claims.get("org_id")
if not org_id_str:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No active organization"
)
org_id = UUID(org_id_str)
repo = get_document_repository()
doc = await repo.get_by_id(document_id=document_id, org_id=UUID(org_id_str))
doc = await repo.get_by_id(document_id=document_id, org_id=org_id)
if doc is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Document not found"
)

await repo.delete(document_id=document_id, org_id=UUID(org_id_str))
# 1) Purge vectors first so a stale document can never answer a question.
try:
await get_vector_store().delete_by_document_id(
document_id=str(document_id), org_id=str(org_id)
)
except Exception:
logger.exception("Failed to purge vectors for document %s", document_id)

# 2) Delete the DB row (cascades chunks via FK ondelete=CASCADE).
await repo.delete(document_id=document_id, org_id=org_id)

# 3) Remove the stored bytes (best effort).
storage = get_object_storage()
try:
await asyncio.to_thread(storage.delete_object, key=doc.storage_url)
except Exception:
logger.exception("Failed to delete stored bytes for document %s", document_id)

return Response(status_code=status.HTTP_204_NO_CONTENT)


Expand Down
28 changes: 28 additions & 0 deletions backend/app/adapters/inbound/api/v1/routes_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

import logging
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, status

from app.application.dto.responses import UsageResponse
from app.infrastructure.di.container import Container
from app.infrastructure.security.auth import get_current_claims

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/usage", tags=["usage"])


@router.get("", response_model=UsageResponse, summary="Current-month usage for org")
async def get_usage(
claims: dict = Depends(get_current_claims),
usage_service=Depends(Container.get_usage_service),
) -> UsageResponse:
org_id_str = claims.get("org_id")
if not org_id_str:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="No active organization"
)
result = await usage_service.get_monthly_usage(org_id=UUID(org_id_str))
return UsageResponse(**result)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add users.mfa_secret

Revision ID: 20260608_0006
Revises: 20260506_0005
"""

from __future__ import annotations

from alembic import op
import sqlalchemy as sa

revision = "20260608_0006"
down_revision = "20260506_0005"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"users",
sa.Column("mfa_secret", sa.String(length=64), nullable=True),
)


def downgrade() -> None:
op.drop_column("users", "mfa_secret")
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""usage_logs table

Revision ID: 20260608_0007
Revises: 20260608_0006
"""

from __future__ import annotations

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

revision = "20260608_0007"
down_revision = "20260608_0006"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"usage_logs",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("org_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("event_type", sa.String(length=32), nullable=False),
sa.Column("token_count", sa.Integer(), nullable=False, server_default="0"),
sa.Column("document_count", sa.Integer(), nullable=False, server_default="0"),
sa.Column(
"metadata",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text("'{}'::jsonb"),
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.ForeignKeyConstraint(["org_id"], ["organizations.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="SET NULL"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("ix_usage_logs_org_id", "usage_logs", ["org_id"], unique=False)
op.create_index(
"ix_usage_logs_event_type", "usage_logs", ["event_type"], unique=False
)
op.create_index(
"ix_usage_logs_created_at", "usage_logs", ["created_at"], unique=False
)
op.create_index(
"ix_usage_logs_org_created",
"usage_logs",
["org_id", "created_at"],
unique=False,
)


def downgrade() -> None:
op.drop_index("ix_usage_logs_org_created", table_name="usage_logs")
op.drop_index("ix_usage_logs_created_at", table_name="usage_logs")
op.drop_index("ix_usage_logs_event_type", table_name="usage_logs")
op.drop_index("ix_usage_logs_org_id", table_name="usage_logs")
op.drop_table("usage_logs")
48 changes: 46 additions & 2 deletions backend/app/adapters/outbound/db/sqlalchemy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
from typing import Any
import uuid

from sqlalchemy import DateTime, ForeignKey, Index, String, Text, UniqueConstraint, func
from sqlalchemy import (
DateTime,
ForeignKey,
Index,
Integer,
String,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship

Expand All @@ -26,6 +35,7 @@ class UserORM(Base):
is_active: Mapped[bool] = mapped_column(nullable=False, default=True)
is_platform_admin: Mapped[bool] = mapped_column(nullable=False, default=False)
mfa_enabled: Mapped[bool] = mapped_column(nullable=False, default=False)
mfa_secret: Mapped[str | None] = mapped_column(String(64), nullable=True)
metadata_json: Mapped[dict] = mapped_column(
"metadata", JSONB, default=dict, nullable=False
)
Expand Down Expand Up @@ -104,7 +114,11 @@ class OrganizationMembershipORM(Base):
nullable=False,
index=True,
)
user: Mapped["UserORM"] = relationship("UserORM", lazy="joined")
# Disambiguate: this table has two FKs to users.id (user_id and
# invited_by_user_id); the member relationship follows user_id.
user: Mapped["UserORM"] = relationship(
"UserORM", lazy="joined", foreign_keys=[user_id]
)
org: Mapped["OrganizationORM"] = relationship(
"OrganizationORM", back_populates="members"
)
Expand Down Expand Up @@ -325,3 +339,33 @@ class ChatMessageORM(Base):
)

session: Mapped[ChatSessionORM] = relationship(back_populates="messages")


class UsageLogORM(Base):
__tablename__ = "usage_logs"

id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=uuid4
)
org_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("organizations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
user_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
)
event_type: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
token_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
document_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
metadata_json: Mapped[dict] = mapped_column(
"metadata", JSONB, default=dict, nullable=False
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), index=True
)

__table_args__ = (Index("ix_usage_logs_org_created", "org_id", "created_at"),)
Loading
Loading