-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
94 lines (76 loc) · 2.61 KB
/
Copy pathmain.py
File metadata and controls
94 lines (76 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from src.graph.graph import build_graph
from src.memory.short_term_memory import (
merge_short_term_into_state,
short_term_after_turn,
)
from src.memory.build_paper_memory import build_paper_memory
import json
import sys
_graph_app = None
_vectordb = None
def get_graph_app():
global _graph_app
if _graph_app is None:
_graph_app = build_graph()
return _graph_app
def get_paper_vectordb():
global _vectordb
if _vectordb is None:
_vectordb = build_paper_memory(
persist_dir="data/chroma",
collection_name="paper_memory",
)
return _vectordb
def run_workflow(query: str, short_term: dict | None = None, debug: bool = False):
"""
Run one turn. If short_term from a previous turn is provided, the router
and executor can use it (messages, last_papers, etc.).
Returns the full graph state (use result["final_answer"] for the reply).
If debug=True, prints the full final state dict for inspection.
"""
app = get_graph_app()
paper_vectordb = get_paper_vectordb()
state_input = {
"query": query,
"paper_vectordb": paper_vectordb,
"messages": [],
}
state_input = merge_short_term_into_state(state_input, short_term)
result = app.invoke(state_input)
if debug:
print("=== DEBUG: full agent state ===")
try:
print(json.dumps(result, indent=2, ensure_ascii=False, default=str))
except TypeError:
# Fallback if there are non-serializable objects
print(result)
return result
def run_workflow_with_short_term(
query: str, short_term: dict | None = None, debug: bool = False
):
"""
Run one turn with short-term memory. Returns (final_answer, updated_short_term).
Pass updated_short_term as short_term on the next turn for multi-turn conversation.
If debug=True, prints the full final state dict for each turn.
"""
result = run_workflow(query, short_term=short_term, debug=debug)
answer = result.get("final_answer", "")
next_short_term = short_term_after_turn(result, query)
return answer, next_short_term
# CLI
if __name__ == "__main__":
app = get_graph_app()
paper_vectordb = get_paper_vectordb()
short_term = None
# Simple CLI flag: `python main.py --debug`
debug = "--debug" in sys.argv
if debug:
print("[debug] State printing is enabled (--debug).")
while True:
query = input("You: ")
if query == "exit":
break
answer, short_term = run_workflow_with_short_term(
query, short_term, debug=debug
)
print(answer)