Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 61 additions & 7 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
from typing import Literal, List

from agent.tools import aggregate, discover_fields, search
from langchain_core.messages import HumanMessage
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.language_models.chat_models import BaseModel
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages.system import SystemMessage
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from langgraph.errors import GraphRecursionError
from core.setup import checkpoint_saver
from agent.callbacks.socket import SocketCallbackHandler

DEFAULT_SYSTEM_MESSAGE = """
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that
support your answer. Answer in raw markdown, but not within a code block. When citing source documents, construct Markdown
links using the document's canonical_link field. Do not include intermediate messages explaining your process.
links using the document's canonical_link field. Do not include intermediate messages explaining your process. If the user's
question is unclear, ask for clarification.
"""

MAX_RECURSION_LIMIT = 10

class SearchWorkflow:
def __init__(self, model: BaseModel, system_message: str):
self.model = model
Expand Down Expand Up @@ -77,9 +82,58 @@ def __init__(
def invoke(self, question: str, ref: str, *, callbacks: List[BaseCallbackHandler] = [], forget: bool = False, **kwargs):
if forget:
self.checkpointer.delete_checkpoints(ref)

try:
return self.search_agent.invoke(
{"messages": [HumanMessage(content=question)]},
config={
"configurable": {"thread_id": ref},
"callbacks": callbacks,
"recursion_limit": MAX_RECURSION_LIMIT,
},
**kwargs
)
except GraphRecursionError as e:
print(f"Recursion error: {e}")

# Retrieve the messages processed so far
checkpoint_tuple = self.checkpointer.get_tuple({"configurable": {"thread_id": ref}})
state = checkpoint_tuple.checkpoint if checkpoint_tuple else None
messages = state.get("channel_values", {}).get("messages", []) if state else []

# Extract relevant responses including tool outputs
responses = []
for msg in messages:
if isinstance(msg, (BaseMessage, ToolMessage)):
responses.append(msg.content)

if responses:
# Summarize the responses so far
summary_prompt = f"""
The following is what I have discovered so far based on multiple sources.
Summarize the key points concisely for the user:

{responses[-5:]} # Take the last few responses
"""

# Generate a summary using the LLM
summary = self.workflow_logic.model.invoke([HumanMessage(content=summary_prompt)])
summary_text = summary.content

# Send summary as an "answer" message before finalizing
for cb in callbacks:
if isinstance(cb, SocketCallbackHandler):
cb.socket.send({"type": "answer", "ref": ref, "message": summary_text})

else:
# Send a fallback message
fallback_message = "I reached my recursion limit but couldn't retrieve enough useful information."
for cb in callbacks:
if isinstance(cb, SocketCallbackHandler):
cb.socket.send({"type": "answer", "ref": ref, "message": fallback_message})


return self.search_agent.invoke(
{"messages": [HumanMessage(content=question)]},
config={"configurable": {"thread_id": ref}, "callbacks": callbacks},
**kwargs
)
for cb in callbacks:
if hasattr(cb, "on_agent_finish"):
cb.on_agent_finish(finish=None, run_id=ref, **kwargs)
return {"type": "final", "ref": ref, "message": "Finished"}