Skip to content
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ Fill in your API endpoint in `config/api_config.yaml`. We support OpenAI compati

You may use inference engine such as [vLLM](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html) or [SGLang](https://github.com/sgl-project/sglang?tab=readme-ov-file#using-local-models) to host your model with an OpenAI compatible API server.

#### Tips for vLLM / OpenAI-compatible local servers

- **Timeouts**: you can add `timeout: <seconds>` under the endpoint entry in `config/api_config.yaml` for local servers (helps with long generations).
- **Parallelism**: set `parallel` to something your server can actually sustain (too high can increase queueing and make progress look “stuck”).

We also include support for fast built-in inference with SGLang, see examples in `config/api_config.yaml` and implementaton in `utils/completion.py`. See `misc/sglang_setup.bash` for environment setup.

### Step 2. Generate Model Answers
Expand Down
2 changes: 2 additions & 0 deletions config/api_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ gemma-3-27b-it:
endpoints:
- api_base: http://0.0.0.0:<port_number>/v1
api_key: '-'
# Optional: timeout in seconds for OpenAI-compatible local servers (e.g., vLLM)
# timeout: 1800
api_type: openai
parallel: 128
max_tokens: 8196
Expand Down
34 changes: 34 additions & 0 deletions tests/test_openai_client_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest
from unittest import mock


class TestOpenAIClientCache(unittest.TestCase):
def test_openai_client_is_cached_per_thread_and_endpoint(self):
# Avoid importing the full completion stack (pulls optional deps).
# Instead, mock an `openai` module and test the small helper directly.
from utils import openai_client

fake_openai = mock.Mock()
fake_openai.OpenAI.side_effect = [object(), object(), object()]

with mock.patch.dict("sys.modules", {"openai": fake_openai}):
c1 = openai_client.get_openai_client(
{"api_base": "http://localhost:8000/v1", "api_key": "k", "timeout": 123}
)
c2 = openai_client.get_openai_client(
{"api_base": "http://localhost:8000/v1", "api_key": "k", "timeout": 123}
)
self.assertIs(c1, c2)
self.assertEqual(fake_openai.OpenAI.call_count, 1)

# Different timeout should create a new client
c3 = openai_client.get_openai_client(
{"api_base": "http://localhost:8000/v1", "api_key": "k", "timeout": 456}
)
self.assertIsNot(c1, c3)
self.assertEqual(fake_openai.OpenAI.call_count, 2)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


if __name__ == "__main__":
unittest.main()

37 changes: 25 additions & 12 deletions utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tqdm import tqdm

from utils.bedrock_utils import create_llama3_body, create_nova_messages, extract_answer
from utils.openai_client import get_openai_client

# API setting constants
API_MAX_RETRY = 3
Expand All @@ -24,6 +25,14 @@
registered_engine_completion = {}


def _tqdm_write(msg: str) -> None:
# Avoid corrupting progress bars when multiple worker threads print.
try:
tqdm.write(msg)
except Exception:
print(msg)


def register_api(api_type):
def decorator(func):
registered_api_completion[api_type] = func
Expand Down Expand Up @@ -120,13 +129,8 @@ def make_config(config_file: str) -> dict:
@register_api("openai")
def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None, **kwargs):
import openai
if api_dict:
client = openai.OpenAI(
base_url=api_dict["api_base"],
api_key=api_dict["api_key"],
)
else:
client = openai.OpenAI()

client = get_openai_client(api_dict)

if api_dict and "model_name" in api_dict:
model = api_dict["model_name"]
Expand All @@ -145,14 +149,23 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No
}
break
except openai.RateLimitError as e:
print(type(e), e)
_tqdm_write(f"{type(e).__name__}: {e}")
time.sleep(API_RETRY_SLEEP)
except openai.BadRequestError as e:
print(messages)
print(type(e), e)
except KeyError:
print(type(e), e)
# Usually deterministic; don't spam full messages in multi-threaded runs.
_tqdm_write(f"{type(e).__name__}: {e}")
break
except (openai.APITimeoutError, openai.APIConnectionError, openai.InternalServerError) as e:
# Common transient errors with local OpenAI-compatible servers (e.g., vLLM)
_tqdm_write(f"{type(e).__name__}: {e}")
time.sleep(API_RETRY_SLEEP)
except KeyError as e:
_tqdm_write(f"{type(e).__name__}: {e}")
break
except Exception as e:
# Keep the worker alive and allow the main progress loop to continue.
_tqdm_write(f"{type(e).__name__}: {e}")
time.sleep(API_RETRY_SLEEP)

return output

Expand Down
49 changes: 49 additions & 0 deletions utils/openai_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import threading
from typing import Optional

_thread_local = threading.local()


def get_openai_client(api_dict: Optional[dict] = None):
"""Return a thread-local cached OpenAI client.

Creating a new OpenAI client per request is expensive (new HTTP client,
TLS/session setup, etc.) and can significantly slow down local inference
backends like vLLM that can otherwise serve quickly.

`api_dict` supports:
- api_base: str
- api_key: str
- timeout: float (seconds)
"""
import openai

api_base = None
api_key = None
timeout = None
if api_dict:
api_base = api_dict.get("api_base")
api_key = api_dict.get("api_key")
timeout = api_dict.get("timeout")

cache = getattr(_thread_local, "openai_clients", None)
if cache is None:
cache = {}
_thread_local.openai_clients = cache

key = (api_base, api_key, timeout)
client = cache.get(key)
if client is None:
kwargs = {}
if api_base:
kwargs["base_url"] = api_base
if api_key:
kwargs["api_key"] = api_key
if timeout is not None:
kwargs["timeout"] = timeout
client = openai.OpenAI(**kwargs) if kwargs else openai.OpenAI()
cache[key] = client
return client