-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpredict_cnn.py
More file actions
114 lines (101 loc) · 3.5 KB
/
Copy pathpredict_cnn.py
File metadata and controls
114 lines (101 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
# coding: utf-8
"""Train CNNs."""
import argparse
import logging
import os.path
import numpy as np
import pandas as pd
import tensorflow as tf
import uproot
from tensorflow.keras.models import load_model
def event_generator(filename, target):
"""Generate events for CNN training."""
log = "energy" in target
with uproot.open(filename) as events:
for batch, report in events.iterate(step_size=1, report=True, library="np"):
for i in range(batch["X"].shape[0]):
yield (
batch["X"].astype(np.float16)[i],
batch["X_mufilter"].astype(np.float16)[i],
(np.log(batch[target][i]) if log else batch[target][i]),
)
@tf.function
def reshape_data(hitmaps, hitmaps_mufilter, truth):
"""Reshape data from hitmaps per subsystem to hitmaps per view."""
hitmaps_v = hitmaps[:, ::2]
hitmaps_h = hitmaps[:, 1::2]
hitmaps_v_T = tf.transpose(hitmaps_v)
hitmaps_h_T = tf.transpose(hitmaps_h)
X_v = tf.expand_dims(hitmaps_v_T, 2)
X_h = tf.expand_dims(hitmaps_h_T, 2)
hitmaps_v = hitmaps_mufilter[:, ::2]
hitmaps_h = hitmaps_mufilter[:, 1:10:2]
hitmaps_v_T = tf.transpose(hitmaps_v)
hitmaps_h_T = tf.transpose(hitmaps_h)
X_mufilter_v = tf.expand_dims(hitmaps_v_T, 2)
X_mufilter_h = tf.expand_dims(hitmaps_h_T, 2)
return (X_v, X_h, X_mufilter_v, X_mufilter_h), truth
def main():
"""Train a pre-built Keras CNN model."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model",
help="""Keras model to load.""",
required=True,
)
parser.add_argument(
"--target",
help="""Target observable.""",
required=True,
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
help="""Number of samples per batch.""",
required=True,
)
parser.add_argument(
"--data",
help="""Test dataset to use.""" """Supports retieval via XRootD.""",
required=True,
)
args = parser.parse_args()
model_name = "_".join(os.path.split(args.model)[-1].split("_")[:5])
print(f"Predicting using model {model_name}.")
events = uproot.open(args.data + ":df")
y_test = events[args.target].array()
ds_test = (
tf.data.Dataset.from_generator(
(lambda: event_generator(args.data + ":df", args.target)),
output_signature=(
tf.TensorSpec(shape=(3072, 200), dtype=tf.float16),
tf.TensorSpec(shape=(4608, 42), dtype=tf.float16),
tf.TensorSpec(shape=(), dtype=tf.float64),
),
)
.map(reshape_data)
.apply(tf.data.experimental.assert_cardinality(events.num_entries))
.batch(args.batch_size)
)
model = load_model(args.model)
if len(model.inputs) == 2:
# Old format, target only
# TODO nonsense results, what's going on?
ds_test = ds_test.map(lambda x, y: ((x[0], x[1]), y))
y_pred = model.predict(ds_test)
if "energy" in args.target:
y_pred = np.exp(y_pred)
n_events = events.num_entries
epochs = int(os.path.split(args.model)[-1].split("_")[6][1:].split(".")[0])
df = pd.DataFrame(
{
f"{args.target}_pred": np.squeeze(y_pred),
f"{args.target}_test": np.squeeze(y_test),
}
)
df.to_csv(f"{model_name}_n{n_events}_e{epochs}.csv")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()