Skip to content

Commit b4fa415

Browse files
feat: filter alerts/sequences/slack by daily fire risk score (#583)
* feat: add risk-api client and sequence confidence helper * feat: schedule daily risk-api refresh in app lifespan * feat: filter alerts and sequences by risk-driven confidence threshold * feat: skip slack alert when sequence max conf below risk threshold * feat: query risk-api per date for from_date endpoints * test: cover risk-driven filtering of alerts and sequences * chore: apply ruff fixes * fix: clamp risk-refresh hour to 0..23 to avoid retry loop * feat: scope from_date risk lookup to caller organization, normalize fwi class casing * chore: rename risk-api env vars to RISK_API_URL/LOGIN/PWD * feat: add max_conf column to sequences with backfill migration * feat: maintain sequence max_conf at ingest with atomic update * refactor: read max_conf from sequence row instead of parsing detections * test: seed max_conf directly on test sequences * fix: portable max_conf bump and validate fwi thresholds in [0,1] * feat: add risk_score query param to override fwi class on alerts endpoints * refactor: collapse risk filter helpers into one and use literal type for risk_score * fix: restore refresh() cache-replace on empty list and harden payload parser * feat: push risk filter into SQL WHERE for exact pagination * refactor: replace fwi conf settings with FWI_MIN_CONF dict in risk module * test: pagination on /sequences/all/fromdate keeps page full when filter applies * fix: drop max_conf clause collapse and route pagination test through override * fix: compute sequence max_conf from primary bbox only, ignore sibling detections * chore: apply ruff format * fix: silence mypy on case() and chained where() over join() * fix: annotate case() result as Any to satisfy mypy and reformat with ruff * test: cover _seconds_until_next_utc_hour and risk_score override on /sequences/* * chore: address risk filter review comments * fix: satisfy mypy for risk filter queries * test: cover risk refresh lifecycle * test: cover sequence risk filter endpoints * test: cover risk service http paths * test: cover alerts/fromdate risk filter and mixed-seq alert * test: parametrize keep-all assertion across moderate/high/very_high/extreme * test: pin fail-open on null max_conf and unknown cameras * test: parametrize alerts no-filter override across moderate/high/very_high/extreme
1 parent 8917c95 commit b4fa415

19 files changed

Lines changed: 1539 additions & 49 deletions

.env.example

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ POSTHOG_KEY=
2323
SUPPORT_EMAIL=
2424
TELEGRAM_TOKEN=
2525

26+
# Risk API (daily fire-weather index per camera)
27+
RISK_API_URL=
28+
RISK_API_LOGIN=
29+
RISK_API_PWD=
30+
RISK_REFRESH_HOUR_UTC=4
31+
2632
# Production-only
2733
ACME_EMAIL=
2834
BACKEND_HOST=

docker-compose.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ services:
6565
- S3_PROXY_URL=${S3_PROXY_URL}
6666
- SERVER_NAME=${SERVER_NAME}
6767
- PLATFORM_URL=${PLATFORM_URL:-https://platform.pyronear.org}
68+
- RISK_API_URL=${RISK_API_URL}
69+
- RISK_API_LOGIN=${RISK_API_LOGIN}
70+
- RISK_API_PWD=${RISK_API_PWD}
71+
- RISK_REFRESH_HOUR_UTC=${RISK_REFRESH_HOUR_UTC:-4}
6872
volumes:
6973
- ./src/:/app/
7074
command: "sh -c 'alembic upgrade head && python app/db.py && uvicorn app.main:app --reload --host 0.0.0.0 --port 5050 --proxy-headers'"

src/app/api/api_v1/endpoints/alerts.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,20 @@
99

1010
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Security, status
1111
from sqlalchemy import asc, desc
12+
from sqlalchemy.sql import ColumnElement
1213
from sqlmodel import delete, func, select
1314
from sqlmodel.ext.asyncio.session import AsyncSession
1415

1516
from app.api.dependencies import get_alert_crud, get_jwt
1617
from app.core.time import utcnow
1718
from app.crud import AlertCRUD
1819
from app.db import get_session
19-
from app.models import Alert, AlertSequence, Sequence, UserRole
20+
from app.models import Alert, AlertSequence, Camera, Sequence, UserRole
2021
from app.schemas.alerts import AlertReadWithSequences
2122
from app.schemas.login import TokenPayload
2223
from app.schemas.sequences import SequenceRead
24+
from app.services.risk import FwiClass, risk_service
25+
from app.services.sequence_confidence import max_conf_filter_clause
2326
from app.services.sequence_counts import get_detection_counts_by_sequence_ids
2427
from app.services.telemetry import telemetry_client
2528

@@ -31,22 +34,44 @@ def verify_org_rights(organization_id: int, alert: Alert) -> None:
3134
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access forbidden.")
3235

3336

34-
async def _fetch_sequences_by_alert_ids(session: AsyncSession, alert_ids: List[int]) -> Dict[int, List[Sequence]]:
37+
async def _fetch_sequences_by_alert_ids(
38+
session: AsyncSession,
39+
alert_ids: List[int],
40+
seq_filter: Union[ColumnElement[bool], None] = None,
41+
) -> Dict[int, List[Sequence]]:
3542
mapping: Dict[int, List[Sequence]] = {}
3643
if not alert_ids:
3744
return mapping
3845
seq_stmt: Any = (
3946
select(AlertSequence.alert_id, Sequence)
4047
.join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id))
4148
.where(AlertSequence.alert_id.in_(alert_ids)) # type: ignore[attr-defined]
42-
.order_by(cast(Any, AlertSequence.alert_id), desc(cast(Any, Sequence.last_seen_at)))
4349
)
50+
if seq_filter is not None:
51+
seq_stmt = seq_stmt.where(seq_filter)
52+
seq_stmt = seq_stmt.order_by(cast(Any, AlertSequence.alert_id), desc(cast(Any, Sequence.last_seen_at)))
4453
res = await session.exec(seq_stmt)
4554
for alert_id, sequence in res.all():
4655
mapping.setdefault(int(alert_id), []).append(sequence)
4756
return mapping
4857

4958

59+
async def _resolve_fwi_class_per_camera(
60+
session: AsyncSession,
61+
organization_id: int,
62+
target_date: Union[date, None] = None,
63+
override_class: Union[str, None] = None,
64+
) -> Dict[int, Union[str, None]]:
65+
"""Resolve ``{camera_id: fwi_class}`` for the org, picking override -> per-date -> today's cache."""
66+
if override_class is not None:
67+
cam_ids = (await session.exec(select(Camera.id).where(Camera.organization_id == organization_id))).all()
68+
return dict.fromkeys(cam_ids, override_class)
69+
if target_date is not None:
70+
scores = await risk_service.get_scores_for_date(target_date, organization_id=organization_id)
71+
return {cid: cls for cid, cls in scores.items()}
72+
return {cid: cls for cid, cls in risk_service.scores().items()}
73+
74+
5075
def _serialize_sequence(sequence: Sequence, detections_count: int = 0) -> SequenceRead:
5176
return SequenceRead(**sequence.model_dump(), detections_count=detections_count)
5277

@@ -113,24 +138,39 @@ async def fetch_alert_sequences(
113138
summary="Fetch all the alerts with unlabeled sequences from the last 24 hours",
114139
)
115140
async def fetch_latest_unlabeled_alerts(
141+
risk_score: Union[FwiClass, None] = Query(
142+
None, description="Override FWI class applied to every sequence; bypasses risk-api lookup."
143+
),
116144
session: AsyncSession = Depends(get_session),
117145
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
118146
) -> List[AlertReadWithSequences]:
119147
telemetry_client.capture(token_payload.sub, event="alerts-fetch-latest")
120148

121-
alerts_stmt: Any = select(Alert).join(AlertSequence, cast(Any, AlertSequence.alert_id == Alert.id))
122-
alerts_stmt = alerts_stmt.join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id))
123-
alerts_stmt = (
124-
alerts_stmt.where(Alert.organization_id == token_payload.organization_id)
125-
.where(Sequence.last_seen_at > utcnow() - timedelta(hours=24))
126-
.where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr]
149+
fwi_classes_by_camera = await _resolve_fwi_class_per_camera(
150+
session, token_payload.organization_id, override_class=risk_score
151+
)
152+
seq_filter = max_conf_filter_clause(fwi_classes_by_camera)
153+
154+
seq_match: Any = cast(
155+
Any,
156+
select(AlertSequence.alert_id).join(Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)),
157+
)
158+
seq_match = (
159+
seq_match.where(Sequence.last_seen_at > utcnow() - timedelta(hours=24)).where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr]
160+
)
161+
if seq_filter is not None:
162+
seq_match = seq_match.where(seq_filter)
163+
164+
alerts_stmt: Any = (
165+
select(Alert)
166+
.where(Alert.organization_id == token_payload.organization_id)
167+
.where(cast(Any, Alert.id).in_(seq_match))
127168
.order_by(Alert.started_at.desc()) # type: ignore[attr-defined]
128169
.limit(15)
129170
)
130-
alerts_res = await session.exec(alerts_stmt)
131-
alerts = alerts_res.unique().all()
171+
alerts = list((await session.exec(alerts_stmt)).all())
132172
alert_ids = [alert.id for alert in alerts]
133-
seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids)
173+
seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids, seq_filter)
134174
detection_counts = await get_detection_counts_by_sequence_ids(
135175
session,
136176
list({sequence.id for sequences in seq_map.values() for sequence in sequences}),
@@ -143,23 +183,35 @@ async def fetch_alerts_from_date(
143183
from_date: date = Query(),
144184
limit: Union[int, None] = Query(15, description="Maximum number of alerts to fetch"),
145185
offset: Union[int, None] = Query(0, description="Number of alerts to skip before starting to fetch"),
186+
risk_score: Union[FwiClass, None] = Query(
187+
None, description="Override FWI class applied to every sequence; bypasses risk-api lookup."
188+
),
146189
session: AsyncSession = Depends(get_session),
147190
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
148191
) -> List[AlertReadWithSequences]:
149192
telemetry_client.capture(token_payload.sub, event="alerts-fetch-from-date")
150193

194+
fwi_classes_by_camera = await _resolve_fwi_class_per_camera(
195+
session, token_payload.organization_id, target_date=from_date, override_class=risk_score
196+
)
197+
seq_filter = max_conf_filter_clause(fwi_classes_by_camera)
198+
151199
alerts_stmt: Any = (
152200
select(Alert)
153201
.where(Alert.organization_id == token_payload.organization_id)
154202
.where(func.date(Alert.started_at) == from_date)
155-
.order_by(Alert.started_at.desc()) # type: ignore[attr-defined]
156-
.limit(limit)
157-
.offset(offset)
158203
)
159-
alerts_res = await session.exec(alerts_stmt)
160-
alerts = alerts_res.all()
204+
if seq_filter is not None:
205+
seq_match: Any = select(AlertSequence.alert_id).join(
206+
Sequence, cast(Any, Sequence.id == AlertSequence.sequence_id)
207+
)
208+
seq_match = seq_match.where(seq_filter)
209+
alerts_stmt = alerts_stmt.where(cast(Any, Alert.id).in_(seq_match))
210+
alerts_stmt = alerts_stmt.order_by(Alert.started_at.desc()).limit(limit).offset(offset) # type: ignore[attr-defined]
211+
212+
alerts = list((await session.exec(alerts_stmt)).all())
161213
alert_ids = [alert.id for alert in alerts]
162-
seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids)
214+
seq_map = await _fetch_sequences_by_alert_ids(session, alert_ids, seq_filter)
163215
detection_counts = await get_detection_counts_by_sequence_ids(
164216
session,
165217
list({sequence.id for sequences in seq_map.values() for sequence in sequences}),

src/app/api/api_v1/endpoints/detections.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66

77
import json
8+
import logging
89
import re
910
from ast import literal_eval
1011
from datetime import datetime, timedelta
@@ -56,11 +57,15 @@
5657
from app.schemas.sequences import SequenceUpdate
5758
from app.services.cones import resolve_cone
5859
from app.services.overlap import compute_overlap, haversine_km
60+
from app.services.risk import risk_service
61+
from app.services.sequence_confidence import max_conf_from_bboxes
5962
from app.services.slack import slack_client
6063
from app.services.storage import s3_service, upload_file
6164
from app.services.telegram import telegram_client
6265
from app.services.telemetry import telemetry_client
6366

67+
logger = logging.getLogger("uvicorn.error")
68+
6469
router = APIRouter()
6570

6671

@@ -427,6 +432,10 @@ async def create_detection(
427432
if matched_sequence is not None:
428433
await sequences.update(matched_sequence.id, SequenceUpdate(last_seen_at=det.created_at))
429434
det = await detections.update(det.id, DetectionSequence(sequence_id=matched_sequence.id))
435+
# Only the primary bbox tracks the sequence; siblings in others_bboxes are unrelated detections.
436+
det_max_conf = max_conf_from_bboxes(det.bbox)
437+
if det_max_conf is not None:
438+
await sequences.bump_max_conf(matched_sequence.id, det_max_conf)
430439
else:
431440
det_filters: List[tuple[str, Any]] = [
432441
("camera_id", token_payload.sub),
@@ -455,6 +464,7 @@ async def create_detection(
455464
if len(overlapping_dets) >= settings.SEQUENCE_MIN_INTERVAL_DETS:
456465
first_det = min(overlapping_dets, key=lambda item: item.created_at)
457466
cone_azimuth, cone_angle = resolve_cone(pose.azimuth, first_det.bbox, camera.angle_of_view)
467+
seq_max_conf = max_conf_from_bboxes(*[d.bbox for d in overlapping_dets])
458468
sequence_ = await sequences.create(
459469
Sequence(
460470
camera_id=token_payload.sub,
@@ -464,6 +474,7 @@ async def create_detection(
464474
cone_angle=cone_angle,
465475
started_at=first_det.created_at,
466476
last_seen_at=det.created_at,
477+
max_conf=seq_max_conf,
467478
)
468479
)
469480
for det_ in overlapping_dets:
@@ -490,11 +501,20 @@ async def create_detection(
490501
if org is None:
491502
org = cast(Organization, await organizations.get(token_payload.organization_id, strict=True))
492503
if org.slack_hook:
493-
slack_payload = jsonable_encoder(det)
494-
slack_payload["sequence_azimuth"] = sequence_.sequence_azimuth
495-
background_tasks.add_task(
496-
slack_client.notify, org.slack_hook, json.dumps(slack_payload), camera.name, alert_id
497-
)
504+
min_conf = risk_service.min_confidence(camera.id)
505+
if min_conf is None or sequence_.max_conf is None or sequence_.max_conf >= min_conf:
506+
slack_payload = jsonable_encoder(det)
507+
slack_payload["sequence_azimuth"] = sequence_.sequence_azimuth
508+
background_tasks.add_task(
509+
slack_client.notify, org.slack_hook, json.dumps(slack_payload), camera.name, alert_id
510+
)
511+
else:
512+
logger.info(
513+
"Skipping Slack notification for camera %s: max conf %.3f < threshold %.3f",
514+
camera.name,
515+
sequence_.max_conf,
516+
min_conf,
517+
)
498518

499519
created.append(det)
500520

src/app/api/api_v1/endpoints/sequences.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from app.schemas.login import TokenPayload
2323
from app.schemas.sequences import SequenceLabel, SequenceRead
2424
from app.services.overlap import compute_overlap
25+
from app.services.risk import FwiClass, risk_service
26+
from app.services.sequence_confidence import max_conf_filter_clause
2527
from app.services.sequence_counts import get_detection_counts_by_sequence_ids
2628
from app.services.storage import s3_service
2729
from app.services.telemetry import telemetry_client
@@ -146,22 +148,32 @@ async def fetch_sequence_detections(
146148
summary="Fetch all the unlabeled sequences from the last 24 hours",
147149
)
148150
async def fetch_latest_unlabeled_sequences(
151+
risk_score: Union[FwiClass, None] = Query(
152+
None, description="Override FWI class applied to every sequence; bypasses risk-api lookup."
153+
),
149154
session: AsyncSession = Depends(get_session),
150155
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
151156
) -> List[SequenceRead]:
152157
telemetry_client.capture(token_payload.sub, event="sequence-fetch-latest")
153-
camera_ids = await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id))
154-
155-
fetched_sequences = (
156-
await session.exec(
157-
select(Sequence)
158-
.where(Sequence.started_at > utcnow() - timedelta(hours=24))
159-
.where(Sequence.camera_id.in_(camera_ids.all())) # type: ignore[attr-defined]
160-
.where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr]
161-
.order_by(Sequence.started_at.desc()) # type: ignore[attr-defined]
162-
.limit(15)
163-
)
158+
camera_ids = (
159+
await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id))
164160
).all()
161+
classes: dict[int, Union[str, None]] = (
162+
dict.fromkeys(camera_ids, risk_score) if risk_score is not None else dict(risk_service.scores())
163+
)
164+
165+
stmt: Any = (
166+
select(Sequence)
167+
.where(Sequence.started_at > utcnow() - timedelta(hours=24))
168+
.where(Sequence.camera_id.in_(camera_ids)) # type: ignore[attr-defined]
169+
.where(Sequence.is_wildfire.is_(None)) # type: ignore[union-attr]
170+
)
171+
seq_filter = max_conf_filter_clause(classes)
172+
if seq_filter is not None:
173+
stmt = stmt.where(seq_filter)
174+
stmt = stmt.order_by(Sequence.started_at.desc()).limit(15) # type: ignore[attr-defined]
175+
176+
fetched_sequences = (await session.exec(stmt)).all()
165177
counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences])
166178
return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences]
167179

@@ -171,23 +183,32 @@ async def fetch_sequences_from_date(
171183
from_date: date = Query(),
172184
limit: Union[int, None] = Query(15, description="Maximum number of sequences to fetch"),
173185
offset: Union[int, None] = Query(0, description="Number of sequences to skip before starting to fetch"),
186+
risk_score: Union[FwiClass, None] = Query(
187+
None, description="Override FWI class applied to every sequence; bypasses risk-api lookup."
188+
),
174189
session: AsyncSession = Depends(get_session),
175190
token_payload: TokenPayload = Security(get_jwt, scopes=[UserRole.ADMIN, UserRole.AGENT, UserRole.USER]),
176191
) -> List[SequenceRead]:
177192
telemetry_client.capture(token_payload.sub, event="sequence-fetch-from-date")
178193
# Limit to cameras in the same organization
179-
camera_ids = await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id))
180-
# Identify the sequences from that day
181-
fetched_sequences = (
182-
await session.exec(
183-
select(Sequence)
184-
.where(func.date(Sequence.started_at) == from_date)
185-
.where(Sequence.camera_id.in_(camera_ids.all())) # type: ignore[attr-defined]
186-
.order_by(Sequence.started_at.desc()) # type: ignore[attr-defined]
187-
.limit(limit)
188-
.offset(offset)
189-
)
194+
camera_ids = (
195+
await session.exec(select(Camera.id).where(Camera.organization_id == token_payload.organization_id))
190196
).all()
197+
if risk_score is not None:
198+
classes: dict[int, Union[str, None]] = dict.fromkeys(camera_ids, risk_score)
199+
else:
200+
scores = await risk_service.get_scores_for_date(from_date, organization_id=token_payload.organization_id)
201+
classes = dict(scores)
202+
203+
stmt: Any = (
204+
select(Sequence).where(func.date(Sequence.started_at) == from_date).where(Sequence.camera_id.in_(camera_ids)) # type: ignore[attr-defined]
205+
)
206+
seq_filter = max_conf_filter_clause(classes)
207+
if seq_filter is not None:
208+
stmt = stmt.where(seq_filter)
209+
stmt = stmt.order_by(Sequence.started_at.desc()).limit(limit).offset(offset) # type: ignore[attr-defined]
210+
211+
fetched_sequences = (await session.exec(stmt)).all()
191212
counts = await get_detection_counts_by_sequence_ids(session, [sequence.id for sequence in fetched_sequences])
192213
return [_serialize_sequence(sequence, counts.get(sequence.id, 0)) for sequence in fetched_sequences]
193214

src/app/core/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ def sqlachmey_uri(cls, v: str) -> str:
7777
TELEGRAM_TOKEN: Union[str, None] = os.environ.get("TELEGRAM_TOKEN")
7878
PLATFORM_URL: str = os.environ.get("PLATFORM_URL", "")
7979

80+
# Risk API (daily fire-weather index per camera)
81+
RISK_API_URL: Union[str, None] = os.environ.get("RISK_API_URL")
82+
RISK_API_LOGIN: Union[str, None] = os.environ.get("RISK_API_LOGIN")
83+
RISK_API_PWD: Union[str, None] = os.environ.get("RISK_API_PWD")
84+
RISK_REFRESH_HOUR_UTC: int = int(os.environ.get("RISK_REFRESH_HOUR_UTC") or 4)
85+
8086
# Error monitoring
8187
SENTRY_DSN: Union[str, None] = os.environ.get("SENTRY_DSN")
8288
SERVER_NAME: str = os.environ.get("SERVER_NAME", socket.gethostname())

src/app/crud/crud_sequence.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.
55

66

7-
from typing import Union
7+
from typing import Any, Union, cast
88

9+
from sqlalchemy import case, or_, update
910
from sqlmodel.ext.asyncio.session import AsyncSession
1011

1112
from app.crud.base import BaseCRUD
@@ -18,3 +19,17 @@
1819
class SequenceCRUD(BaseCRUD[Sequence, Sequence, Union[SequenceUpdate, SequenceLabel]]):
1920
def __init__(self, session: AsyncSession) -> None:
2021
super().__init__(session, Sequence)
22+
23+
async def bump_max_conf(self, sequence_id: int, candidate: float) -> None:
24+
"""Atomically raise sequences.max_conf to candidate if higher (or set if NULL).
25+
26+
Uses a portable CASE expression so it runs on SQLite as well as Postgres.
27+
"""
28+
max_conf_col = cast(Any, Sequence.max_conf)
29+
bumped: Any = cast(Any, case)(
30+
(or_(max_conf_col.is_(None), max_conf_col < candidate), candidate),
31+
else_=max_conf_col,
32+
)
33+
stmt: Any = update(Sequence).where(cast(Any, Sequence.id) == sequence_id).values(max_conf=bumped)
34+
await self.session.exec(stmt)
35+
await self.session.commit()

0 commit comments

Comments
 (0)