Skip to content
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ Fill in your API endpoint in `config/api_config.yaml`. We support OpenAI compati

You may use inference engine such as [vLLM](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html) or [SGLang](https://github.com/sgl-project/sglang?tab=readme-ov-file#using-local-models) to host your model with an OpenAI compatible API server.

#### Tips for vLLM / OpenAI-compatible local servers

- **Timeouts**: you can add `timeout: <seconds>` under the endpoint entry in `config/api_config.yaml` for local servers (helps with long generations).
- **Parallelism**: set `parallel` to something your server can actually sustain (too high can increase queueing and make progress look “stuck”).

We also include support for fast built-in inference with SGLang, see examples in `config/api_config.yaml` and implementaton in `utils/completion.py`. See `misc/sglang_setup.bash` for environment setup.

### Step 2. Generate Model Answers
Expand Down
2 changes: 2 additions & 0 deletions config/api_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ gemma-3-27b-it:
endpoints:
- api_base: http://0.0.0.0:<port_number>/v1
api_key: '-'
# Optional: timeout in seconds for OpenAI-compatible local servers (e.g., vLLM)
# timeout: 1800
api_type: openai
parallel: 128
max_tokens: 8196
Expand Down
50 changes: 50 additions & 0 deletions tests/test_openai_client_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Unit tests for OpenAI client caching functionality."""

import unittest
from unittest import mock


class TestOpenAIClientCache(unittest.TestCase):
"""Test suite for the OpenAI client caching mechanism."""

def setUp(self):
"""Reset thread-local cache before each test to ensure isolation."""
from utils import openai_client
# Ensure no cached client from earlier tests in this thread.
openai_client._thread_local.__dict__.pop("openai_clients", None)

def test_openai_client_is_cached_per_thread_and_endpoint(self):
"""Verify that OpenAI clients are cached per thread and endpoint configuration.

Tests that:
- Repeated calls with identical config return the same client instance
- Different configurations (e.g., timeout) create separate cached clients
"""
# Avoid importing the full completion stack (pulls optional deps).
# Instead, mock an `openai` module and test the small helper directly.
from utils import openai_client

fake_openai = mock.Mock()
fake_openai.OpenAI.side_effect = [object(), object(), object()]

with mock.patch.dict("sys.modules", {"openai": fake_openai}):
c1 = openai_client.get_openai_client(
{"api_base": "http://localhost:8000/v1", "api_key": "k", "timeout": 123}
)
c2 = openai_client.get_openai_client(
{"api_base": "http://localhost:8000/v1", "api_key": "k", "timeout": 123}
)
self.assertIs(c1, c2)
self.assertEqual(fake_openai.OpenAI.call_count, 1)

# Different timeout should create a new client
c3 = openai_client.get_openai_client(
{"api_base": "http://localhost:8000/v1", "api_key": "k", "timeout": 456}
)
self.assertIsNot(c1, c3)
self.assertEqual(fake_openai.OpenAI.call_count, 2)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


if __name__ == "__main__":
unittest.main()

115 changes: 89 additions & 26 deletions utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tqdm import tqdm

from utils.bedrock_utils import create_llama3_body, create_nova_messages, extract_answer
from utils.openai_client import get_openai_client

# API setting constants
API_MAX_RETRY = 3
Expand All @@ -24,7 +25,31 @@
registered_engine_completion = {}


def _tqdm_write(msg: str) -> None:
"""Write a message to stdout without corrupting tqdm progress bars.

In multi-threaded environments, direct print() calls can interleave with
tqdm's progress bar output, causing garbled console output. This helper
uses tqdm.write() which properly coordinates with active progress bars.

Args:
msg: The message string to write to stdout.
"""
try:
tqdm.write(msg)
except Exception:
print(msg)


def register_api(api_type):
"""Decorator to register a function as an API completion handler.

Args:
api_type: String identifier for the API type (e.g., 'openai', 'anthropic').

Returns:
Decorator function that registers the wrapped function.
"""
def decorator(func):
registered_api_completion[api_type] = func
return func
Expand All @@ -33,6 +58,14 @@ def decorator(func):


def register_engine(engine_type):
"""Decorator to register a function as an engine completion handler.

Args:
engine_type: String identifier for the engine type (e.g., 'sglang').

Returns:
Decorator function that registers the wrapped function.
"""
def decorator(func):
registered_engine_completion[engine_type] = func
return func
Expand Down Expand Up @@ -119,15 +152,27 @@ def make_config(config_file: str) -> dict:

@register_api("openai")
def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None, **kwargs):
"""Send a chat completion request to an OpenAI or OpenAI-compatible server.

Uses a thread-local cached client to avoid per-request connection overhead,
which is especially important for local inference servers like vLLM.

Args:
model: The model identifier to use for completion.
messages: List of message dicts with 'role' and 'content' keys.
temperature: Sampling temperature for generation.
max_tokens: Maximum number of tokens to generate.
api_dict: Optional dict with 'api_base', 'api_key', 'timeout', and
'model_name' (to override the model parameter).
**kwargs: Additional arguments (unused).

Returns:
Dict with 'answer' key containing the model response, or API_ERROR_OUTPUT on failure.
"""
import openai
if api_dict:
client = openai.OpenAI(
base_url=api_dict["api_base"],
api_key=api_dict["api_key"],
)
else:
client = openai.OpenAI()


client = get_openai_client(api_dict)

if api_dict and "model_name" in api_dict:
model = api_dict["model_name"]

Expand All @@ -145,31 +190,49 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No
}
break
except openai.RateLimitError as e:
print(type(e), e)
_tqdm_write(f"{type(e).__name__}: {e}")
time.sleep(API_RETRY_SLEEP)
except openai.BadRequestError as e:
print(messages)
print(type(e), e)
except KeyError:
print(type(e), e)
# Usually deterministic; don't spam full messages in multi-threaded runs.
_tqdm_write(f"{type(e).__name__}: {e}")
break
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError) as e:
# Common transient errors with local OpenAI-compatible servers (e.g., vLLM)
_tqdm_write(f"{type(e).__name__}: {e}")
time.sleep(API_RETRY_SLEEP)
except KeyError as e:
_tqdm_write(f"{type(e).__name__}: {e}")
break
except Exception as e:
# Keep the worker alive and allow the main progress loop to continue.
_tqdm_write(f"{type(e).__name__}: {e}")
time.sleep(API_RETRY_SLEEP)

return output


@register_api("openai_thinking")
def chat_completion_openai_thinking(model, messages, api_dict=None, **kwargs):
"""Send a chat completion request to OpenAI models with reasoning/thinking support.

Uses the cached OpenAI client and supports models with extended reasoning
capabilities (e.g., o1, o3). Handles rate limits with fixed delay (API_RETRY_SLEEP).

Args:
model: The model identifier to use for completion.
messages: List of message dicts with 'role' and 'content' keys.
api_dict: Optional dict with 'api_base', 'api_key', and 'timeout' settings.
**kwargs: Additional arguments, notably 'reasoning_effort' (default: 'medium').

Returns:
Dict with 'answer' key containing the model response, or API_ERROR_OUTPUT on failure.
"""
import openai

if api_dict:
client = openai.OpenAI(
api_key=api_dict["api_key"],
)
else:
client = openai.OpenAI()

client = get_openai_client(api_dict)

output = API_ERROR_OUTPUT
for i in range(API_MAX_RETRY):
for _ in range(API_MAX_RETRY):
try:
completion = client.chat.completions.create(
model=model,
Expand All @@ -181,13 +244,13 @@ def chat_completion_openai_thinking(model, messages, api_dict=None, **kwargs):
}
break
except openai.RateLimitError as e:
print(type(e), e)
_tqdm_write(f"{type(e).__name__}: {e}")
time.sleep(API_RETRY_SLEEP)
except openai.BadRequestError as e:
print(messages)
print(type(e), e)
except KeyError:
print(type(e), e)
_tqdm_write(f"{type(e).__name__}: {e}")
break
except KeyError as e:
_tqdm_write(f"{type(e).__name__}: {e}")
break

return output
Expand Down
55 changes: 55 additions & 0 deletions utils/openai_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Thread-local OpenAI client caching for improved performance.

This module provides a caching mechanism that reuses OpenAI client instances
per thread and endpoint configuration, avoiding the overhead of creating new
HTTP clients for each API request.
"""
from __future__ import annotations

import threading
from typing import Optional

_thread_local = threading.local()


def get_openai_client(api_dict: Optional[dict] = None):
"""Return a thread-local cached OpenAI client.

Creating a new OpenAI client per request is expensive (new HTTP client,
TLS/session setup, etc.) and can significantly slow down local inference
backends like vLLM that can otherwise serve quickly.

`api_dict` supports:
- api_base: str
- api_key: str
- timeout: float (seconds)
"""
import openai

api_base = None
api_key = None
timeout = None
if api_dict:
api_base = api_dict.get("api_base")
api_key = api_dict.get("api_key")
timeout = api_dict.get("timeout")

cache = getattr(_thread_local, "openai_clients", None)
if cache is None:
cache = {}
_thread_local.openai_clients = cache

key = (api_base, api_key, timeout)
client = cache.get(key)
if client is None:
kwargs = {}
if api_base:
kwargs["base_url"] = api_base
if api_key:
kwargs["api_key"] = api_key
if timeout is not None:
kwargs["timeout"] = timeout
client = openai.OpenAI(**kwargs) if kwargs else openai.OpenAI()
cache[key] = client
return client