Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion backend/aci/common/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel, BeforeValidator, ConfigDict, Field

from aci.common.schemas.apikey import APIKeyPublic
from aci.common.schemas.apikey import APIKeyPublic, APIKeyWithSecret

MAX_INSTRUCTION_LENGTH = 5000

Expand Down Expand Up @@ -52,3 +52,9 @@ class AgentPublic(BaseModel):
api_keys: list[APIKeyPublic]

model_config = ConfigDict(from_attributes=True)


class AgentPublicWithAPIKeys(AgentPublic):
"""AgentPublic plus plaintext API keys. Only for creation-time responses."""

api_keys: list[APIKeyWithSecret]
11 changes: 10 additions & 1 deletion backend/aci/common/schemas/apikey.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,20 @@

class APIKeyPublic(BaseModel):
id: UUID
key: str
agent_id: UUID
status: APIKeyStatus

created_at: datetime
updated_at: datetime

model_config = ConfigDict(from_attributes=True)


class APIKeyWithSecret(APIKeyPublic):
"""APIKeyPublic plus the plaintext key.

Only for creation-time responses (project/agent creation), where the caller
must capture the key once. Never use as the response model for read endpoints.
"""

key: str
8 changes: 7 additions & 1 deletion backend/aci/common/schemas/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel, ConfigDict, Field

from aci.common.enums import Visibility
from aci.common.schemas.agent import AgentPublic
from aci.common.schemas.agent import AgentPublic, AgentPublicWithAPIKeys


class ProjectCreate(BaseModel):
Expand Down Expand Up @@ -39,3 +39,9 @@ class ProjectPublic(BaseModel):
agents: list[AgentPublic]

model_config = ConfigDict(from_attributes=True)


class ProjectPublicWithAPIKeys(ProjectPublic):
"""ProjectPublic plus plaintext agent API keys. Only for the creation response."""

agents: list[AgentPublicWithAPIKeys]
8 changes: 2 additions & 6 deletions backend/aci/common/schemas/security_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,13 @@ class OAuth2SchemeCredentials(BaseModel):
class OAuth2SchemeCredentialsLimited(BaseModel):
"""Limited OAuth2 credentials to expose to the client directly.

client_id and client_secret are included so that downstream clients can
use the refresh_token to mint tokens for additional resource audiences
that ACI does not directly broker (e.g. SharePoint REST in addition to
the Microsoft Graph access_token ACI manages on their behalf).
SECURITY: client_secret and refresh_token must never be exposed here.
Clients only get the current access_token; token refresh is brokered by ACI.
"""

access_token: str
expires_at: int | None = None
refresh_token: str | None = None
client_id: str | None = None
client_secret: str | None = None


class NoAuthSchemeCredentials(BaseModel, extra="forbid"):
Expand Down
65 changes: 36 additions & 29 deletions backend/aci/server/routes/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
ProjectNotFound,
)
from aci.common.logging_setup import get_logger
from aci.common.schemas.agent import AgentCreate, AgentPublic, AgentUpdate
from aci.common.schemas.project import ProjectCreate, ProjectPublic, ProjectUpdate
from aci.common.schemas.agent import AgentCreate, AgentPublic, AgentPublicWithAPIKeys, AgentUpdate
from aci.common.schemas.project import (
ProjectCreate,
ProjectPublic,
ProjectPublicWithAPIKeys,
ProjectUpdate,
)
from aci.server import config, quota_manager
from aci.server import dependencies as deps

Expand All @@ -25,7 +30,7 @@


# TODO: Once member has been introduced change the ACL to require_org_member_with_minimum_role
@router.post("", response_model=ProjectPublic, include_in_schema=True)
@router.post("", response_model=ProjectPublicWithAPIKeys, include_in_schema=True)
async def create_project(
body: ProjectCreate,
# user: Annotated[User, Depends(auth.require_user)],
Expand All @@ -48,17 +53,18 @@ async def create_project(
custom_instructions={},
)
db_session.commit()

logger.info(f"Created project, project_id={project.id}, org_id={body.org_id}")

# Convert to ProjectPublic model to avoid DetachedInstanceError
from aci.common.schemas.project import ProjectPublic
from aci.common.schemas.agent import AgentPublic
from aci.common.schemas.apikey import APIKeyPublic


# Convert to ProjectPublicWithAPIKeys model to avoid DetachedInstanceError.
# This is the only endpoint that returns the plaintext API key: callers must
# capture it here, read endpoints never return it again.
from aci.common.schemas.apikey import APIKeyWithSecret


# Load the agent that was created
agents = crud.projects.get_agents_by_project(db_session, project.id)

# Convert agents to AgentPublic models
agent_publics = []
for agent in agents:
Expand All @@ -67,9 +73,9 @@ async def create_project(
try:
api_key = crud.projects.get_api_key_by_agent_id(db_session, agent.id)
if api_key:
api_key_public = APIKeyPublic(
api_key_public = APIKeyWithSecret(
id=api_key.id,
key=api_key.key, # Include the decrypted API key
key=api_key.key, # Returned once at creation only
agent_id=api_key.agent_id,
status=api_key.status,
created_at=api_key.created_at,
Expand All @@ -79,9 +85,9 @@ async def create_project(
except Exception as e:
logger.error(f"Error loading API key for agent {agent.id}: {e}")
# Continue without API key - this allows the project creation to succeed
# Create AgentPublic model
agent_public = AgentPublic(

# Create AgentPublicWithAPIKeys model
agent_public = AgentPublicWithAPIKeys(
id=agent.id,
project_id=agent.project_id,
name=agent.name,
Expand All @@ -93,9 +99,9 @@ async def create_project(
api_keys=api_key_publics,
)
agent_publics.append(agent_public)
# Create ProjectPublic model
project_public = ProjectPublic(

# Create ProjectPublicWithAPIKeys model
project_public = ProjectPublicWithAPIKeys(
id=project.id,
org_id=project.org_id,
name=project.name,
Expand All @@ -109,7 +115,7 @@ async def create_project(
updated_at=project.updated_at,
agents=agent_publics,
)

return project_public


Expand All @@ -128,13 +134,13 @@ async def get_projects(

try:
projects = crud.projects.get_projects_by_org(db_session, org_id)

# Return a response with agents but handle Unicode errors gracefully
simplified_projects = []
for project in projects:
# Load agents for this project
agents = crud.projects.get_agents_by_project(db_session, project.id)

# Convert agents to simple dict format with error handling for API keys
agent_list = []
for agent in agents:
Expand All @@ -143,20 +149,21 @@ async def get_projects(
try:
api_key = crud.projects.get_api_key_by_agent_id(db_session, agent.id)
if api_key:
# NOTE: the plaintext key is intentionally NOT included here;
# it is only returned once by the project creation endpoint.
api_key_dict = {
"id": str(api_key.id),
"key": api_key.key, # Include the decrypted API key
"agent_id": str(api_key.agent_id),
"status": api_key.status.value,
"created_at": api_key.created_at.isoformat() if api_key.created_at else None,
"updated_at": api_key.updated_at.isoformat() if api_key.updated_at else None,
}

api_keys_list.append(api_key_dict)
except Exception as e:
logger.error(f"Error loading API key for agent {agent.id}: {e}")
# Continue without API key - this allows the project to load

agent_dict = {
"id": str(agent.id),
"project_id": str(agent.project_id),
Expand All @@ -169,7 +176,7 @@ async def get_projects(
"api_keys": api_keys_list
}
agent_list.append(agent_dict)

simplified_project = {
"id": str(project.id),
"org_id": project.org_id,
Expand All @@ -185,9 +192,9 @@ async def get_projects(
"agents": agent_list
}
simplified_projects.append(simplified_project)

return simplified_projects

except Exception as e:
logger.error(f"Error getting projects: {e}")
# Return empty list if there's an error, so frontend can still load
Expand Down Expand Up @@ -258,7 +265,7 @@ async def update_project(
return updated_project


@router.post("/{project_id}/agents", response_model=AgentPublic, include_in_schema=True)
@router.post("/{project_id}/agents", response_model=AgentPublicWithAPIKeys, include_in_schema=True)
async def create_agent(
project_id: UUID,
body: AgentCreate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@ def test_get_linked_account_with_oauth2_credentials(
linked_account = response.json()
security_credentials: dict[str, str] = linked_account["security_credentials"]
assert security_credentials["access_token"], "OAuth2 credentials should contain access_token"
# NOTE: expires_at and refresh_token are optional, but they exist in this mock linked account
# NOTE: expires_at is optional, but it exists in this mock linked account
assert security_credentials["expires_at"], "OAuth2 credentials should contain expires_at"
assert security_credentials["refresh_token"], "OAuth2 credentials should contain refresh_token"
# SECURITY: refresh_token and client_secret must never be exposed to clients
assert "refresh_token" not in security_credentials
assert "client_secret" not in security_credentials


def test_get_linked_account_with_expired_oauth2_credentials(
Expand Down Expand Up @@ -141,5 +143,7 @@ def test_get_linked_account_with_expired_oauth2_credentials(
assert int(security_credentials["expires_at"]) == (
mock_current_time + int(mock_refresh_response["expires_in"])
)
assert security_credentials["refresh_token"] == mock_refresh_response["refresh_token"]
# SECURITY: refresh_token and client_secret must never be exposed to clients
assert "refresh_token" not in security_credentials
assert "client_secret" not in security_credentials
mock_refresh.assert_called_once()
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from aci.common.db import crud
from aci.common.db.sql_models import Agent, APIKey, App, Project
from aci.common.schemas.agent import AgentCreate, AgentPublic, AgentUpdate
from aci.common.schemas.agent import AgentCreate, AgentPublic, AgentPublicWithAPIKeys, AgentUpdate
from aci.server import config
from aci.server.tests.conftest import DummyUser

Expand All @@ -30,7 +30,7 @@ def test_create_agent(
headers={"Authorization": f"Bearer {dummy_user.access_token}"},
)
assert response.status_code == status.HTTP_200_OK
agent_public = AgentPublic.model_validate(response.json())
agent_public = AgentPublicWithAPIKeys.model_validate(response.json())
assert agent_public.name == body.name
assert agent_public.description == body.description
assert agent_public.project_id == dummy_project_1.id
Expand All @@ -41,7 +41,7 @@ def test_create_agent(
).scalar_one_or_none()

assert agent is not None
assert agent_public.model_dump() == AgentPublic.model_validate(agent).model_dump()
assert agent_public.model_dump() == AgentPublicWithAPIKeys.model_validate(agent).model_dump()

# check api keys
api_key = db_session.execute(
Expand Down