-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtarget_builder.py
More file actions
347 lines (288 loc) · 11.5 KB
/
Copy pathtarget_builder.py
File metadata and controls
347 lines (288 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
"""Target vector generation for RL-NPO.
Generates empirical target vectors by:
1. Using an LLM to generate neuroscience-grounded reference text for each ROI
2. Running TRIBE v2 on the reference text to get cortical predictions
3. Time-averaging predictions and saving as .npy files
Runs once per ROI, offline. Cached results are never regenerated during the RL loop.
"""
import os
import sys
import tempfile
import json
import logging
import numpy as np
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv()
logger = logging.getLogger(__name__)
# ─────────────────────────────────────────────────────────────────────
# Reference text generation prompts (neuroscience-grounded)
# ─────────────────────────────────────────────────────────────────────
REFERENCE_PROMPTS: dict[str, str] = {
"memory": (
"Write a 90-second vivid episodic narrative with clear temporal "
"structure, scene boundaries, and spatial context. Use first-person. "
"Include specific sensory details — textures, temperatures, smells — "
"anchored in a concrete place and time. These properties are known "
"from neuroscience to maximally activate the hippocampal-"
"parahippocampal memory encoding system."
),
"emotion": (
"Write a 90-second emotionally charged monologue with high arousal "
"content — urgency, stakes, interpersonal tension. Use concrete "
"emotional language: fear, loss, desperate hope, the physical "
"sensations of adrenaline. These properties activate amygdala and "
"anterior cingulate emotional processing circuits."
),
"attention": (
"Write 90 seconds of dense, technically complex instructional "
"content requiring active working memory. Use numbered steps, "
"conditional logic, and backward references (e.g., 'recall from "
"step 2'). These properties activate DLPFC and IPS attentional "
"networks."
),
"language": (
"Write 90 seconds of syntactically complex prose with deeply "
"embedded clauses, passive constructions, garden-path sentences, "
"and semantic ambiguity that resolves late in the sentence. These "
"properties activate Broca's area and Wernicke's area language "
"processing networks."
),
"narrative": (
"Write a 90-second gripping story opening that creates narrative "
"tension and suppresses mind-wandering. Use vivid scene-setting, "
"an immediate conflict, and a forward-pulling unanswered question. "
"These properties suppress default mode network activation (the "
"'locked in' state)."
),
}
def _get_openrouter_client() -> OpenAI:
"""Create OpenAI client configured for OpenRouter."""
api_key = os.environ.get("OPENROUTER_KEY")
if not api_key:
raise EnvironmentError(
"OPENROUTER_KEY not found in environment. "
"Set it in .env or export it."
)
return OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=api_key,
)
def generate_reference_text(
roi: str,
model: str = "anthropic/claude-3.5-haiku",
) -> str:
"""Use an LLM to generate neuroscience-grounded reference text for an ROI.
Parameters
----------
roi : str
Target ROI name (memory, emotion, attention, language, narrative).
model : str
OpenRouter model identifier.
Returns
-------
str
Generated reference text optimized for the target brain region.
"""
if roi not in REFERENCE_PROMPTS:
raise KeyError(f"Unknown ROI '{roi}'. Available: {list(REFERENCE_PROMPTS.keys())}")
client = _get_openrouter_client()
prompt = REFERENCE_PROMPTS[roi]
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": (
"You are an expert writer. Produce exactly the text requested. "
"No preamble, no commentary, no meta-discussion. "
"Just the raw text, nothing else."
),
},
{"role": "user", "content": prompt},
],
temperature=0.9,
max_tokens=2000,
)
text = response.choices[0].message.content.strip()
logger.info(f"Generated reference text for '{roi}': {len(text)} chars")
return text
def generate_target_vector(
roi: str,
reference_text: str,
output_dir: str = "targets",
cache_folder: str = "./cache",
device: str = "auto",
) -> np.ndarray:
"""Run TRIBE v2 text-only inference on reference text to produce a target vector.
Parameters
----------
roi : str
Target ROI name.
reference_text : str
The reference text to process through TRIBE v2.
output_dir : str
Directory to save the .npy target vector.
cache_folder : str
TRIBE v2 cache directory for feature extraction.
device : str
Device for inference: "auto" (CUDA if available), "cuda", or "cpu".
Returns
-------
np.ndarray
Time-averaged cortical prediction of shape (20484,).
"""
import gc
import torch
from tribev2.tribev2 import TribeModel
os.makedirs(output_dir, exist_ok=True)
# Write text to temp file (TRIBE v2 requires file path input)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".txt", delete=False, encoding="utf-8"
) as f:
f.write(reference_text)
f.flush()
os.fsync(f.fileno())
tmp_path = f.name
try:
logger.info("Loading TRIBE v2 model for target generation...")
model = TribeModel.from_pretrained(
"facebook/tribev2", cache_folder=cache_folder, device=device
)
logger.info("Building events dataframe from text...")
df = model.get_events_dataframe(text_path=tmp_path)
logger.info("Running prediction...")
preds, _ = model.predict(events=df) # shape: (n_timesteps, 20484)
# Time-average to get stable activation profile
target_vec = preds.mean(axis=0) # shape: (20484,)
out_path = os.path.join(output_dir, f"{roi}.npy")
np.save(out_path, target_vec)
logger.info(
f"Saved target vector for '{roi}': shape={target_vec.shape} -> {out_path}"
)
return target_vec
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
# Free GPU memory between targets to prevent OOM
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def build_all_targets(
output_dir: str = "targets",
cache_folder: str = "./cache",
openrouter_model: str = "anthropic/claude-3.5-haiku",
reference_texts_dir: str = "targets/reference_texts",
) -> dict[str, np.ndarray]:
"""Generate target vectors for all predefined ROIs.
Parameters
----------
output_dir : str
Directory to save .npy files.
cache_folder : str
TRIBE v2 cache folder.
openrouter_model : str
LLM model for reference text generation.
reference_texts_dir : str
Directory to save generated reference texts for inspection.
Returns
-------
dict[str, np.ndarray]
Mapping from ROI name to target vector.
"""
os.makedirs(reference_texts_dir, exist_ok=True)
results = {}
for roi in REFERENCE_PROMPTS:
out_path = os.path.join(output_dir, f"{roi}.npy")
# Skip if already generated
if os.path.exists(out_path):
logger.info(f"Target vector for '{roi}' already exists at {out_path}, skipping.")
results[roi] = np.load(out_path)
continue
print(f"\n{'='*60}")
print(f" Generating target vector: {roi}")
print(f"{'='*60}")
# Step 1: Generate reference text via LLM
print(f" [1/3] Generating reference text via {openrouter_model}...")
ref_text = generate_reference_text(roi, model=openrouter_model)
# Save reference text for inspection
ref_path = os.path.join(reference_texts_dir, f"{roi}.txt")
with open(ref_path, "w", encoding="utf-8") as f:
f.write(ref_text)
print(f" [1/3] Reference text saved to {ref_path} ({len(ref_text)} chars)")
# Step 2-3: Run TRIBE v2 and save
print(f" [2/3] Running TRIBE v2 inference (this may take a few minutes on CPU)...")
target_vec = generate_target_vector(
roi, ref_text, output_dir=output_dir, cache_folder=cache_folder
)
print(f" [3/3] Target vector saved: shape={target_vec.shape}")
results[roi] = target_vec
return results
# ─────────────────────────────────────────────────────────────────────
# Custom target generation from arbitrary files
# ─────────────────────────────────────────────────────────────────────
# File extensions mapped to TRIBE v2 input types
_EXT_TO_TYPE = {
".txt": "text",
}
def generate_custom_target(
input_path: str,
output_path: str = "targets/custom.npy",
cache_folder: str = "./cache",
name: str = "custom",
) -> np.ndarray:
"""Generate a target vector from a custom text file.
Parameters
----------
input_path : str
Path to the input text file (.txt).
output_path : str
Where to save the resulting .npy target vector.
cache_folder : str
TRIBE v2 feature cache directory.
name : str
Name for logging purposes.
Returns
-------
np.ndarray
Time-averaged cortical prediction of shape (20484,).
"""
import gc
import torch
from pathlib import Path
from tribev2.tribev2 import TribeModel
input_path = str(Path(input_path).resolve())
ext = Path(input_path).suffix.lower()
if ext not in _EXT_TO_TYPE:
raise ValueError(
f"Unsupported file type '{ext}'. "
f"RL-NPO is strictly text-only. Supported: {list(_EXT_TO_TYPE.keys())}"
)
device = "auto"
print(f" Input type: text")
print(f" Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
print(f" Loading TRIBE v2 model...")
model = TribeModel.from_pretrained(
"facebook/tribev2", cache_folder=cache_folder, device=device
)
print(f" Building events dataframe from {Path(input_path).name}...")
df = model.get_events_dataframe(text_path=input_path)
print(f" Running prediction...")
preds, _ = model.predict(events=df) # shape: (n_timesteps, 20484)
# Time-average to get stable activation profile
target_vec = preds.mean(axis=0) # shape: (20484,)
# Save
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
np.save(output_path, target_vec)
print(f" Saved: {output_path} (shape={target_vec.shape})")
# Cleanup
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return target_vec
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
print("Building all target vectors...")
build_all_targets()
print("\nDone. All target vectors saved to targets/")