Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 70 additions & 33 deletions client.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,87 @@
import requests
from typing import Dict, Any, List, Optional
import logging
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class InferenceClient:
"""Client for calling the inference service"""
"""Client for calling the inference service
Uses connection pooling for efficient HTTP requests.

def __init__(self, base_url: str, timeout: int = 30):
"""
Args:
base_url: Base URL of the inference service (e.g., "http://inference-service:8000")
timeout: Request timeout in seconds
"""
Parameters
----------
base_url : str
Base URL of the inference service (e.g., "http://inference-service:8000").
timeout : int, optional
Request timeout in seconds. Default is 30.
max_retries : int, optional
Maximum number of retries for failed requests. Default is 3.
pool_connections : int, optional
Number of connection pools to cache. Default is 10.
pool_maxsize : int, optional
Maximum number of connections to save in the pool. Default is 10.
"""

def __init__(
self,
base_url: str,
timeout: int = 30,
max_retries: int = 3,
pool_connections: int = 10,
pool_maxsize: int = 10
):
self.base_url = base_url.rstrip('/')
self.timeout = timeout
self.session = requests.Session() # Reuse connection
logger.info(f"Initialized client for {self.base_url}")

# Create session with connection pooling
self.session = requests.Session()

# Configure retry strategy
retry_strategy = Retry(
total=max_retries,
backoff_factor=0.5, # Wait 0.5s, 1s, 2s between retries
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["HEAD", "GET", "POST", "OPTIONS"]
)

# Configure HTTP adapter with connection pooling
adapter = HTTPAdapter(
max_retries=retry_strategy,
pool_connections=pool_connections,
pool_maxsize=pool_maxsize
)

# Mount adapter for both http and https
self.session.mount("http://", adapter)
self.session.mount("https://", adapter)

# Set default headers (avoid connection close)
self.session.headers.update({
'Connection': 'keep-alive',
'Content-Type': 'application/json'
})

logger.info(f"Initialized client for {self.base_url} with connection pooling")

def __del__(self):
"""Cleanup: Close session on deletion"""
if hasattr(self, 'session'):
self.session.close()


def health_check(self) -> bool:
"""Check if service is healthy"""
"""
Check if the inference service is healthy.

Returns
-------
bool
True if the service responds with status 200, False otherwise.
"""
try:
response = self.session.get(
f"{self.base_url}/health",
Expand Down Expand Up @@ -93,27 +153,4 @@ def predict_batch(self, inputs_list: List[Dict[str, float]]) -> Dict[str, Any]:
timeout=self.timeout * 2 # Longer timeout for batch
)
response.raise_for_status()
return response.json()

def load_model(self, model_name: str, model_version: Optional[str] = None) -> Dict[str, Any]:
"""
Load a different model

Args:
model_name: Name of the model to load
model_version: Version or stage (optional)

Returns:
Load response with status
"""
payload = {"model_name": model_name}
if model_version:
payload["model_version"] = model_version

response = self.session.post(
f"{self.base_url}/model/load",
json=payload,
timeout=60 # Model loading can take longer
)
response.raise_for_status()
return response.json()
1 change: 1 addition & 0 deletions copier-template-k8s/copier.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# copier-template-k8s/copier.yml

_templates_suffix: .jinja
_answers_file: .copier-answers.yml

# Questions to ask when generating
service_name:
Expand Down
1 change: 1 addition & 0 deletions copier-template-k8s/{{_copier_conf.answers_file}}.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{{ _copier_answers|to_nice_yaml -}}
21 changes: 21 additions & 0 deletions inference_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ class ErrorResponse(BaseModel):
error: str
detail: Optional[str] = None

class VariableTypesResponse(BaseModel):
input_types: Dict[str, str]
output_types: Dict[str, str]




def download_model_artifacts(model_name: str, model_version: Optional[str] = None) -> tuple[str, str]:
"""
Expand Down Expand Up @@ -452,6 +458,21 @@ async def get_model_inputs(model: TorchModel = Depends(get_model)):
input_variables=input_variables
)

@app.get("/inputs/types", response_model=VariableTypesResponse)
async def get_variable_types(model: TorchModel = Depends(get_model)):
"""Get the types of all input and output variables"""
input_types = {}
for var in model.input_variables:
input_types[var.name] = var.__class__.__name__

output_types = {}
for var in model.output_variables:
output_types[var.name] = var.__class__.__name__

return VariableTypesResponse(
input_types=input_types,
output_types=output_types
)

@app.get("/outputs", response_model=ModelOutputsResponse)
async def get_model_outputs(model: TorchModel = Depends(get_model)):
Expand Down
24 changes: 24 additions & 0 deletions ingress-test/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
## Usage

Single client (1 request/second):
```python stress_test.py```

Faster rate (10 requests/second):
```python stress_test.py --rate 10```

Run for 60 seconds:
```python stress_test.py --duration 60```

Run exactly 100 requests:
```python stress_test.py --num-requests 100```

To run multiple clients at one time
```bash
chmod +x run_multiple.sh

# 4 clients, 1 Hz each, 60 seconds
./run_multiple.sh 4 1.0 60

# View logs
tail -f client_*.log
```
25 changes: 25 additions & 0 deletions ingress-test/ingress-cu-inj.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: inference-service-cu-inj
annotations:
nginx.ingress.kubernetes.io/whitelist-source-range: "71.198.255.25,134.79.0.0/16,172.16.0.0/12,208.45.173.162,216.46.165.69"
nginx.ingress.kubernetes.io/proxy-body-size: 2g
nginx.ingress.kubernetes.io/client-max-body-size: 2g
nginx.ingress.kubernetes.io/proxy-read-timeout: "1800"
nginx.ingress.kubernetes.io/proxy-send-timeout: "1800"
nginx.ingress.kubernetes.io/proxy-request-buffering: "off"
nginx.ingress.kubernetes.io/rewrite-target: /$2
spec:
rules:
- host: "ard-modeling-service.slac.stanford.edu"
http:
paths:
- pathType: ImplementationSpecific # Change this
path: "/cuinj(/|$)(.*)" # Change this
backend:
service:
name: inference-service-cu-inj
port:
number: 8000

24 changes: 24 additions & 0 deletions ingress-test/ingress-fel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: inference-service-fel
annotations:
nginx.ingress.kubernetes.io/whitelist-source-range: "71.198.255.25,134.79.0.0/16,172.16.0.0/12,208.45.173.162,216.46.165.69"
nginx.ingress.kubernetes.io/proxy-body-size: 2g
nginx.ingress.kubernetes.io/client-max-body-size: 2g
nginx.ingress.kubernetes.io/proxy-read-timeout: "1800"
nginx.ingress.kubernetes.io/proxy-send-timeout: "1800"
nginx.ingress.kubernetes.io/proxy-request-buffering: "off"
nginx.ingress.kubernetes.io/rewrite-target: /$2
spec:
rules:
- host: "ard-modeling-service.slac.stanford.edu"
http:
paths:
- pathType: ImplementationSpecific # Change this
path: "/fel(/|$)(.*)" # Change this
backend:
service:
name: inference-service-fel
port:
number: 8000
25 changes: 25 additions & 0 deletions ingress-test/run_multiple.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

NUM_CLIENTS=${1:-4}
RATE=${2:-1.0}
DURATION=${3:-60}

echo "Starting $NUM_CLIENTS clients at $RATE Hz for $DURATION seconds..."

for i in $(seq 1 $NUM_CLIENTS); do
python stress_test_fel.py \
--client-id $i \
--rate $RATE \
--duration $DURATION \
> client_${i}.log 2>&1 &
echo "Started client $i (PID: $!)"
sleep 0.2 # Stagger client startups slightly
done

echo "All clients started. Logs: client_*.log"
echo "To stop all: pkill -f stress_test.py"

# Wait for all background jobs
wait

echo "All clients finished"
Loading