-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_model.py
More file actions
170 lines (145 loc) · 6.11 KB
/
train_model.py
File metadata and controls
170 lines (145 loc) · 6.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import click
from datetime import date
import build_net
import pandas as pd
from contextlib import redirect_stdout
import numpy as np
from tensorflow.keras.callbacks import Callback
class Metrics(Callback):
def __init__(self, val_features, val_features_conv, val_targets):
self.val_features = val_features
self.val_features_conv = val_features_conv
self.val_targ = val_targets
def on_train_begin(self, logs={}):
self.val_percent_errors = []
def on_epoch_end(self, epoch, logs={}):
val_predict = (np.asarray(self.model.predict([self.val_features, self.val_features_conv])))
error = np.abs(self.val_targ[:, 0] - val_predict[:, 0])
_val_percent_error = np.average((error / self.val_targ[:, 0]) * 100)
self.val_percent_errors.append(_val_percent_error)
print(' - validation single pulse percent error: {} % '.format(_val_percent_error))
print(' ')
return
def plot_loss(loss, val_loss, mae, save_dir):
"""
Makes a plot of training and validation loss over the course of training.
:param loss: training loss
:param val_loss: validation loss
:param mae: Mean Absolute Error
:param save_dir: directory to save the image to
"""
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, label='Training Loss (MSE)')
plt.plot(epochs, val_loss, label='Validation Loss (MSE)')
plt.plot(epochs, mae, label='MAE')
plt.xlabel('Epochs')
plt.ylabel('Error')
plt.title('Training/Validation Loss and MAE')
plt.legend()
filename = os.path.join(save_dir, 'Loss_history.png')
plt.savefig(filename)
plt.close()
def write_info_file(save_dir, data_path, batch_size, epochs, lr, run_number, times):
"""
Writes a text file to the save directory with a summary of the hyper-parameters used for training
:param str save_dir: path to directory to save the file to
:param str data_path: path to .h5 data file
:param int batch_size: size of batches used in training
:param int epochs: number of epochs network was trained for
:param float lr: learning rate for the optimizer
:param str run_number: Run number of the day
:param bool times: Was the model trained on time (true) or energies (false)
"""
filename = os.path.join(save_dir, 'run_info.txt')
info_list = ['ContextEncoder Hyper-parameters: Run {} \n'.format(run_number),
'Training data found at: {} \n'.format(data_path),
'Batch Size: {} \n'.format(batch_size),
'Epochs: {} \n'.format(epochs),
'Learning Rate: {} \n'.format(lr),
'Times: {} \n'.format(times)]
with open(filename, 'w') as f:
f.writelines(info_list)
@click.command()
@click.argument('data_path', type=click.Path(exists=True, readable=True))
@click.option('--batch_size', default=32)
@click.option('--num_pulses', default=-1)
@click.option('--epochs', default=50)
@click.option('--lr', default=0.0001, help='Learning rate for Adam optimizer')
@click.option('--run_number', default=1, help='ith run of the day')
@click.option('--times/--energies', default=True)
def main(data_path, batch_size, num_pulses, epochs, lr, run_number, times):
today = str(date.today())
run_number = '_' + str(run_number)
save_dir = './Run_' + today + run_number
if os.path.exists(save_dir):
ans = input(
'The directory this run will write to already exists, would you like to overwrite it? ([y/n])')
if ans == 'y':
pass
else:
return
else:
os.makedirs(save_dir)
write_info_file(save_dir, data_path, batch_size, epochs, lr, run_number, times)
if times:
model = build_net.time_model()
else:
model = build_net.energy_model()
# write .txt file with model summary
filename = os.path.join(save_dir, 'modelsummary.txt')
with open(filename, 'w') as f:
with redirect_stdout(f):
model.summary()
adam = tf.keras.optimizers.Adam(lr=lr)
model.compile(optimizer=adam, loss='mse', metrics=['mae'])
# load and normalize data
data = np.load(data_path)
print('Loaded Data')
features = data[:, :500]
features_conv = data[:, :500]
if times:
targets = data[:, 500:502]
print(targets.max(axis=0))
targets /= targets.max(axis=0)
else:
print('Loading energies as target data')
targets = data[:, 502:]
targets /= np.std(targets, axis=0)
features_conv = features_conv / np.max(features_conv)
features_conv = tf.expand_dims(features_conv, -1)
features = features / np.max(features)
split = round(0.8 * len(features))
train_features = features[:split]
val_features = features[split:]
train_targets = targets[:split]
val_targets = targets[split:]
train_features_conv = features_conv[:split]
val_features_conv = features_conv[split:]
# set up checkpoints
checkpoint_path = os.path.join(save_dir, "checkpoints/cp-{epoch:04d}.ckpt")
cp_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, verbose=1, save_weights_only=True,
# Save weights, every 5-epochs.
period=5)
metrics = Metrics(val_features, val_features_conv, val_targets)
history = model.fit([train_features, train_features_conv],
train_targets,
validation_data=([val_features, val_features_conv], val_targets),
epochs=epochs,
batch_size=batch_size,
callbacks=[cp_callback, metrics])
# save losses in dataframe
loss = pd.Series(history.history['loss'])
val_loss = pd.Series(history.history['val_loss'])
mae = pd.Series(history.history['val_mae'])
loss_df = pd.DataFrame({'Training Loss': loss,
'Val Loss': val_loss,
'MAE': mae})
filename = os.path.join(save_dir, 'losses.csv')
loss_df.to_csv(filename) # save losses for further plotting/analysis
plot_loss(loss, val_loss, mae, save_dir)
if __name__ == '__main__':
main()