-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathrun_net.py
More file actions
58 lines (48 loc) · 1.76 KB
/
Copy pathrun_net.py
File metadata and controls
58 lines (48 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Modified from pycls.
https://github.com/facebookresearch/pycls/blob/main/tools/run_net.py
"""
import argparse
import sys
import os
import pycls.core.config as config
import pycls.core.distributed as dist
import pycls.core.trainer as trainer
from pycls.core.config import cfg
def parse_args():
parser = argparse.ArgumentParser(description="Run a model.")
help_s, choices = "Run mode", ["train", "test", "time"]
parser.add_argument("--mode", help=help_s, choices=choices, required=True, type=str)
help_s = "Config file location"
parser.add_argument("--cfg", help=help_s, required=True, type=str)
help_s = "See pycls/core/config.py for all options"
parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args()
def main():
args = parse_args()
mode = args.mode
config.load_cfg(args.cfg)
cfg.merge_from_list(args.opts)
if cfg.OUT_DIR is None:
out_dir = os.path.join('work_dirs', os.path.splitext(os.path.basename(args.cfg))[0])
cfg.OUT_DIR = out_dir
config.assert_cfg()
cfg.freeze()
if not os.path.exists(cfg.OUT_DIR):
os.makedirs(cfg.OUT_DIR)
if mode == "train":
dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.train_model)
elif mode == "test":
dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.test_model)
elif mode == "time":
dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.time_model)
if __name__ == "__main__":
main()