-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathntuple_creator.py
More file actions
123 lines (106 loc) · 3.47 KB
/
Copy pathntuple_creator.py
File metadata and controls
123 lines (106 loc) · 3.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python
# coding: utf-8
"""Create features for use in MVA algorithms."""
import argparse
import logging
import pandas as pd
import ROOT
def main():
"""Create features for use in MVA algorithms."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--inputfiles",
nargs="+",
help="""List of input files to use."""
"""Supports retrieving file from EOS via the XRootD protocol.""",
required=True,
)
parser.add_argument(
"-o",
"--outputfile",
help="""File to write the filtered tree to."""
"""Will be recreated if it already exists.""",
)
parser.add_argument(
"-j",
"--num_cpu",
default=1,
type=int,
help="""Number of threads to use.""",
)
args = parser.parse_args()
ROOT.EnableImplicitMT(args.num_cpu)
df = ROOT.ROOT.RDataFrame("cbmsim", args.inputfiles)
df = df.Filter(
"Digi_AdvMuFilterHits.GetEntries() || Digi_AdvTargetHits.GetEntries()"
)
count = df.Count()
ROOT.gInterpreter.ProcessLine('#include "ShipMCTrack.h"')
ROOT.gInterpreter.ProcessLine('#include "AdvTargetHit.h"')
ROOT.gInterpreter.ProcessLine('#include "AdvMuFilterHit.h"')
ROOT.gInterpreter.Declare(
"""
int station_from_id(int id) {
return id >>17;
}
"""
)
ROOT.gInterpreter.Declare(
"""
template<typename T>
ROOT::RVec<T> Deduplicate (ROOT::RVec<T> v){
std::sort(v.begin(), v.end());
auto last = std::unique(v.begin(), v.end());
v.erase(last, v.end());
return v;
}
"""
)
df = (
df.Define("start_z", "dynamic_cast<ShipMCTrack*>(MCTrack[1])->GetStartZ()")
.Define("nu_energy", "dynamic_cast<ShipMCTrack*>(MCTrack[0])->GetEnergy()")
.Define("energy_dep_target", "Sum(AdvTargetPoint.fELoss)")
.Define("energy_dep_mufilter", "Sum(AdvMuFilterPoint.fELoss)")
.Define(
"target_stations", "Map(Digi_AdvTargetHits.fDetectorID, station_from_id)"
)
.Define(
"mufilter_stations",
"Map(Digi_AdvMuFilterHits.fDetectorID, station_from_id)",
)
.Define("target_n_stations", "Deduplicate(target_stations).size()")
.Define("mufilter_n_stations", "Deduplicate(mufilter_stations).size()")
.Define("target_n_hits", "Digi_AdvTargetHits.GetEntries()")
.Define("mufilter_n_hits", "Digi_AdvMuFilterHits.GetEntries()")
)
for i in range(100):
df = df.Define(
f"target_n_hits_station_{i}",
f"std::count(target_stations.begin(), target_stations.end(), {i})",
)
for i in range(20):
df = df.Define(
f"mufilter_n_hits_station_{i}",
f"std::count(mufilter_stations.begin(), mufilter_stations.end(), {i})",
)
col_names = (
[
"start_z",
"nu_energy",
"energy_dep_target",
"energy_dep_mufilter",
"target_n_hits",
"target_n_stations",
"mufilter_n_hits",
"mufilter_n_stations",
]
+ [f"target_n_hits_station_{i}" for i in range(100)]
+ [f"mufilter_n_hits_station_{i}" for i in range(20)]
)
cols = df.AsNumpy(col_names)
n_events = count.GetValue()
pandas_df = pd.DataFrame(cols)
pandas_df.to_csv(f"features_{n_events}.csv")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()