Skip to content

Commit 8917c95

Browse files
authored
Add detections_count to sequence and alert responses (#559)
* feat(schemas): add detection_count to SequenceRead schema * feat(endpoints): add helper function to get sequence count * feat(endpoints): add sequence count to alerts endpoints * feat(endpoints): add sequence count to sequences endpoint * Fix local CI ordering and ruff issues * refactor(services): move detection-count helper to services/ Per review feedback on PR #559: get_detection_counts_by_sequence_ids is an aggregation query, not a CRUD/route helper, so it belongs in services/ where it can be reused outside of endpoint modules. * refactor(alerts): align _serialize_sequence signature across endpoints Per review feedback on PR #559: _serialize_sequence in alerts.py took the full Dict[int, int] of counts and did the lookup inside. Match the sequences.py version (scalar int param, caller does the lookup) so the helper has one shape repo-wide and only does serialization. * refactor: drop redundant int() casts on model id attributes Per review feedback on PR #559: Sequence.id and Alert.id are typed int on the model, so int(sequence.id) / int(alert.id) is redundant. Same for the dict comprehension in get_detection_counts_by_sequence_ids where the row values are already int. The int(alert_id) at alerts.py:46 stays because alert_id there is unpacked from a raw SQL row tuple, not a model attribute. * test(alerts): use strict=True when zipping fixture sequences and counts Per review feedback on PR #559: strict=False silently swallows length mismatches. Both lists have length 3 today, so flipping to strict=True guards against future fixture edits drifting out of sync. * test(sequences): use utcnow() helper in detections-count test Rebase fallout from #574 (datetime.utcnow deprecation): one new test function still called datetime.utcnow() while the datetime import had been removed from the file, causing NameError on collection.
1 parent 1f67522 commit 8917c95

7 files changed

Lines changed: 196 additions & 26 deletions

File tree

src/app/api/api_v1/endpoints/alerts.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from app.schemas.alerts import AlertReadWithSequences
2121
from app.schemas.login import TokenPayload
2222
from app.schemas.sequences import SequenceRead
23+
from app.services.sequence_counts import get_detection_counts_by_sequence_ids
2324
from app.services.telemetry import telemetry_client
2425

2526
router = APIRouter()
@@ -46,10 +47,16 @@ async def _fetch_sequences_by_alert_ids(session: AsyncSession, alert_ids: List[i
4647
return mapping
4748

4849

49-
def _serialize_alert(alert: Alert, sequences: List[Sequence]) -> AlertReadWithSequences:
50+
def _serialize_sequence(sequence: Sequence, detections_count: int = 0) -> SequenceRead:
51+
return SequenceRead(**sequence.model_dump(), detections_count=detections_count)
52+
53+
54+
def _serialize_alert(
55+
alert: Alert, sequences: List[Sequence], detection_counts: Dict[int, int]
56+
) -> AlertReadWithSequences:
5057
return AlertReadWithSequences(
5158
**alert.model_dump(),
52-
sequences=[SequenceRead(**seq.model_dump()) for seq in sequences],
59+
sequences=[_serialize_sequence(sequence, detection_counts.get(sequence.id, 0)) for sequence in sequences],
5360
)
5461

5562

@@ -66,9 +73,11 @@ async def get_alert(
6673
if UserRole.ADMIN not in token_payload.scopes:
6774
verify_org_rights(token_payload.organization_id, alert)
6875

69-
alert_id_int = int(alert.id)
70-
seq_map = await _fetch_sequences_by_alert_ids(session, [alert_id_int])
71-
return _serialize_alert(alert, seq_map.get(alert_id_int, []))
76+
seq_map = await _fetch_sequences_by_alert_ids(session, [alert.id])
77+
detection_counts = await get_detection_counts_by_sequence_ids(
78+
session, [sequence.id for sequence in seq_map.get(alert.id, [])]
79+
)
80+
return _serialize_alert(alert, seq_map.get(alert.id, []), detection_counts)
7281

7382

7483
@router.get(
@@ -81,7 +90,7 @@ async def fetch_alert_sequences(
8190
alerts: AlertCRUD = Depends(get_alert_crud),
8291
session: AsyncSession = Depends(get_session),
8392
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
84-
) -> List[Sequence]:
93+
) -> List[SequenceRead]:
8594
telemetry_client.capture(token_payload.sub, event="alerts-sequences-get", properties={"alert_id": alert_id})
8695
alert = cast(Alert, await alerts.get(alert_id, strict=True))
8796
if UserRole.ADMIN not in token_payload.scopes:
@@ -93,7 +102,9 @@ async def fetch_alert_sequences(
93102
seq_stmt = seq_stmt.where(AlertSequence.alert_id == alert_id).order_by(order_clause).limit(limit)
94103

95104
res = await session.exec(seq_stmt)
96-
return list(res.all())
105+
sequences = list(res.all())
106+
detection_counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in sequences])
107+
return [_serialize_sequence(sequence, detection_counts.get(sequence.id, 0)) for sequence in sequences]
97108

98109

99110
@router.get(
@@ -118,9 +129,13 @@ async def fetch_latest_unlabeled_alerts(
118129
)
119130
alerts_res = await session.exec(alerts_stmt)
120131
alerts = alerts_res.unique().all()
121-
alert_ids = [int(alert.id) for alert in alerts]
132+
alert_ids = [alert.id for alert in alerts]
122133
seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids)
123-
return [_serialize_alert(alert, seq_map.get(int(alert.id), [])) for alert in alerts]
134+
detection_counts = await get_detection_counts_by_sequence_ids(
135+
session,
136+
list({sequence.id for sequences in seq_map.values() for sequence in sequences}),
137+
)
138+
return [_serialize_alert(alert, seq_map.get(alert.id, []), detection_counts) for alert in alerts]
124139

125140

126141
@router.get("/all/fromdate", status_code=status.HTTP_200_OK, summary="Fetch all the alerts for a specific date")
@@ -143,9 +158,13 @@ async def fetch_alerts_from_date(
143158
)
144159
alerts_res = await session.exec(alerts_stmt)
145160
alerts = alerts_res.all()
146-
alert_ids = [int(alert.id) for alert in alerts]
161+
alert_ids = [alert.id for alert in alerts]
147162
seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids)
148-
return [_serialize_alert(alert, seq_map.get(int(alert.id), [])) for alert in alerts]
163+
detection_counts = await get_detection_counts_by_sequence_ids(
164+
session,
165+
list({sequence.id for sequences in seq_map.values() for sequence in sequences}),
166+
)
167+
return [_serialize_alert(alert, seq_map.get(alert.id, []), detection_counts) for alert in alerts]
149168

150169

151170
@router.delete("/{alert_id}", status_code=status.HTTP_200_OK, summary="Delete an alert")

src/app/api/api_v1/endpoints/detections.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ async def fetch_detections(
554554
) -> List[DetectionRead]:
555555
telemetry_client.capture(token_payload.sub, event="detections-fetch")
556556
if UserRole.ADMIN in token_payload.scopes:
557-
return [DetectionRead(**elt.model_dump()) for elt in await detections.fetch_all()]
557+
return [DetectionRead(**elt.model_dump()) for elt in await detections.fetch_all(order_by="id")]
558558

559559
cameras_list = await cameras.fetch_all(filters=("organization_id", token_payload.organization_id))
560560
camera_ids = [camera.id for camera in cameras_list]

src/app/api/api_v1/endpoints/sequences.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from app.schemas.login import TokenPayload
2323
from app.schemas.sequences import SequenceLabel, SequenceRead
2424
from app.services.overlap import compute_overlap
25+
from app.services.sequence_counts import get_detection_counts_by_sequence_ids
2526
from app.services.storage import s3_service
2627
from app.services.telemetry import telemetry_client
2728

@@ -83,20 +84,26 @@ async def _refresh_alert_state(alert_id: int, session: AsyncSession, alerts: Ale
8384
)
8485

8586

87+
def _serialize_sequence(sequence: Sequence, detections_count: int = 0) -> SequenceRead:
88+
return SequenceRead(**sequence.model_dump(), detections_count=detections_count)
89+
90+
8691
@router.get("/{sequence_id}", status_code=status.HTTP_200_OK, summary="Fetch the information of a specific sequence")
8792
async def get_sequence(
8893
sequence_id: int = Path(..., gt=0),
8994
cameras: CameraCRUD = Depends(get_camera_crud),
9095
sequences: SequenceCRUD = Depends(get_sequence_crud),
96+
session: AsyncSession = Depends(get_session),
9197
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
92-
) -> Sequence:
98+
) -> SequenceRead:
9399
telemetry_client.capture(token_payload.sub, event="sequences-get", properties={"sequence_id": sequence_id})
94100
sequence = cast(Sequence, await sequences.get(sequence_id, strict=True))
95101

96102
if UserRole.ADMIN not in token_payload.scopes:
97103
await verify_org_rights(token_payload.organization_id, sequence.camera_id, cameras)
98104

99-
return SequenceRead(**sequence.model_dump())
105+
counts = await get_detection_counts_by_sequence_ids(session, [sequence.id])
106+
return _serialize_sequence(sequence, counts.get(sequence.id, 0))
100107

101108

102109
@router.get(
@@ -155,7 +162,8 @@ async def fetch_latest_unlabeled_sequences(
155162
.limit(15)
156163
)
157164
).all()
158-
return [SequenceRead(**elt.model_dump()) for elt in fetched_sequences]
165+
counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences])
166+
return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences]
159167

160168

161169
@router.get("/all/fromdate", status_code=status.HTTP_200_OK, summary="Fetch all the sequences for a specific date")
@@ -180,7 +188,8 @@ async def fetch_sequences_from_date(
180188
.offset(offset)
181189
)
182190
).all()
183-
return [SequenceRead(**elt.model_dump()) for elt in fetched_sequences]
191+
counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences])
192+
return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences]
184193

185194

186195
@router.delete("/{sequence_id}", status_code=status.HTTP_200_OK, summary="Delete a sequence")

src/app/schemas/sequences.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ class SequenceLabel(BaseModel):
2222

2323

2424
class SequenceRead(Sequence):
25-
pass
25+
detections_count: int = 0
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (C) 2025-2026, Pyronear.
2+
3+
# This program is licensed under the Apache License 2.0.
4+
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
5+
6+
7+
from typing import Any, Dict, List, cast
8+
9+
from sqlmodel import func, select
10+
from sqlmodel.ext.asyncio.session import AsyncSession
11+
12+
from app.models import Detection
13+
14+
15+
async def get_detection_counts_by_sequence_ids(session: AsyncSession, sequence_ids: List[int]) -> Dict[int, int]:
16+
if not sequence_ids:
17+
return {}
18+
19+
stmt: Any = (
20+
select(cast(Any, Detection.sequence_id), func.count(Detection.id))
21+
.where(cast(Any, Detection.sequence_id).in_(sequence_ids))
22+
.group_by(cast(Any, Detection.sequence_id))
23+
)
24+
res = await session.exec(stmt)
25+
return {sequence_id: detections_count for sequence_id, detections_count in res.all() if sequence_id is not None}

src/tests/endpoints/test_alerts.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,18 @@
1414

1515
from app.core.config import settings
1616
from app.core.time import utcnow
17-
from app.models import Alert, AlertSequence, AnnotationType, Camera, Organization, Pose, Sequence
17+
from app.models import Alert, AlertSequence, AnnotationType, Camera, Detection, Organization, Pose, Sequence
1818
from app.services.overlap import compute_overlap
1919

2020

2121
async def _create_alert_with_sequences(
2222
session: AsyncSession, org_id: int, camera_id: int, lat: float, lon: float
23-
) -> Tuple[Alert, List[int]]:
23+
) -> Tuple[Alert, List[int], List[int]]:
2424
now = utcnow()
25+
pose = (
26+
await session.exec(select(Pose).where(Pose.camera_id == camera_id).order_by(Pose.id)) # type: ignore[attr-defined]
27+
).first()
28+
assert pose is not None
2529
seq_payloads = [
2630
{
2731
"camera_id": camera_id,
@@ -48,6 +52,7 @@ async def _create_alert_with_sequences(
4852
"cone_angle": 3.0,
4953
},
5054
]
55+
detections_count_by_sequence = [2, 1, 0]
5156
sequences: List[Sequence] = []
5257
for idx, payload in enumerate(seq_payloads):
5358
seq = Sequence(
@@ -60,6 +65,20 @@ async def _create_alert_with_sequences(
6065
await session.commit()
6166
for seq in sequences:
6267
await session.refresh(seq)
68+
for sequence, detections_count in zip(sequences, detections_count_by_sequence, strict=True):
69+
for det_idx in range(detections_count):
70+
session.add(
71+
Detection(
72+
camera_id=sequence.camera_id,
73+
pose_id=pose.id,
74+
sequence_id=sequence.id,
75+
bucket_key=f"alert-seq-{sequence.id}-{det_idx}.jpg",
76+
bbox="[(.1,.1,.7,.8,.9)]",
77+
others_bboxes=None,
78+
created_at=now - timedelta(seconds=det_idx),
79+
)
80+
)
81+
await session.commit()
6382

6483
alert = Alert(
6584
organization_id=org_id,
@@ -75,14 +94,15 @@ async def _create_alert_with_sequences(
7594
for seq in sequences:
7695
session.add(AlertSequence(alert_id=alert.id, sequence_id=seq.id))
7796
await session.commit()
78-
return alert, [seq.id for seq in sequences]
97+
return alert, [seq.id for seq in sequences], detections_count_by_sequence
7998

8099

81100
@pytest.mark.asyncio
82101
async def test_get_alert_and_sequences(async_client: AsyncClient, detection_session: AsyncSession):
83-
alert, seq_ids = await _create_alert_with_sequences(
102+
alert, seq_ids, detections_count_by_sequence = await _create_alert_with_sequences(
84103
detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256
85104
)
105+
expected_counts = dict(zip(seq_ids, detections_count_by_sequence, strict=False))
86106

87107
auth = pytest.get_token(
88108
pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"]
@@ -97,19 +117,23 @@ async def test_get_alert_and_sequences(async_client: AsyncClient, detection_sess
97117
assert payload["started_at"] == alert.started_at.isoformat()
98118
assert payload["last_seen_at"] == alert.last_seen_at.isoformat()
99119
assert {seq["id"] for seq in payload["sequences"]} == set(seq_ids)
120+
assert {seq["id"]: seq["detections_count"] for seq in payload["sequences"]} == expected_counts
100121

101122
resp = await async_client.get(f"/alerts/{alert.id}/sequences?limit=5&desc=true", headers=auth)
102123
assert resp.status_code == 200, resp.text
103124
returned = resp.json()
104125
last_seen_times = [item["last_seen_at"] for item in returned]
105126
assert last_seen_times == sorted(last_seen_times, reverse=True)
127+
assert {sequence["id"]: sequence["detections_count"] for sequence in returned} == expected_counts
128+
assert any(sequence["detections_count"] == 0 for sequence in returned)
106129

107130

108131
@pytest.mark.asyncio
109132
async def test_alerts_unlabeled_latest(async_client: AsyncClient, detection_session: AsyncSession):
110-
alert, seq_ids = await _create_alert_with_sequences(
133+
alert, seq_ids, detections_count_by_sequence = await _create_alert_with_sequences(
111134
detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256
112135
)
136+
expected_counts = dict(zip(seq_ids, detections_count_by_sequence, strict=False))
113137

114138
auth = pytest.get_token(
115139
pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"]
@@ -124,13 +148,16 @@ async def test_alerts_unlabeled_latest(async_client: AsyncClient, detection_sess
124148
assert returned["started_at"] == alert.started_at.isoformat()
125149
assert returned["last_seen_at"] == alert.last_seen_at.isoformat()
126150
assert {seq["id"] for seq in returned["sequences"]} == set(seq_ids)
151+
assert {seq["id"]: seq["detections_count"] for seq in returned["sequences"]} == expected_counts
152+
assert any(seq["detections_count"] == 0 for seq in returned["sequences"])
127153

128154

129155
@pytest.mark.asyncio
130156
async def test_alerts_from_date(async_client: AsyncClient, detection_session: AsyncSession):
131-
alert, seq_ids = await _create_alert_with_sequences(
157+
alert, seq_ids, detections_count_by_sequence = await _create_alert_with_sequences(
132158
detection_session, org_id=1, camera_id=1, lat=48.3856355, lon=2.7323256
133159
)
160+
expected_counts = dict(zip(seq_ids, detections_count_by_sequence, strict=False))
134161
date_str = alert.started_at.date().isoformat()
135162

136163
auth = pytest.get_token(
@@ -146,6 +173,8 @@ async def test_alerts_from_date(async_client: AsyncClient, detection_session: As
146173
assert started_times == sorted(started_times, reverse=True)
147174
alert_payload = next(item for item in returned if item["id"] == alert.id)
148175
assert {seq["id"] for seq in alert_payload["sequences"]} == set(seq_ids)
176+
assert {seq["id"]: seq["detections_count"] for seq in alert_payload["sequences"]} == expected_counts
177+
assert any(seq["detections_count"] == 0 for seq in alert_payload["sequences"])
149178

150179

151180
@pytest.mark.asyncio

0 commit comments

Comments
 (0)