Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,16 @@ env.json:
samconfig.%.yaml:
DEV_PREFIX=$* ./bin/make_deploy_config.sh
deploy: build samconfig.$(DEV_PREFIX).yaml
if ! aws sts get-caller-identity --query 'Arn' --output text | grep AWSReservedSSO_AWSAdministratorAccess > /dev/null; then \
echo "You must be logged in as an admin to deploy"; \
exit 1; \
fi
sam deploy --config-file samconfig.$(DEV_PREFIX).yaml --stack-name dc-api-$(DEV_PREFIX)
sync: samconfig.$(DEV_PREFIX).yaml
if ! aws sts get-caller-identity --query 'Arn' --output text | grep AWSReservedSSO_AWSAdministratorAccess > /dev/null; then \
echo "You must be logged in as an admin to sync"; \
exit 1; \
fi
sam sync --config-file samconfig.$(DEV_PREFIX).yaml --stack-name dc-api-$(DEV_PREFIX) --watch $(ARGS)
sync-code: ARGS=--code
sync-code: sync
Expand Down
133 changes: 99 additions & 34 deletions chat/src/agent/callbacks/metrics.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,109 @@
from datetime import datetime
from typing import Any, Dict
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.messages.tool import ToolMessage
import boto3
import json
import os

class MetricsCallbackHandler(BaseCallbackHandler):
def __init__(self, *args, **kwargs):
self.accumulator = {}
self.answers = []
self.artifacts = []
super().__init__(*args, **kwargs)

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
if response is None:
return

if not response.generations or not response.generations[0]:
return

for generation in response.generations[0]:
if generation.text != "":
self.answers.append(generation.text)

if not hasattr(generation, 'message') or generation.message is None:
continue

metadata = getattr(generation.message, 'usage_metadata', None)
if metadata is None:
continue

for k, v in metadata.items():
self.accumulator[k] = self.accumulator.get(k, 0) + v

def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
def __init__(self, log_stream = None, *args, extra_data = {}, **kwargs):
self.accumulator = {}
self.answers = []
self.artifacts = []
self.log_stream = log_stream
self.extra_data = extra_data
super().__init__(*args, **kwargs)

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
if response is None:
return

if not response.generations or not response.generations[0]:
return

for generation in response.generations[0]:
if generation.text != "":
self.answers.append(generation.text)

if not hasattr(generation, "message") or generation.message is None:
continue

metadata = getattr(generation.message, "usage_metadata", None)
if metadata is None:
continue

for k, v in metadata.items():
self.accumulator[k] = self.accumulator.get(k, 0) + v

def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
content = output.content
if isinstance(content, str):
try:
content = json.loads(content)
except json.decoder.JSONDecodeError as e:
print(
f"Invalid json ({e}) returned from {output.name} tool: {output.content}"
)
return

match output.name:
case "aggregate":
self.artifacts.append({"type": "aggregation", "artifact": output.artifact.get("aggregation_result", {})})
self.artifacts.append(
{
"type": "aggregation",
"artifact": content.get("aggregation_result", {}),
}
)
case "search":
try:
source_urls = [doc.metadata["api_link"] for doc in output.artifact]
self.artifacts.append({"type": "source_urls", "artifact": source_urls})
except json.decoder.JSONDecodeError as e:
print(f"Invalid json ({e}) returned from {output.name} tool: {output.content}")
source_urls = [doc.get("api_link") for doc in content]
self.artifacts.append({"type": "source_urls", "artifact": source_urls})
case "summarize":
print(output)

def log_metrics(self):
if self.log_stream is None:
return

log_group = os.getenv("METRICS_LOG_GROUP")
if log_group and ensure_log_stream_exists(log_group, self.log_stream):
client = log_client()
message = {
"answer": self.answers,
"artifacts": self.artifacts,
"token_counts": self.accumulator,
}
message.update(self.extra_data)

log_events = [
{
"timestamp": timestamp(),
"message": json.dumps(message),
}
]
client.put_log_events(
logGroupName=log_group, logStreamName=self.log_stream, logEvents=log_events
)


def ensure_log_stream_exists(log_group, log_stream):
client = log_client()
try:
print(
client.create_log_stream(logGroupName=log_group, logStreamName=log_stream)
)
return True
except client.exceptions.ResourceAlreadyExistsException:
return True
except Exception:
print(f"Could not create log stream: {log_group}:{log_stream}")
return False


def log_client():
return boto3.client("logs", region_name=os.getenv("AWS_REGION", "us-east-1"))


def timestamp():
return round(datetime.timestamp(datetime.now()) * 1000)
10 changes: 7 additions & 3 deletions chat/src/agent/callbacks/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,22 @@ def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Di
self.socket.send({"type": "tool_start", "ref": self.ref, "message": {"tool": serialized.get("name"), "input": input}})

def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
content = output.content
if isinstance(content, str):
content = json.loads(content)

match output.name:
case "aggregate":
self.socket.send({"type": "aggregation_result", "ref": self.ref, "message": output.artifact.get("aggregation_result", {})})
self.socket.send({"type": "aggregation_result", "ref": self.ref, "message": content.get('aggregation_result', {})})
case "discover_fields":
pass
case "search":
result_fields = ("id", "title", "visibility", "work_type", "thumbnail")
docs: List[Dict[str, Any]] = [{k: doc.metadata.get(k) for k in result_fields} for doc in output.artifact]
docs: List[Dict[str, Any]] = [{k: doc.get(k) for k in result_fields} for doc in content]
self.socket.send({"type": "search_result", "ref": self.ref, "message": docs})
case "retrieve_documents":
result_fields = ("id", "title", "visibility", "work_type", "thumbnail")
docs: List[Dict[str, Any]] = [{k: doc.get(k) for k in result_fields} for doc in output.artifact]
docs: List[Dict[str, Any]] = [{k: doc.get(k) for k in result_fields} for doc in content]
self.socket.send({"type": "retrieved_documents", "ref": self.ref, "message": docs})
case _:
print(f"Unhandled tool_end message: {output}")
Expand Down
39 changes: 33 additions & 6 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Literal, List
from langchain_core.messages import HumanMessage, ToolMessage
from agent.tools import aggregate, discover_fields, search, retrieve_documents
Expand All @@ -8,21 +9,26 @@
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from langgraph.errors import GraphRecursionError
from core.document import minimize_documents
from core.setup import checkpoint_saver
from agent.callbacks.socket import SocketCallbackHandler
from typing import Optional
import time

DEFAULT_SYSTEM_MESSAGE = """
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that
support your answer. Answer in raw markdown, but not within a code block. When citing source documents, construct Markdown
links using the document's canonical_link field. Do not include intermediate messages explaining your process. If the user's
question is unclear, ask for clarification.
question is unclear, ask for clarification. Use no more than 6 tool calls. If you still cannot answer the question after 6
tool calls, summarize the information you have gathered so far and suggest ways in which the user might narrow the scope
of their question to make it more answerable.
"""

MAX_RECURSION_LIMIT = 8
MAX_RECURSION_LIMIT = 16

class SearchWorkflow:
def __init__(self, model: BaseModel, system_message: str):
def __init__(self, model: BaseModel, system_message: str, metrics = None):
self.metrics = metrics
self.model = model
self.system_message = system_message

Expand All @@ -35,6 +41,23 @@ def should_continue(self, state: MessagesState) -> Literal["tools", END]:
# Otherwise, we stop (reply to the user)
return END

def summarize(self, state: MessagesState):
messages = state["messages"]
last_message = messages[-1]
if last_message.name not in ["search", "retrieve_documents"]:
return {"messages": messages}

start_time = time.time()
content = minimize_documents(json.loads(last_message.content))
content = json.dumps(content, separators=(',', ':'))
end_time = time.time()
elapsed_time = end_time - start_time
print(f'Condensed {len(last_message.content)} bytes to {len(content)} bytes in {elapsed_time:.2f} seconds. Savings: {100 * (1 - len(content) / len(last_message.content)):.2f}%')

last_message.content = content

return {"messages": messages}

def call_model(self, state: MessagesState):
messages = [SystemMessage(content=self.system_message)] + state["messages"]
response: BaseMessage = self.model.invoke(messages)
Expand All @@ -46,6 +69,7 @@ def __init__(
self,
model: BaseModel,
*,
metrics = None,
system_message: str = DEFAULT_SYSTEM_MESSAGE,
**kwargs
):
Expand All @@ -57,23 +81,26 @@ def __init__(
except NotImplementedError:
pass

self.workflow_logic = SearchWorkflow(model=model, system_message=system_message)
self.workflow_logic = SearchWorkflow(model=model, system_message=system_message, metrics=metrics)

# Define a new graph
workflow = StateGraph(MessagesState)

# Define the two nodes we will cycle between
workflow.add_node("agent", self.workflow_logic.call_model)
workflow.add_node("tools", tool_node)

workflow.add_node("summarize", self.workflow_logic.summarize)

# Set the entrypoint as `agent`
workflow.add_edge(START, "agent")

# Add a conditional edge
workflow.add_conditional_edges("agent", self.workflow_logic.should_continue)

# Add a normal edge from `tools` to `agent`
workflow.add_edge("tools", "agent")
#workflow.add_edge("tools", "agent")
workflow.add_edge("tools", "summarize")
workflow.add_edge("summarize", "agent")

self.checkpointer = checkpoint_saver()
self.search_agent = workflow.compile(checkpointer=self.checkpointer)
Expand Down
41 changes: 24 additions & 17 deletions chat/src/agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,19 @@ def get_keyword_fields(properties, prefix=''):
keyword_fields.extend(get_keyword_fields(nested_properties, prefix=current_path + '.'))
return keyword_fields

@tool(response_format="content_and_artifact")
def filter_results(results):
"""
Filters out the embeddings from the results
"""
filtered = []
for result in results:
doc = result.metadata
if 'embedding' in doc:
doc.pop('embedding')
filtered.append(doc)
return filtered

@tool(response_format="content")
def discover_fields():
"""
Discover the fields available in the OpenSearch index. This tool is useful for understanding the structure of the index and the fields available for aggregation queries.
Expand All @@ -32,15 +44,15 @@ def discover_fields():
fields = opensearch.client.indices.get_mapping(index=opensearch.index)
top_properties = list(fields.values())[0]['mappings']['properties']
result = get_keyword_fields(top_properties)
return json.dumps(result, default=str), result
return result

@tool(response_format="content_and_artifact")
@tool(response_format="content")
def search(query: str):
"""Perform a semantic search of Northwestern University Library digital collections. When answering a search query, ground your answer in the context of the results with references to the document's metadata."""
query_results = opensearch_vector_store().similarity_search(query, size=20)
return json.dumps(query_results, default=str), query_results
return filter_results(query_results)

@tool(response_format="content_and_artifact")
@tool(response_format="content")
def aggregate(agg_field: str, term_field: str, term: str):
"""
Perform a quantitative aggregation on the OpenSearch index. Use this tool for quantitative questions like "How many...?" or "What are the most common...?"
Expand All @@ -61,17 +73,18 @@ def aggregate(agg_field: str, term_field: str, term: str):
"""
try:
response = opensearch_vector_store().aggregations_search(agg_field, term_field, term)
return json.dumps(response, default=str), response
return response
except Exception as e:
return json.dumps({"error": str(e)}), None
return json.dumps({"error": str(e)})

@tool(response_format="content_and_artifact")
@tool(response_format="content")
def retrieve_documents(doc_ids: List[str]):
"""
Retrieve documents from the OpenSearch index based on a list of document IDs.

Use this instead of the search tool if the user has provided docs for context
and you need the full metadata.
and you need the full metadata, or if you're working with output from another
tool that only contains document IDs.
Provide an answer to their question based on the metadata of the documents.


Expand All @@ -84,12 +97,6 @@ def retrieve_documents(doc_ids: List[str]):

try:
response = opensearch_vector_store().retrieve_documents(doc_ids)
documents = []
for doc in response:
metadata = doc.metadata
if 'embedding' in metadata:
metadata.pop('embedding')
documents.append(metadata)
return json.dumps(documents, default=str), documents
return filter_results(response)
except Exception as e:
return json.dumps({"error": str(e)}), None
return {"error": str(e)}
Loading