-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutil.py
More file actions
63 lines (46 loc) · 1.59 KB
/
util.py
File metadata and controls
63 lines (46 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import errno
import torch
import shutil
import numpy as np
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.value = 0
self.ave = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.value = val
self.sum += val * n
self.count += n
self.ave = self.sum / self.count
def mkdir_p(path):
try:
os.mkdir(path)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
def adjust_learning_rate(optimizer, train_configuration, epoch, training_epoch_num, args):
cos_inner = np.pi * (epoch % training_epoch_num)
cos_inner /= (training_epoch_num)
cos_out = np.cos(cos_inner) + 1
learning_rate = float(train_configuration['learning_rate'] / 2 * cos_out)
optimizer.param_groups[0]['lr'] = learning_rate / 10
optimizer.param_groups[1]['lr'] = learning_rate
optimizer.param_groups[2]['lr'] = learning_rate
for param_group in optimizer.param_groups:
print(param_group['lr'])
def save_checkpoint(state, checkpoint, filename='checkpoint.pth.tar'):
filepath = checkpoint + '/' + filename
torch.save(state, filepath)
# if is_best:
# shutil.copyfile(filepath, checkpoint + '/model_best.pth.tar')
def save_prime(state, is_best, checkpoint, filename='prime.pth.tar'):
filepath = checkpoint + '/' + filename
torch.save(state, filepath)
if is_best:
shutil.copyfile(filepath, checkpoint + '/prime_best.pth.tar')