Skip to content

Commit cbca6e8

Browse files
Add network embedding batching (#276)
When using Mistral embedding API (through [LiteLLM](https://www.litellm.ai/) for OpenAI API compatibility) with Context Chat, I got many `Too many tokens overall, split into more batches` errors. This is due to Mistral's lower tokens limit per API request, [16000~](langchain-ai/langchain#20523), compared to OpenAI's [300000](https://github.com/langchain-ai/langchain/blob/18230f625f79aba25cbf9fb5500ab504cbb8f0bc/libs/partners/openai/langchain_openai/embeddings/base.py#L22). The idea to fix this is to implement the same pattern as the [LangChain OpenAI integration](https://github.com/langchain-ai/langchain/blob/18230f625f79aba25cbf9fb5500ab504cbb8f0bc/libs/partners/openai/langchain_openai/embeddings/base.py#L598), batching API requests. A better solution would be to allow using LangChain’s built-in provider class, but this refactor is too big for my first PR x) Signed-off-by: Florian Charlaix <fcharlaix@open-dsi.fr>
1 parent 257c5ce commit cbca6e8

5 files changed

Lines changed: 13 additions & 1 deletion

File tree

config.cpu.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ embedding:
2323
base_url: http://localhost:5000/v1
2424
workers: 1
2525
request_timeout: 1800 # in seconds
26+
# batch_size: 100 # max texts per embedding API request, 0 = no batching
2627
# only for external embedding service
2728
# remote_service: true
2829
# model_name: text-embedding-3-small

config.gpu.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ embedding:
2323
base_url: http://localhost:5000/v1
2424
workers: 1
2525
request_timeout: 1800 # in seconds
26+
# batch_size: 100 # max texts per embedding API request, 0 = no batching
2627
# only for external embedding service
2728
# remote_service: true
2829
# model_name: text-embedding-3-small

context_chat_backend/config_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def get_config(file_path: str) -> TConfig:
7878
remote_service=True,
7979
workers=0,
8080
request_timeout=embedding.get('request_timeout', 1800) if embedding else 1800,
81+
batch_size=int(os.getenv('CC_EM_BATCH_SIZE', 100)),
8182
)
8283
except Exception as e:
8384
raise AssertionError(

context_chat_backend/network_em.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,15 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float]
117117
return [d['embedding'] for d in resp['data']] # pyright: ignore[reportReturnType]
118118

119119
def embed_documents(self, texts: list[str]) -> list[list[float]]:
120-
return self._get_embedding(texts) # pyright: ignore[reportReturnType]
120+
batch_size = self.app_config.embedding.batch_size
121+
if batch_size <= 0 or len(texts) <= batch_size:
122+
return self._get_embedding(texts) # pyright: ignore[reportReturnType]
123+
124+
results: list[list[float]] = []
125+
for i in range(0, len(texts), batch_size):
126+
batch_embeddings = self._get_embedding(texts[i:i + batch_size])
127+
results.extend(batch_embeddings) # pyright: ignore[reportArgumentType]
128+
return results
121129

122130
def embed_query(self, text: str) -> list[float]:
123131
return self._get_embedding(text) # pyright: ignore[reportReturnType]

context_chat_backend/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class TEmbeddingConfig(BaseModel):
3131
model_name: str | None = DEFAULT_EM_MODEL_ALIAS
3232
auth: TEmbeddingAuthApiKey | TEmbeddingAuthBasic | None = None
3333
remote_service: bool = False
34+
batch_size: int = 100 # max texts per embedding API request, 0 = no batching
3435
llama: dict = dict() # noqa: C408
3536

3637

0 commit comments

Comments
 (0)