Skip to content

Commit 85e6ebd

Browse files
committed
Add streaming, lower temperature
1 parent 32b114f commit 85e6ebd

8 files changed

Lines changed: 401 additions & 106 deletions

File tree

core/api.py

Lines changed: 85 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import tomli
1111
from fastapi import Depends, FastAPI, Form, Header, HTTPException, Query, UploadFile
1212
from fastapi.middleware.cors import CORSMiddleware # Import CORSMiddleware
13+
from fastapi.responses import StreamingResponse
1314
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
1415
from starlette.middleware.sessions import SessionMiddleware
1516

@@ -480,7 +481,7 @@ async def query_completion(
480481

481482
# Main query processing
482483
perf.start_phase("document_service_query")
483-
response = await document_service.query(
484+
result = await document_service.query(
484485
request.query,
485486
auth,
486487
request.filters,
@@ -499,30 +500,93 @@ async def query_completion(
499500
request.schema,
500501
history,
501502
perf, # Pass performance tracker
503+
request.stream_response,
502504
)
503505

504-
# Chat history storage
505-
perf.start_phase("chat_history_storage")
506-
if history_key:
507-
history.append(
508-
{
509-
"role": "assistant",
510-
"content": response.completion,
511-
"timestamp": datetime.now(UTC).isoformat(),
512-
}
513-
)
514-
await redis.set(history_key, json.dumps(history))
515-
await document_service.db.upsert_chat_history(
516-
request.chat_id,
517-
auth.user_id,
518-
auth.app_id,
519-
history,
520-
)
506+
# Handle streaming vs non-streaming responses
507+
if request.stream_response:
508+
# For streaming responses, unpack the tuple
509+
response_stream, sources = result
510+
511+
async def generate_stream():
512+
full_content = ""
513+
first_token_time = None
514+
515+
async for chunk in response_stream:
516+
# Track time to first token
517+
if first_token_time is None:
518+
first_token_time = time.time()
519+
completion_start_to_first_token = first_token_time - perf.start_time
520+
perf.add_suboperation("completion_start_to_first_token", completion_start_to_first_token)
521+
logger.info(f"Completion start to first token: {completion_start_to_first_token:.2f}s")
522+
523+
full_content += chunk
524+
yield f"data: {json.dumps({'content': chunk})}\n\n"
525+
526+
# Convert sources to the format expected by frontend
527+
sources_info = [
528+
{"document_id": source.document_id, "chunk_number": source.chunk_number, "score": source.score}
529+
for source in sources
530+
]
531+
532+
# Send completion signal with sources
533+
yield f"data: {json.dumps({'done': True, 'sources': sources_info})}\n\n"
534+
535+
# Handle chat history after streaming is complete
536+
if history_key:
537+
history.append(
538+
{
539+
"role": "assistant",
540+
"content": full_content,
541+
"timestamp": datetime.now(UTC).isoformat(),
542+
}
543+
)
544+
await redis.set(history_key, json.dumps(history))
545+
await document_service.db.upsert_chat_history(
546+
request.chat_id,
547+
auth.user_id,
548+
auth.app_id,
549+
history,
550+
)
551+
552+
# Log consolidated performance summary for streaming
553+
streaming_time = time.time() - first_token_time if first_token_time else 0
554+
perf.add_suboperation("streaming_duration", streaming_time)
555+
perf.log_summary(f"Generated streaming completion with {len(sources)} sources")
556+
557+
headers = {
558+
"Cache-Control": "no-cache",
559+
"Connection": "keep-alive",
560+
"Access-Control-Allow-Origin": "*",
561+
"Access-Control-Allow-Headers": "*",
562+
}
563+
return StreamingResponse(generate_stream(), media_type="text/event-stream", headers=headers)
564+
else:
565+
# For non-streaming responses, result is just the CompletionResponse
566+
response = result
567+
568+
# Chat history storage for non-streaming responses
569+
perf.start_phase("chat_history_storage")
570+
if history_key:
571+
history.append(
572+
{
573+
"role": "assistant",
574+
"content": response.completion,
575+
"timestamp": datetime.now(UTC).isoformat(),
576+
}
577+
)
578+
await redis.set(history_key, json.dumps(history))
579+
await document_service.db.upsert_chat_history(
580+
request.chat_id,
581+
auth.user_id,
582+
auth.app_id,
583+
history,
584+
)
521585

522-
# Log consolidated performance summary
523-
perf.log_summary(f"Generated completion with {len(response.sources) if response.sources else 0} sources")
586+
# Log consolidated performance summary
587+
perf.log_summary(f"Generated completion with {len(response.sources) if response.sources else 0} sources")
524588

525-
return response
589+
return response
526590
except ValueError as e:
527591
validate_prompt_overrides_with_http_exception(operation_type="query", error=e)
528592
except PermissionError as e:

core/completion/litellm_completion.py

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import re # Import re for parsing model name
3-
from typing import Any, Dict, List, Optional, Tuple, Union
3+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
44

55
import litellm
66

@@ -425,15 +425,107 @@ async def _handle_standard_litellm(
425425
finish_reason=response.choices[0].finish_reason,
426426
)
427427

428-
async def complete(self, request: CompletionRequest) -> CompletionResponse:
428+
async def _handle_streaming_litellm(
429+
self,
430+
user_content: str,
431+
image_urls: List[str],
432+
request: CompletionRequest,
433+
history_messages: List[Dict[str, str]],
434+
) -> AsyncGenerator[str, None]:
435+
"""Handle streaming output generation with LiteLLM."""
436+
logger.debug(f"Using LiteLLM streaming for model: {self.model_config['model_name']}")
437+
# Build messages for LiteLLM
438+
content_list = [{"type": "text", "text": user_content}]
439+
include_images = image_urls # Use the collected full data URIs
440+
441+
if include_images:
442+
NUM_IMAGES = min(5, len(image_urls))
443+
for img_url in image_urls[:NUM_IMAGES]:
444+
content_list.append({"type": "image_url", "image_url": {"url": img_url}})
445+
446+
# LiteLLM uses list content format
447+
user_message = {"role": "user", "content": content_list}
448+
# Use the system prompt defined earlier
449+
litellm_messages = [get_system_message()] + history_messages + [user_message]
450+
451+
# Prepare LiteLLM parameters
452+
model_params = {
453+
"model": self.model_config["model_name"],
454+
"messages": litellm_messages,
455+
"max_tokens": request.max_tokens,
456+
"temperature": request.temperature,
457+
"stream": True, # Enable streaming
458+
"num_retries": 3,
459+
}
460+
461+
for key, value in self.model_config.items():
462+
if key != "model_name":
463+
model_params[key] = value
464+
465+
logger.debug(f"Calling LiteLLM streaming with params: {model_params}")
466+
response = await litellm.acompletion(**model_params)
467+
468+
# Stream the response chunks
469+
async for chunk in response:
470+
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
471+
yield chunk.choices[0].delta.content
472+
473+
async def _handle_streaming_ollama(
474+
self,
475+
user_content: str,
476+
ollama_image_data: List[str],
477+
request: CompletionRequest,
478+
history_messages: List[Dict[str, str]],
479+
) -> AsyncGenerator[str, None]:
480+
"""Handle streaming output generation with Ollama."""
481+
logger.debug(f"Using direct Ollama streaming for model: {self.ollama_base_model_name}")
482+
client = ollama.AsyncClient(host=self.ollama_api_base)
483+
484+
# Construct Ollama messages
485+
system_message = {"role": "system", "content": get_system_message()["content"]}
486+
user_message_data = {"role": "user", "content": user_content}
487+
488+
# Add images directly to the user message if available
489+
if ollama_image_data:
490+
# Add all images to the user message
491+
user_message_data["images"] = ollama_image_data
492+
493+
ollama_messages = [system_message] + history_messages + [user_message_data]
494+
495+
# Construct Ollama options
496+
options = {
497+
"temperature": request.temperature,
498+
"num_predict": (
499+
request.max_tokens if request.max_tokens is not None else -1
500+
), # Default to model's default if None
501+
}
502+
503+
try:
504+
response = await client.chat(
505+
model=self.ollama_base_model_name,
506+
messages=ollama_messages,
507+
options=options,
508+
stream=True, # Enable streaming
509+
)
510+
511+
async for chunk in response:
512+
if chunk.get("message", {}).get("content"):
513+
yield chunk["message"]["content"]
514+
515+
except Exception as e:
516+
logger.error(f"Error during direct Ollama streaming call: {e}")
517+
raise
518+
519+
async def complete(self, request: CompletionRequest) -> Union[CompletionResponse, AsyncGenerator[str, None]]:
429520
"""
430521
Generate completion using LiteLLM or direct Ollama client if configured.
431522
432523
Args:
433524
request: CompletionRequest object containing query, context, and parameters
434525
435526
Returns:
436-
CompletionResponse object with the generated text and usage statistics
527+
CompletionResponse object with the generated text and usage statistics or
528+
AsyncGenerator for streaming responses
437529
"""
438530
# Process context chunks and handle images
439531
context_text, image_urls, ollama_image_data = process_context_chunks(request.context_chunks, self.is_ollama)
@@ -446,6 +538,18 @@ async def complete(self, request: CompletionRequest) -> CompletionResponse:
446538
# Check if structured output is requested
447539
structured_output = request.schema is not None
448540

541+
# Streaming is not supported with structured output
542+
if request.stream_response and structured_output:
543+
logger.warning("Streaming is not supported with structured output. Falling back to non-streaming.")
544+
request.stream_response = False
545+
546+
# If streaming is requested and no structured output
547+
if request.stream_response and not structured_output:
548+
if self.is_ollama:
549+
return self._handle_streaming_ollama(user_content, ollama_image_data, request, history_messages)
550+
else:
551+
return self._handle_streaming_litellm(user_content, image_urls, request, history_messages)
552+
449553
# If structured output is requested, use instructor to handle it
450554
if structured_output:
451555
# Get dynamic model from schema

core/models/completion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ class CompletionRequest(BaseModel):
3232
query: str
3333
context_chunks: List[str]
3434
max_tokens: Optional[int] = 1000
35-
temperature: Optional[float] = 0.7
35+
temperature: Optional[float] = 0.3
3636
prompt_template: Optional[str] = None
3737
folder_name: Optional[str] = None
3838
end_user_id: Optional[str] = None
3939
schema: Optional[Union[Type[BaseModel], Dict[str, Any]]] = None
4040
chat_history: Optional[List[ChatMessage]] = None
41+
stream_response: Optional[bool] = False

core/models/request.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class CompletionQueryRequest(RetrieveRequest):
4444
None,
4545
description="Optional chat session ID for persisting conversation history",
4646
)
47+
stream_response: Optional[bool] = Field(
48+
False,
49+
description="Whether to stream the response back in chunks",
50+
)
4751

4852

4953
class IngestTextRequest(BaseModel):

core/services/document_service.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import uuid
99
from datetime import UTC, datetime
1010
from io import BytesIO
11-
from typing import Any, Dict, List, Optional, Type, Union
11+
from typing import Any, AsyncGenerator, Dict, List, Optional, Type, Union
1212

1313
import arq
1414
import filetype
@@ -588,7 +588,8 @@ async def query(
588588
schema: Optional[Union[Type[BaseModel], Dict[str, Any]]] = None,
589589
chat_history: Optional[List[ChatMessage]] = None,
590590
perf_tracker: Optional[Any] = None, # Performance tracker from API layer
591-
) -> CompletionResponse:
591+
stream_response: Optional[bool] = False,
592+
) -> Union[CompletionResponse, tuple[AsyncGenerator[str, None], List[ChunkSource]]]:
592593
"""Generate completion using relevant chunks as context.
593594
594595
When graph_name is provided, the query will leverage the knowledge graph
@@ -717,28 +718,46 @@ async def query(
717718
prompt_template=custom_prompt_template,
718719
schema=schema,
719720
chat_history=chat_history,
721+
stream_response=stream_response,
720722
)
721723

722724
response = await self.completion_model.complete(request)
723725

724726
if not perf_tracker:
725727
phase_times["completion_generation"] = time.time() - completion_start
726728

727-
# Add sources information at the document service level
728-
response.sources = sources
729-
730-
# Log performance summary only for standalone calls
731-
if local_perf:
732-
total_time = time.time() - query_start_time
733-
logger.info("=== DocumentService.query Performance Summary ===")
734-
logger.info(f"Total query time: {total_time:.2f}s")
735-
for phase, duration in sorted(phase_times.items(), key=lambda x: x[1], reverse=True):
736-
percentage = (duration / total_time) * 100 if total_time > 0 else 0
737-
logger.info(f" - {phase}: {duration:.2f}s ({percentage:.1f}%)")
738-
logger.info(f"Generated completion with {len(sources)} sources")
739-
logger.info("================================================")
740-
741-
return response
729+
# Handle streaming vs non-streaming responses
730+
if stream_response:
731+
# For streaming responses, return the async generator and sources separately
732+
733+
# Log performance summary for streaming calls
734+
if local_perf:
735+
total_time = time.time() - query_start_time
736+
logger.info("=== DocumentService.query Performance Summary (Streaming) ===")
737+
logger.info(f"Total setup time: {total_time:.2f}s")
738+
for phase, duration in sorted(phase_times.items(), key=lambda x: x[1], reverse=True):
739+
percentage = (duration / total_time) * 100 if total_time > 0 else 0
740+
logger.info(f" - {phase}: {duration:.2f}s ({percentage:.1f}%)")
741+
logger.info(f"Starting streaming with {len(sources)} sources")
742+
logger.info("=" * 59)
743+
744+
return response, sources
745+
else:
746+
# Add sources information at the document service level for non-streaming
747+
response.sources = sources
748+
749+
# Log performance summary only for standalone calls
750+
if local_perf:
751+
total_time = time.time() - query_start_time
752+
logger.info("=== DocumentService.query Performance Summary ===")
753+
logger.info(f"Total query time: {total_time:.2f}s")
754+
for phase, duration in sorted(phase_times.items(), key=lambda x: x[1], reverse=True):
755+
percentage = (duration / total_time) * 100 if total_time > 0 else 0
756+
logger.info(f" - {phase}: {duration:.2f}s ({percentage:.1f}%)")
757+
logger.info(f"Generated completion with {len(sources)} sources")
758+
logger.info("================================================")
759+
760+
return response
742761

743762
async def ingest_text(
744763
self,

0 commit comments

Comments
 (0)