Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions sotodlib/preprocess/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,94 @@ def select(self, meta, proc_aman=None, in_place=True):
else:
return keep

class AzssGlobalTemplate(_Preprocess):
"""Fit and subtract a global azimuth synchronous signal (AzSS) template
from each detector's timestream. The template is a per-az-bin signal
(e.g. from ``azss.get_azss``) that's averaged over detectors and observations
then interpolated to sample times. A per-detector amplitude coefficient is fit
via masked linear regression, excludingflagged samples. Detectors with
insufficient valid samples are flagged as bad and can be cut in the select step.

Saves results in proc_aman under the ``azss_global_template`` field.

Example config block::

- name: "subtract_azss_global_template"
calc:
template: "azss_stats" # key in aman or string name
signal: "signal"
flags: "glitches" # key in aman.flags
min_valid_samples: 12000
save: True
process:
subtract_in_place: True
select:
kind: "any"
"""

name = "azss_global_template"

def __init__(self, step_cfgs):
self.save_name = "azss_global_template"
super().__init__(step_cfgs)

def calc_and_save(self, aman, proc_aman):
coeffs, bad_dets = subtract_azss_global_template(aman, **self.calc_cfgs)

azss_gt_aman = core.AxisManager(aman.dets)
azss_gt_aman.wrap("coeffs", coeffs, [(0, "dets")])
azss_gt_aman.wrap("bad_dets", bad_dets, [(0, "dets")])

self.save(proc_aman, azss_gt_aman)
return aman, proc_aman

def save(self, proc_aman, azss_gt_aman):
if self.save_cfgs is None:
return
if self.save_cfgs:
proc_aman.wrap(self.save_name, azss_gt_aman)

def process(self, aman, proc_aman, sim=False, data_aman=None):
if self.process_cfgs is None:
return aman, proc_aman

template = aman[self.calc_cfgs.get("template", "azss_stats")]
signal_key = self.calc_cfgs.get("signal", "signal")
flags_key = self.calc_cfgs.get("flags", "glitches")

f_template = interp1d(
template.binned_az, template.binned_signal, fill_value="extrapolate"
)
template_samps = f_template(aman.boresight.az)

coeffs = proc_aman[self.save_name].coeffs
bad_dets = proc_aman[self.save_name].bad_dets
valid_dets = ~bad_dets

if valid_dets.sum() == 0:
logger.warning("No valid detectors for AzSS global template subtraction")
return aman, proc_aman

aman[signal_key][valid_dets] -= (
coeffs[valid_dets, np.newaxis] * template_samps[np.newaxis, :]
)
return aman, proc_aman

def select(self, meta, proc_aman=None, in_place=True):
if self.select_cfgs is None:
return meta

if proc_aman is None:
proc_aman = meta.preprocess

bad_dets = proc_aman[self.save_name].bad_dets
keep = ~bad_dets

if in_place:
meta.restrict("dets", meta.dets.vals[keep])
return meta
else:
return keep

class SubtractAzSSTemplate(_Preprocess):
"""Subtract Azimuth Synchronous Signal (AzSS) common template.
Expand Down
34 changes: 34 additions & 0 deletions sotodlib/tod_ops/azss.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,3 +548,37 @@ def subtract_azss_template(
if subtract:
signal[:, scan_flags] -= model[:, scan_flags]
return model

def subtract_azss_global_template(aman, template, signal, flags, min_valid_samples=12000,
subtract_in_place=False):
"""
"""
if isinstance(template, str):
template = aman[template]
if isinstance(signal, str):
signal = aman.signal

f_template = interp1d(template.binned_az, template.binned_signal, fill_value='extrapolate')
template_samps = f_template(aman.boresight.az)

m = flags.mask()
n_valid = (~m).sum(axis=1)
bad_dets = n_valid < min_valid_samples
if np.sum(~bad_dets) == 0:
logger.warning('No detectors have enough valid samples')
return coeffs, bad_dets

valid = ~m[~bad_dets]
signal_fit = signal[~bad_dets]
template_masked = template_samps[np.newaxis, :] * valid

VtV = (template_masked ** 2).sum(axis=1)
Vty = (vects_masked * signal_fit).sum(axis=1)

coeffs[~bad_dets] = Vty / VtV
if subtract_in_place:
signal[~bad_dets] -= coeffs[~bad_dets, np.newaxis] * template_masked[np.newaxis,:]
return coeffs, bad_dets