-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathmain_qa.py
More file actions
107 lines (84 loc) · 3.66 KB
/
main_qa.py
File metadata and controls
107 lines (84 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
from pathlib import Path
from util import *
from eval import *
from dataset import get_dataset
from prompts import PromptFactory
from model import get_model
from tqdm import tqdm
from pprint import pprint
def launch():
args = parse_args()
pprint(args)
# Reading the JSON data from that file
if args.tree_node_idx is not None:
with open(args.tree_node_idx, 'r') as file:
tree_node_idx = json.load(file)
tree_node_idx_dict = {item['name']: item['sorted_values'] for item in tree_node_idx}
# output
makedir(args.output_base_path)
output_path = os.path.join(args.output_base_path, args.output_filename)
# resume
processed = {}
if not args.start_from_scratch and os.path.exists(output_path):
processed = load_json(output_path)
if 'data' in processed:
processed = processed['data']
# get input
quids_to_exclude = set(list(processed.keys()))
dataset = get_dataset(args, quids_to_exclude=quids_to_exclude, num_examples_to_run=args.num_examples_to_run)
# configure prompt
prompter = PromptFactory().get(args.prompt_type)
# get model
model = get_model(args)
model.set_post_process_fn(prompter.post_process_fn)
loc = None
# answer
pbar = tqdm(total=len(dataset))
for i, item in enumerate(dataset):
ukey_1 = item['quid'] if 'quid' in item else item['uid']
clip_length = int(1/args.fps) if args.fps < 1 else 1/args.fps
few_shot_examples = build_fewshot_examples(args.fewshot_example_path, args.data_path)
if args.tree_node_idx is not None:
loc = tree_node_idx_dict.get(ukey_1, None)
model.set_post_process_fn(prompter.post_process_fn)
prompt = prompter.fill(**item, fps=args.fps, clip_length=clip_length, num_words=args.num_words_in_sum, examplars=few_shot_examples, loc_pred = loc)
pred, info = model.forward(prompter.head, prompt)
ukey_name = 'quid' if 'quid' in item else 'uid'
# print("pred: ", pred)
ukey = item[ukey_name]
processed[ukey] = item
processed[ukey]['prompt'] = prompt
processed[ukey]['prompt_template'] = prompter.get_template_str()
processed[ukey]['response'] = info['response']
processed[ukey]['pred'] = pred
if args.save_info:
processed[ukey]['info'] = {k: v for k, v in info.items() if k != 'response'}
if i % args.save_every == 0:
save_json(processed, output_path)
pbar.update(1)
# incorporate with backup prediction
if len(args.backup_pred_path) > 0:
backup = load_json(args.backup_pred_path)
if 'data' in backup:
backup = backup['data']
for uid in processed:
if processed[uid]['pred'] == -1:
processed[uid]['pred'] = backup[uid]['pred']
# if eval
if not args.disable_eval:
if args.task == 'qa':
if args.dataset == 'egoschema':
processed = eval_qa_egoschema(processed)
elif args.dataset in ['nextqa', 'intentqa', 'nextgqa']:
processed = eval_qa_nextqa(args.anno_path, processed)
elif args.task == 'gqa':
if args.dataset == 'nextgqa':
pred_qa_path = args.nextgqa_pred_qa_path if len(args.nextgqa_pred_qa_path) > 0 else None
processed = eval_gqa(args.nextgqa_gt_ground_path, processed, pred_qa_path=pred_qa_path)
elif args.task == 'sum':
processed, sum_data = eval_sum(processed)
save_json(sum_data, f'{Path(output_path).parent / Path(output_path).stem}_data.json')
save_json(processed, output_path)
if __name__ == '__main__':
launch()