-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualize.py
More file actions
94 lines (67 loc) · 3.02 KB
/
visualize.py
File metadata and controls
94 lines (67 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import json, os
import torch
import json
import numpy as np
import random
import argparse
from fidbench.misc import get_tokenizer
from pygments.console import colorize
import transformers
transformers.logging.set_verbosity_error()
def log_exceed(l, max_l):
msg = colorize('red', "[WARNING]: ") + f'Expect input token length plus `max_gen` to be less equal than {max_l}, but got {l}.'
print(msg, flush=True)
def log_warning(msg):
msg = colorize('red', "[WARNING]: ") + msg
print(msg, flush=True)
def log_info(msg):
msg = colorize('green', "[INFO]: ") + msg
print(msg, flush=True)
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed)
if __name__ == '__main__':
seed_everything(42)
parser = argparse.ArgumentParser()
parser.add_argument("env_conf", type=str, default=None)
parser.add_argument("--thresh", type=float, default=3)
args = parser.parse_args()
with open(args.env_conf, "r") as f:
env_conf = json.load(f)
run_name = args.env_conf.replace('.json', '')
method_name = run_name.split('/')[1]
base_model_run_name = run_name.replace(f"-{method_name}", "").replace(f"/{method_name}", "")
base_model_pred_path = os.path.join("pred", base_model_run_name)
model_pred_path = os.path.join("pred", run_name)
tokenizer = get_tokenizer(**env_conf['model'])
for pred_file_name in os.listdir(base_model_pred_path):
if 'jsonl' not in pred_file_name:
continue
base_model_file_path = os.path.join(base_model_pred_path, pred_file_name)
model_file_path = os.path.join(model_pred_path, pred_file_name)
assert os.path.exists(model_file_path)
with open(base_model_file_path, 'r') as f1, open(model_file_path, 'r') as f2:
print(colorize('blue', f'visualize {model_file_path} ...'), end='\n')
for line1, line2 in zip(f1, f2):
p_ids = json.loads(line1)['p_ids']
g_ids = json.loads(line1)['g_ids']
acc1 = json.loads(line2)['acc1']
acc5 = json.loads(line2)['acc5']
toks = json.loads(line2)['toks']
probs = json.loads(line2)['probs']
assert len(g_ids) - 1 == len(acc1) == len(acc5)
print(tokenizer.decode(p_ids), end='')
print(tokenizer.decode(g_ids[0]), end='')
g_ids = g_ids[1:]
for gid, correct, tok, prob in zip(g_ids, acc1, toks, probs):
if correct or prob[1] / (prob[0] + 1e-8) < args.thresh:
print(colorize('green', tokenizer.decode(gid)), end='')
else:
print(colorize('red', f"{tokenizer.decode(gid)}[{tokenizer.decode(tok[1])}]"), end='')
print(colorize('blue', '\n' + '=' * 80))
input('>> press enter to continue')