Skip to content

Commit ee9d007

Browse files
Adityav369claude
andauthored
Fix: re-probe unhealthy ColPali endpoints after cooldown (#398)
Previously, when an endpoint failed once it was discarded from healthy_endpoints and only re-added if ALL endpoints failed at the same time. A single transient OOM on one of N endpoints could silently halve sustained ingestion throughput until the worker process restarted. Now each endpoint is timestamped when marked unhealthy and re-included after a 60s cooldown. If it's still failing, the existing failure path will mark it unhealthy again — no change to error semantics. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6fdfe25 commit ee9d007

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

core/embedding/colpali_api_embedding_model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,27 @@ def __init__(self):
5151

5252
# Track endpoint health for failover
5353
self.healthy_endpoints: set[str] = set(self.endpoints)
54+
self._endpoint_unhealthy_since: Dict[str, float] = {}
55+
self._unhealthy_recovery_seconds: float = 60.0
5456
self._endpoint_latencies: dict[str, float] = {}
5557
self.endpoint = self.endpoints[0]
5658
self._latest_ingest_metrics: Dict[str, float] = {}
5759

60+
def _recover_endpoints(self) -> None:
61+
"""Re-include endpoints whose unhealthy cooldown has elapsed."""
62+
if not self._endpoint_unhealthy_since:
63+
return
64+
now = time.monotonic()
65+
recovered = [
66+
ep
67+
for ep, marked_at in self._endpoint_unhealthy_since.items()
68+
if now - marked_at >= self._unhealthy_recovery_seconds
69+
]
70+
for ep in recovered:
71+
self.healthy_endpoints.add(ep)
72+
self._endpoint_unhealthy_since.pop(ep, None)
73+
logger.info("Re-probing previously unhealthy ColPali endpoint: %s", ep)
74+
5875
async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[MultiVector]:
5976
ingest_start = time.monotonic()
6077
# Normalize to list
@@ -131,6 +148,7 @@ async def _embed_inputs_distributed(
131148
if not indexed_inputs:
132149
return {}
133150

151+
self._recover_endpoints()
134152
# Use healthy endpoints, fall back to all if none healthy
135153
endpoints = list(self.healthy_endpoints) if self.healthy_endpoints else self.endpoints
136154
n_endpoints = len(endpoints)
@@ -166,6 +184,7 @@ async def _embed_inputs_distributed(
166184
elif isinstance(result, Exception):
167185
logger.warning(f"Endpoint {endpoint} failed: {result}")
168186
self.healthy_endpoints.discard(endpoint)
187+
self._endpoint_unhealthy_since[endpoint] = time.monotonic()
169188
failed_inputs.extend(batch)
170189
else:
171190
merged.update(result)
@@ -182,6 +201,7 @@ async def _embed_inputs_distributed(
182201
# All endpoints failed, reset health and raise
183202
logger.error("All ColPali endpoints failed, resetting health status")
184203
self.healthy_endpoints = set(self.endpoints)
204+
self._endpoint_unhealthy_since.clear()
185205
raise RuntimeError(
186206
f"All {len(self.endpoints)} ColPali endpoints failed for {len(failed_inputs)} {input_type} inputs"
187207
)
@@ -289,6 +309,7 @@ async def _call_api_endpoint(self, endpoint: str, inputs: List[str], input_type:
289309

290310
async def embed_for_query(self, text: str) -> MultiVector:
291311
# Use first healthy endpoint for queries (single text, fast)
312+
self._recover_endpoints()
292313
endpoint = next(iter(self.healthy_endpoints), self.endpoints[0])
293314
data = await self._call_api_endpoint(endpoint, [text], "text")
294315
if not data:
@@ -304,6 +325,7 @@ async def generate_embeddings(self, content: Union[str, Image]) -> np.ndarray:
304325
Returns:
305326
numpy array of embeddings.
306327
"""
328+
self._recover_endpoints()
307329
endpoint = next(iter(self.healthy_endpoints), self.endpoints[0])
308330

309331
if isinstance(content, Image):

0 commit comments

Comments
 (0)