-
Notifications
You must be signed in to change notification settings - Fork 126
Add remote OpenAI-compatible embedding API support #776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
bbdd7fe
fd280cd
59fec97
8db336c
9f991e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,9 @@ | ||||||||||||
| """Functions to compute vector embeddings.""" | ||||||||||||
|
|
||||||||||||
| from typing import Callable, List, Optional | ||||||||||||
|
|
||||||||||||
| import requests | ||||||||||||
|
|
||||||||||||
| from ..util import get_logger | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
|
|
@@ -16,3 +20,26 @@ def load_model(model_name: str): | |||||||||||
| model = SentenceTransformer(model_name) | ||||||||||||
| logger.debug("Done initializing embedding model.") | ||||||||||||
| return model | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def create_remote_embedding_function( | ||||||||||||
| base_url: str, model_name: str, api_key: Optional[str] = None | ||||||||||||
| ) -> Callable[[List[str]], List[List[float]]]: | ||||||||||||
| """Create an embedding function that calls a remote OpenAI-compatible API. | ||||||||||||
|
|
||||||||||||
| Returns a callable with signature (texts: list[str]) -> list[list[float]]. | ||||||||||||
| """ | ||||||||||||
| url = f"{base_url.rstrip('/')}/v1/embeddings" | ||||||||||||
|
||||||||||||
| url = f"{base_url.rstrip('/')}/v1/embeddings" | |
| stripped = base_url.rstrip("/") | |
| if stripped.endswith("/v1"): | |
| stripped = stripped[:-3] | |
| url = f"{stripped}/v1/embeddings" |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -79,7 +79,9 @@ class DefaultConfig(object): | |||||||||||||
| LLM_MODEL = "" | ||||||||||||||
| LLM_MAX_CONTEXT_LENGTH = 50000 | ||||||||||||||
| LLM_SYSTEM_PROMPT = None | ||||||||||||||
| VECTOR_EMBEDDING_MODEL = "" | ||||||||||||||
| VECTOR_EMBEDDING_MODEL = "" # Model name for semantic search embeddings | ||||||||||||||
| EMBEDDING_BASE_URL = None # If set, use remote OpenAI-compatible API instead of local model | ||||||||||||||
| EMBEDDING_API_KEY = None # Optional API key for authenticated embedding providers | ||||||||||||||
|
||||||||||||||
| EMBEDDING_BASE_URL = None # If set, use remote OpenAI-compatible API instead of local model | |
| EMBEDDING_API_KEY = None # Optional API key for authenticated embedding providers | |
| # If set, use remote OpenAI-compatible API instead of local model | |
| EMBEDDING_BASE_URL = None | |
| # Optional API key for authenticated embedding providers | |
| EMBEDDING_API_KEY = None |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,130 @@ | ||||||||||||||||||||||||||||||||||
| """Tests for remote embedding function.""" | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| from unittest.mock import patch | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| from gramps_webapi.api.search.embeddings import create_remote_embedding_function | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @pytest.fixture | ||||||||||||||||||||||||||||||||||
| def mock_response_data(): | ||||||||||||||||||||||||||||||||||
| """Sample embedding API response with out-of-order indices.""" | ||||||||||||||||||||||||||||||||||
| return { | ||||||||||||||||||||||||||||||||||
| "object": "list", | ||||||||||||||||||||||||||||||||||
| "data": [ | ||||||||||||||||||||||||||||||||||
| {"object": "embedding", "index": 1, "embedding": [0.4, 0.5, 0.6]}, | ||||||||||||||||||||||||||||||||||
| {"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}, | ||||||||||||||||||||||||||||||||||
| {"object": "embedding", "index": 2, "embedding": [0.7, 0.8, 0.9]}, | ||||||||||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||||||||||
| "model": "test-model", | ||||||||||||||||||||||||||||||||||
| "usage": {"prompt_tokens": 10, "total_tokens": 10}, | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| class TestCreateRemoteEmbeddingFunction: | ||||||||||||||||||||||||||||||||||
| """Tests for create_remote_embedding_function.""" | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @patch("gramps_webapi.api.search.embeddings.requests.post") | ||||||||||||||||||||||||||||||||||
| def test_returns_embeddings_in_order(self, mock_post, mock_response_data): | ||||||||||||||||||||||||||||||||||
| mock_post.return_value.json.return_value = mock_response_data | ||||||||||||||||||||||||||||||||||
| mock_post.return_value.raise_for_status.return_value = None | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| embed = create_remote_embedding_function( | ||||||||||||||||||||||||||||||||||
| base_url="http://localhost:11434", | ||||||||||||||||||||||||||||||||||
| model_name="test-model", | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| result = embed(["hello", "world", "foo"]) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| assert result == [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @patch("gramps_webapi.api.search.embeddings.requests.post") | ||||||||||||||||||||||||||||||||||
| def test_posts_to_correct_url(self, mock_post, mock_response_data): | ||||||||||||||||||||||||||||||||||
| mock_post.return_value.json.return_value = mock_response_data | ||||||||||||||||||||||||||||||||||
| mock_post.return_value.raise_for_status.return_value = None | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| embed = create_remote_embedding_function( | ||||||||||||||||||||||||||||||||||
| base_url="http://localhost:11434", | ||||||||||||||||||||||||||||||||||
| model_name="test-model", | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| embed(["hello"]) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| mock_post.assert_called_once() | ||||||||||||||||||||||||||||||||||
| call_args = mock_post.call_args | ||||||||||||||||||||||||||||||||||
| assert call_args[0][0] == "http://localhost:11434/v1/embeddings" | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @patch("gramps_webapi.api.search.embeddings.requests.post") | ||||||||||||||||||||||||||||||||||
| def test_strips_trailing_slash_from_base_url(self, mock_post, mock_response_data): | ||||||||||||||||||||||||||||||||||
| mock_post.return_value.json.return_value = mock_response_data | ||||||||||||||||||||||||||||||||||
| mock_post.return_value.raise_for_status.return_value = None | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| embed = create_remote_embedding_function( | ||||||||||||||||||||||||||||||||||
| base_url="http://localhost:11434/", | ||||||||||||||||||||||||||||||||||
| model_name="test-model", | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| embed(["hello"]) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| call_args = mock_post.call_args | ||||||||||||||||||||||||||||||||||
| assert call_args[0][0] == "http://localhost:11434/v1/embeddings" | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| @patch("gramps_webapi.api.search.embeddings.requests.post") | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
| @patch("gramps_webapi.api.search.embeddings.requests.post") | |
| @patch("gramps_webapi.api.search.embeddings.requests.post") | |
| def test_base_url_with_v1_segment(self, mock_post, mock_response_data): | |
| mock_post.return_value.json.return_value = mock_response_data | |
| mock_post.return_value.raise_for_status.return_value = None | |
| embed = create_remote_embedding_function( | |
| base_url="http://localhost:11434/v1", | |
| model_name="test-model", | |
| ) | |
| embed(["hello"]) | |
| call_args = mock_post.call_args | |
| assert call_args[0][0] == "http://localhost:11434/v1/embeddings" | |
| @patch("gramps_webapi.api.search.embeddings.requests.post") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The OpenAI example sets
GRAMPSWEB_EMBEDDING_BASE_URL=https://api.openai.com/v1, butcreate_remote_embedding_function()appends/v1/embeddingsto the base URL. With this example config, the effective URL becomeshttps://api.openai.com/v1/v1/embeddingsand will fail. Update the example (e.g., base URL without/v1) or adjust the code to accept both formats.