Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions backend/app/database/face_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,65 @@ def db_get_images_by_cluster_id(
return images
finally:
conn.close()


def db_get_images_by_face_clusters(
cluster_ids: List[str], # TEXT UUIDs — NOT integers
match_mode: str = "match_any", # "match_any" | "match_all"
) -> List[Dict]:
"""
Return images containing the requested face cluster identities,
ranked by how many of those identities appear in each image.
"""
if not cluster_ids:
return []

placeholders = ", ".join("?" * len(cluster_ids))
params: list = list(cluster_ids)

base_sql = f"""
SELECT
i.id AS image_id,
i.path AS image_path,
i.thumbnailPath AS thumbnail_path,
i.metadata,
COUNT(DISTINCT f.cluster_id) AS match_count
FROM images i
INNER JOIN faces f ON i.id = f.image_id
WHERE f.cluster_id IN ({placeholders})
GROUP BY i.id, i.path, i.thumbnailPath, i.metadata
{{having}}
ORDER BY match_count DESC
"""

if match_mode == "match_all":
having = "HAVING COUNT(DISTINCT f.cluster_id) = ?"
params.append(len(cluster_ids))
else:
having = ""

sql = base_sql.format(having=having)

import json

conn = sqlite3.connect(DATABASE_PATH)
try:
cursor = conn.cursor()
cursor.execute(sql, params)
rows = cursor.fetchall()
results = []
for row in rows:
image_id, image_path, thumbnail_path, metadata_raw, match_count = row
metadata = json.loads(metadata_raw) if metadata_raw else None
results.append(
{
"image_id": image_id,
"image_path": image_path,
"thumbnail_path": thumbnail_path,
"metadata": metadata,
"match_count": match_count,
}
)
return results
finally:
conn.close()
67 changes: 67 additions & 0 deletions backend/app/routes/face_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
db_update_cluster,
db_get_all_clusters_with_face_counts,
db_get_images_by_cluster_id,
db_get_images_by_face_clusters,
)
from app.utils.face_clusters import cluster_util_face_clusters_sync
from app.schemas.face_clusters import (
Expand All @@ -25,6 +26,10 @@
GetClusterImagesResponse,
GetClusterImagesData,
ImageInCluster,
MultiPersonSearchRequest,
MultiPersonSearchResponse,
MultiPersonSearchData,
MultiPersonSearchImage,
)
from app.schemas.images import FaceSearchRequest, InputType
from app.utils.faceSearch import perform_face_search
Expand Down Expand Up @@ -347,3 +352,65 @@ def trigger_global_reclustering():
message=f"Global reclustering failed: {str(e)}",
).model_dump(),
)


@router.post(
"/multi-search",
response_model=MultiPersonSearchResponse,
responses={code: {"model": ErrorResponse} for code in [400, 404, 500]},
)
def search_images_by_multiple_faces(body: MultiPersonSearchRequest):
"""Search for images containing multiple face identities, ranked by match count."""
try:
if not body.cluster_ids:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorResponse(
success=False,
error="Validation Error",
message="cluster_ids cannot be empty.",
).model_dump(),
)
if body.match_mode not in ("match_any", "match_all"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorResponse(
success=False,
error="Validation Error",
message="match_mode must be 'match_any' or 'match_all'.",
).model_dump(),
)

rows = db_get_images_by_face_clusters(body.cluster_ids, body.match_mode)

images = [
MultiPersonSearchImage(
id=row["image_id"],
path=row["image_path"],
thumbnailPath=row["thumbnail_path"],
metadata=row["metadata"],
match_count=row["match_count"],
)
for row in rows
]

return MultiPersonSearchResponse(
success=True,
message=f"Found {len(images)} image(s) matching the selected people.",
data=MultiPersonSearchData(
images=images,
total=len(images),
match_mode=body.match_mode,
),
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ErrorResponse(
success=False,
error="Internal server error",
message=f"Multi-person search failed: {str(e)}",
).model_dump(),
)
26 changes: 26 additions & 0 deletions backend/app/schemas/face_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,29 @@ class GlobalReclusterResponse(BaseModel):
message: Optional[str] = None
error: Optional[str] = None
data: Optional[GlobalReclusterData] = None


class MultiPersonSearchRequest(BaseModel):
cluster_ids: List[str]
match_mode: str = "match_any"


class MultiPersonSearchImage(BaseModel):
id: str
path: str
thumbnailPath: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
match_count: int


class MultiPersonSearchData(BaseModel):
images: List[MultiPersonSearchImage]
total: int
match_mode: str


class MultiPersonSearchResponse(BaseModel):
success: bool
message: Optional[str] = None
error: Optional[str] = None
data: Optional[MultiPersonSearchData] = None
Loading
Loading