Skip to content
24 changes: 20 additions & 4 deletions src/sntools/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from scipy import integrate, interpolate


def gen_evts(_channel, _flux, n_targets, seed, verbose):
def gen_evts(_channel, _flux, mode, binsize, n_targets, seed, verbose):
"""Generate events.

* Get event rate by interpolating from time steps in the input data.
Expand Down Expand Up @@ -41,20 +41,36 @@ def gen_evts(_channel, _flux, n_targets, seed, verbose):
for t in flux.raw_times]
event_rate = interpolate.pchip(flux.raw_times, raw_nevts)

bin_width = 1 # in ms
# appropriate bin width is different for each mode
if mode == "sn":
bin_width = 1 # in ms
elif mode == "presn":
bin_width = binsize*60*1000 # converting minutes into ms


n_bins = int((flux.endtime - flux.starttime) / bin_width) # number of full-width bins; int() implies floor()
if verbose:
print(f"[{tag}] Generating events in {bin_width} ms bins from {flux.starttime} to {flux.endtime} ms ...")
if mode == "presn":
print(f"[{tag}] Generating events in {binsize} minute bins from {(flux.starttime)/60000} to {(flux.endtime)/60000} minutes ...")
else:
print(f"[{tag}] Generating events in {bin_width} ms bins from {flux.starttime} to {flux.endtime} ms ...")

# scipy is optimized for operating on large arrays, making it orders of
# magnitude faster to pre-compute all values of the interpolated functions.
binned_t = [flux.starttime + (i + 0.5) * bin_width for i in range(n_bins)]
binned_nevt_th = event_rate(binned_t)

# if bin width is not 1ms, binned_nevt_th must be multipled by bin width to get the correct rate of events as event_rate is in units of per ms.
binned_nevt_th = event_rate(binned_t) * bin_width

# check for unphysical values of interpolated function event_rate(t)
for _i, _n in enumerate(binned_nevt_th):
if _n < 0:
binned_nevt_th[_i] = 0


binned_nevt = np.random.poisson(binned_nevt_th) # Get random number of events in each bin from Poisson distribution


flux.prepare_evt_gen(binned_t) # give flux script a chance to pre-compute values

events = []
Expand Down
15 changes: 11 additions & 4 deletions src/sntools/formats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(self, sn_model, flv, starttime, endtime) -> None:
self._sn_model = sn_model

times = [t.to(u.ms).value for t in sn_model.get_time()]
times.sort()
self.starttime = get_starttime(starttime, times[0])
self.endtime = get_endtime(endtime, times[-1])
self.raw_times = get_raw_times(times, self.starttime, self.endtime)
Expand All @@ -155,14 +156,20 @@ class SNEWPYCompositeFlux(CompositeFlux):
"""Adapter class to turn a SNEWPY.models.SupernovaModel into an sntools.formats.CompositeFlux"""

@classmethod
def from_file(cls, file, format, starttime=None, endtime=None):
def from_file(cls, file, mode, format, starttime=None, endtime=None):
"""Create a SNEWPYCompositeFlux from an input file."""
self = SNEWPYCompositeFlux()
self._repr = f"SNEWPYCompositeFlux.from_file('{file}', format='{format}', starttime={starttime}, endtime={endtime})"

# snewpy.models.loaders classes treat relative file paths as relative to the snewpy cache directory,
# so we’ll turn it into an absolute path first.
sn_model = getattr(import_module('snewpy.models.ccsn_loaders'), format)(abspath(file))

# need an if statement here to import the snewpy.models.presn_loaders module if the mode is set to presn.
if mode == "sn":
sn_model = getattr(import_module('snewpy.models.ccsn_loaders'), format)(abspath(file))
elif mode == "presn":
sn_model = getattr(import_module('snewpy.models.presn_loaders'), format)(abspath(file))


for flv in ('e', 'eb', 'x', 'xb'):
f = SNEWPYFlux(sn_model, flv, starttime, endtime)
Expand All @@ -181,7 +188,7 @@ def get_starttime(starttime, minimum):
if starttime is None:
starttime = ceil(minimum)
elif starttime < minimum:
raise ValueError(f"Start time cannot be earlier than {minimum} (first entry in input file).")
raise ValueError(f"Start time cannot be earlier than {minimum} ms (first entry in input file).")
return starttime


Expand All @@ -195,7 +202,7 @@ def get_endtime(endtime, maximum):
if endtime is None:
endtime = floor(maximum)
elif endtime > maximum:
raise ValueError(f"End time cannot be later than {maximum} (last entry in input file).")
raise ValueError(f"End time cannot be later than {maximum} ms (last entry in input file).")
return endtime


Expand Down
75 changes: 58 additions & 17 deletions src/sntools/genevts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,44 @@ def main():
flux_at_detector = args.flux.transformed_by(args.transformation, args.distance)

# Generate events for each (sub-)channel and combine them
pool = ProcessPoolExecutor(max_workers=args.maxworkers)
results = []
for channel in sorted(args.channels):
mod_channel = import_module("sntools.interaction_channels." + channel)
n_targets = args.detector.n_molecules * args.detector.material["channel_weights"][channel]
for flv in mod_channel.possible_flavors:
channel_instance = mod_channel.Channel(flv)
for flux in flux_at_detector.components[flv]:
results.append(pool.submit(gen_evts, channel_instance, flux, n_targets, args.randomseed + random.random(), args.verbose))

events = []
for result in as_completed(results):
events.extend(result.result())


# using the process pool executor for sn mode
if args.mode == "sn":
print("we are going through the ProcessPoolExecutor route")
pool = ProcessPoolExecutor(max_workers=args.maxworkers)
results = []

for channel in sorted(args.channels):
mod_channel = import_module("sntools.interaction_channels." + channel)
n_targets = args.detector.n_molecules * args.detector.material["channel_weights"][channel]
for flv in mod_channel.possible_flavors:
channel_instance = mod_channel.Channel(flv)
for flux in flux_at_detector.components[flv]:
results.append(pool.submit(gen_evts, channel_instance, flux, args.mode, args.binsize, n_targets, args.randomseed + random.random(), args.verbose))

events = []
for result in as_completed(results):
events.extend(result.result())


# but the process pool executor isn't compatible with presn code from snewpy so we can't use it in this case.
elif args.mode == "presn":
print("We are not in the ProcessPoolExecutor loop")
events = []

for channel in sorted(args.channels):
mod_channel = import_module("sntools.interaction_channels." + channel)
n_targets = args.detector.n_molecules * args.detector.material["channel_weights"][channel]
for flv in mod_channel.possible_flavors:
channel_instance = mod_channel.Channel(flv)
for flux in flux_at_detector.components[flv]:
events.extend(gen_evts(_channel=channel_instance, _flux=flux, mode=args.mode, binsize=args.binsize, n_targets=n_targets, seed=args.randomseed + random.random(), verbose=args.verbose))



#===================================================================================================================================================


# Sort events by time and write them to an output file
events.sort(key=lambda evt: evt.time)
Expand Down Expand Up @@ -77,10 +102,18 @@ def parse_command_line_options():

parser.add_argument("input_file", help="Name or common prefix of the input file(s). Required.")

# comment in place of a possible new argument to specify pre SN fluxes needed
choices = ("sn", "presn")

parser.add_argument("--mode", metavar="MODE", choices=choices, default=choices[0],
help="Mode of operation: supernova burst or pre supernova. Choices: %(choices)s. Default: %(default)s.")
Comment thread
eobrien2502 marked this conversation as resolved.
Outdated

choices = ("gamma", "nakazato", "princeton", "totani", "warren2020",
"SNEWPY-Bollig_2016", "SNEWPY-Fornax_2021", "SNEWPY-Fornax_2022", "SNEWPY-Kuroda_2020",
"SNEWPY-Mori_2023", "SNEWPY-Nakazato_2013", "SNEWPY-OConnor_2015", "SNEWPY-Sukhbold_2015",
"SNEWPY-Tamborra_2014", "SNEWPY-Walk_2018", "SNEWPY-Walk_2019", "SNEWPY-Zha_2021")
"SNEWPY-Tamborra_2014", "SNEWPY-Walk_2018", "SNEWPY-Walk_2019", "SNEWPY-Zha_2021", "SNEWPY-Odrzywolek_2010",
"SNEWPY-Patton_2017", "SNEWPY-Kato_2017", "SNEWPY-Yoshida_2016")
Comment thread
eobrien2502 marked this conversation as resolved.

parser.add_argument("-f", "--format", metavar="FORMAT", choices=choices, default=choices[1],
help="Format of input file(s). Choices: %(choices)s. Default: %(default)s.")

Expand Down Expand Up @@ -113,10 +146,13 @@ def parse_command_line_options():
parser.add_argument("--distance", type=float, default=10.0, help="Distance to supernova in kpc. Default: %(default)s.")

parser.add_argument("--starttime", metavar="T", type=float,
help="Start generating events at T milliseconds. Default: First time bin in input file.")
help="Start generating events at T milliseconds (ccsn) or T minutes (presn). Default: First time bin in input file.")

parser.add_argument("--endtime", metavar="T", type=float,
help="Stop generating events at T milliseconds. Default: Last time bin in input file.")
help="Stop generating events at T milliseconds (ccsn) or T minutes (presn). Default: Last time bin in input file.")

parser.add_argument("--binsize", metavar="BIN", type=float, default=0.5,
help="Size of bins used in presn rate calculations, given in minutes. Default: 0.5 minutes (30 seconds).")

parser.add_argument("--randomseed", metavar="SEED", default=random.randint(0, 2**32 - 1), type=int, # non-ints may not give reproducible results
help="Integer between 0 and 2^32 - 1 used as a random number seed to reproducibly generate events. Default: Random.")
Expand All @@ -132,8 +168,13 @@ def parse_command_line_options():
args.detector = Detector(args.detector)
args.channels = args.detector.material["channel_weights"] if args.channel == "all" else [args.channel]

# converting minutes into milliseconds for the presn mode
if args.mode == "presn":
args.starttime = args.starttime * 60* 1000
args.endtime = args.endtime * 60 * 1000

if args.format[:7] == "SNEWPY-":
args.flux = SNEWPYCompositeFlux.from_file(args.input_file, args.format[7:], args.starttime, args.endtime)
args.flux = SNEWPYCompositeFlux.from_file(args.input_file, args.mode, args.format[7:], args.starttime, args.endtime)
else:
args.flux = CompositeFlux.from_file(args.input_file, args.format, args.starttime, args.endtime)

Expand Down
Loading