-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtest.py
More file actions
53 lines (43 loc) · 1.74 KB
/
Copy pathtest.py
File metadata and controls
53 lines (43 loc) · 1.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from tqdm import tqdm
from tools.get_config import get_cfg
from Dataset.build import build_testset
from Models.build import build_model
from Eval.build import build_evaluator
import argparse
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='view-config')
parser.add_argument('--gpu_no', default=0, type=int)
parser.add_argument('--test_batch_size', default=32, type=int)
parser.add_argument('--cfg', default='./Config/config.py', type=str)
parser.add_argument('--result_path', default='./result', type=str)
parser.add_argument('--weight_path', default='', type=str)
parser.add_argument('--view_path', default='./view', type=str)
parser.add_argument('--is_view', default=0, type=int)
parser.add_argument('--is_val', default=0, type=int)
args = parser.parse_args()
return args
cfg = get_cfg(parse_args())
torch.cuda.set_device(cfg.gpu_no)
net = build_model(cfg)
net.load_state_dict(torch.load(cfg.weight_path, map_location='cpu'), strict=True)
net.cuda().eval()
tsset = build_testset(cfg)
print('testset length:', len(tsset))
evaluator = build_evaluator(cfg)
evaluator.pre_process()
tsloader = torch.utils.data.DataLoader(tsset, batch_size=cfg.test_batch_size, shuffle=False, num_workers=16,
drop_last=False, collate_fn = tsset.collate_fn)
for i, (img, file_names, ori_imgs) in enumerate(tqdm(tsloader, desc=f'Model is running')):
with torch.no_grad():
img = img.cuda()
outputs = net(img)
if cfg.is_view:
evaluator.view_output(outputs, file_names, ori_imgs)
else:
evaluator.write_output(outputs, file_names)
if cfg.is_view:
evaluator.view_gt()
else:
evaluator.evaluate()