-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmain.py
More file actions
159 lines (136 loc) · 5.3 KB
/
Copy pathmain.py
File metadata and controls
159 lines (136 loc) · 5.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import asyncio
import logging
from pathlib import Path
import pandas as pd
import uvloop
from httpx import AsyncClient
from ferry.args_parser import Args, get_args, parse_seasons_arg
from ferry.crawler.cache import load_cache_json
from ferry.crawler.cas_request import USER_AGENT
from ferry.crawler.classes import crawl_classes
from ferry.crawler.evals import crawl_evals
from ferry.crawler.seasons import fetch_seasons
from ferry.database import sync_db_courses, sync_db_courses_old, sync_db_evals
from ferry.summarize import DEFAULT_MODEL, summarize_evals
from ferry.transform import transform, write_csvs
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Limit pandas display to avoid memory issues with large DataFrames
pd.set_option("display.max_columns", 20)
pd.set_option("display.max_rows", 100)
pd.set_option("display.max_colwidth", 50)
pd.set_option("display.width", 120)
async def start_crawl(args: Args) -> list[str]:
"""Run the crawl stages and return the resolved list of seasons."""
classes = None
# Initialize HTTPX client, only used for fetching classes (evals fetch
# initializes its own client with CAS auth)
client = AsyncClient(timeout=None, headers={"User-Agent": USER_AGENT})
if args.crawl_seasons:
course_seasons = await fetch_seasons(
data_dir=args.data_dir, client=client, use_cache=args.use_cache
)
else:
# Still try to load from cache if it exists
course_seasons = load_cache_json(args.data_dir / "course_seasons.json")
seasons = parse_seasons_arg(
arg_seasons=args.seasons, all_viable_seasons=course_seasons
)
print("-" * 80)
if args.crawl_classes:
ycs_pers = None
if args.ycs_pers:
import ujson
try:
ycs_pers = ujson.loads(args.ycs_pers)
except Exception as e:
print(f"Error parsing YCS personalization tokens: {e}")
classes = await crawl_classes(
seasons=seasons,
data_dir=args.data_dir,
cws_api_key=args.cws_api_key,
client=client,
use_cache=args.use_cache,
pers=ycs_pers,
cookie_header=args.ycs_cookie,
)
# Validate that locations were fetched if credentials were provided
if args.ycs_cookie and args.ycs_pers:
location_count = 0
for season_courses in classes.values():
for course in season_courses:
for meeting in course.get("meetings", []):
if meeting.get("location"):
location_count += 1
print(f"Found {location_count} meetings with locations")
if location_count == 0:
raise RuntimeError(
"YCS credentials were provided but no locations were fetched. "
"This likely means the authentication failed or the credentials are invalid."
)
if args.crawl_evals:
await crawl_evals(
cas_cookie=args.cas_cookie,
seasons=seasons,
data_dir=args.data_dir,
courses=classes,
)
# Track seasons updated during crawl for catalog refresh endpoint
if args.crawl_classes or args.crawl_evals:
updated_seasons_path = args.data_dir / "ferry_updated_seasons.txt"
updated_seasons_path.write_text(",".join(seasons))
await client.aclose()
print("-" * 80)
return seasons
async def main():
args = get_args()
if args.debug:
logging.basicConfig(level=logging.DEBUG)
args.data_dir.mkdir(parents=True, exist_ok=True)
if args.release:
import sentry_sdk
sentry_sdk.init(
args.sentry_url,
# Set traces_sample_rate to 1.0 to capture 100%
# of transactions for performance monitoring.
# We recommend adjusting this value in production.
traces_sample_rate=1.0,
)
else:
print("Running in dev mode. Sentry not initialized.")
seasons = await start_crawl(args)
tables = None
if args.transform:
tables = await transform(data_dir=args.data_dir)
if args.snapshot_tables:
assert tables
write_csvs(tables, data_dir=args.data_dir)
if args.sync_db_courses:
assert tables
if args.rewrite:
sync_db_courses_old(tables, args.database_connect_string)
else:
sync_db_courses(
tables,
args.database_connect_string,
data_dir=args.data_dir,
freeze_locations=args.freeze_locations,
)
if args.sync_db_evals:
assert tables
sync_db_evals(tables, args.database_connect_string)
if args.summarize_evals:
if not args.openai_api_key:
raise ValueError("API key is required for --summarize-evals")
await summarize_evals(
seasons=seasons,
data_dir=args.data_dir,
api_key=args.openai_api_key,
model=args.llm_model or DEFAULT_MODEL,
base_url=args.llm_base_url,
max_courses_per_season=args.max_courses,
)
if args.generate_diagram:
from ferry.generate_db_diagram import generate_db_diagram
generate_db_diagram(path=Path("docs/db_diagram.pdf"))
if __name__ == "__main__":
asyncio.run(main())