Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 57 additions & 17 deletions sotodlib/io/g3tsmurf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import os
import numpy as np
import logging
import datetime as dt

from sotodlib.io.load_smurf import load_file, SmurfStatus
from sotodlib.io.load_smurf import load_file, SmurfStatus, make_datetime
from sotodlib.io.g3tsmurf_db import Observations, Files


logger = logging.getLogger(__name__)



def get_obs_folder(obs_id, archive):
"""
Get the folder associated with the observation action. Assumes
Expand Down Expand Up @@ -69,6 +71,9 @@ def get_batch(
n_samps=None,
det_chunks=None,
samp_chunks=None,
start = None,
end = None,
startend_buffer = 10,
test = False,
load_file_args={},
):
Expand All @@ -90,7 +95,7 @@ def get_batch(

Arguments
----------
obs_id : string
obs_id : string, or Observatin object
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spelling

Level 2 observation IDs
archive : G3tSmurf Instance
The G3tSmurf database connected to the obs_id
Expand All @@ -117,6 +122,10 @@ def get_batch(
samp_chunks: None or list of tuples
if specified, each entry in the list is successively passed to load the
AxisManagers as `load_file(... samples = list[i] ... )`
start: Datetime or timestamp. Begin batching at this time.
end: Datetime or timestamp. End betching at this time.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spelling

startend_buffer: int. Seconds of buffer to add on either side samples estimated
to correspond with start, end based on sampling rate.
test: bool
If true, yields a tuple of (det_chunks, samp_chunks) instead of a loaded
AxisManager
Expand All @@ -129,13 +138,41 @@ def get_batch(
"""

session = archive.Session()
obs = session.query(Observations).filter(Observations.obs_id==obs_id).one()

if isinstance(obs_id, Observations):
obs = obs_id
obs_id = obs.obs_id
else:
obs = session.query(Observations).filter(Observations.obs_id==obs_id).one()
db_files = session.query(Files).filter(Files.obs_id==obs_id).order_by(Files.start)
filenames = sorted( [f.name for f in db_files])

ts = obs.tunesets[0] ## if this throws an error we have some fallbacks
obs_dets, obs_samps = len(ts.dets), obs.n_samples
session.close()

## should we let folks overwrite this here?
if "archive" in load_file_args:
archive = load_file_args.pop("archive")

if "status" in load_file_args:
status = load_file_args.pop("status")
else:
status = SmurfStatus.from_file(filenames[0])

if start == None:
start = obs.start
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if start is None: (I don't think == is standard for getting Nones in python)

if start is None can we just set samp_s to 0 and skip the math? that will avoid rounding errors and missed edge cases by default. Same if stop is None, just set samp_e to obs_samps.

if end == None:
end = obe.stop
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo. did you test this?

samprate = 4e3 / (status.downsample_enabled * status.downsample_factor)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this divide by 0 if status.downsample_enabled is False?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's definitely what it does. Will fix

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect status.downsample_factor is 1 when it isn't downsampled. But I may be wrong.

samp_s = (make_datetime(start).timestamp() - obs.start.timestamp()) * samprate \
- startend_buffer * samprate
samp_s = max(int(samp_s), 0)
samp_e = obs_samps - (obs.stop.timestamp() - make_datetime(end).timestamp()) * samprate \
+ startend_buffer * samprate
samp_e = min(obs_samps, int(samp_e))
obs_samps = samp_e-samp_s


if n_det_chunks is not None and n_dets is not None:
logger.warning("Both n_det_chunks and n_dets specified, n_det_chunks overrides")
Expand All @@ -159,29 +196,20 @@ def get_batch(
det_chunks = [range(i*n_dets,min((i+1)*n_dets,obs_dets)) for i in range(n_det_chunks)]
if n_samp_chunks is not None:
n_samps = int(np.ceil(obs_samps/n_samp_chunks))
samp_chunks = [(i*n_samps,min((i+1)*n_samps,obs_samps)) for i in range(n_samp_chunks)]
samp_chunks = [(samp_s+i*n_samps,min((i+1)*n_samps+samp_s,samp_e)) for i in range(n_samp_chunks)]

if n_dets is not None:
n_det_chunks = int(np.ceil(obs_dets/n_dets))
det_chunks = [range(i*n_dets,min((i+1)*n_dets,obs_dets)) for i in range(n_det_chunks)]
if n_samps is not None:
n_samp_chunks = int(np.ceil(obs_samps/n_samps))
samp_chunks = [(i*n_samps,min((i+1)*n_samps,obs_samps)) for i in range(n_samp_chunks)]
samp_chunks = [(samp_s+i*n_samps,min((i+1)*n_samps+samp_s,samp_e)) for i in range(n_samp_chunks)]

if det_chunks is None:
det_chunks = [range(0,obs_dets)]
if samp_chunks is None:
samp_chunks = [(0,obs_samps)]

## should we let folks overwrite this here?
if "archive" in load_file_args:
archive = load_file_args.pop("archive")

if "status" in load_file_args:
status = load_file_args.pop("status")
else:
status = SmurfStatus.from_file(filenames[0])

samp_chunks = [(samp_s,samp_e)]

logger.debug(f"Loading data with det_chunks: {det_chunks}.")
logger.debug(f"Loading data in samp_chunks: {samp_chunks}.")

Expand All @@ -191,14 +219,26 @@ def get_batch(
if test:
yield (det_chunk, samp_chunk)
else:
yield load_file(
aman = load_file(
filenames,
channels=det_chunk,
samples=samp_chunk,
archive=archive,
status=status,
**load_file_args,
)
if samp_s !=0 or samp_e != obs.n_samples:
msk = np.all(
[aman.timestamps >= start.timestamp(), aman.timestamps < end.timestamp()],
axis=0,
)
idx = np.where(msk)[0]
if len(idx) == 0:
aman.restrict("samps", (aman.samps.offset, aman.samps.offset))
else:
aman.restrict("samps", (aman.samps.offset+idx[0], aman.samps.offset+idx[-1]))
yield aman

except GeneratorExit:
pass

62 changes: 33 additions & 29 deletions sotodlib/io/load_smurf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,34 @@
num_bias_lines = 16


"""
Used for the input of many functions below
"""
def make_datetime(x):
"""
Takes an input (either a timestamp or datetime), and returns a datetime.
Intended to allow flexibility in inputs for various other functions
Note that x will be assumed to be in UTC if timezone is not specified

Args
----
x: input datetime or timestamp

Returns
----
datetime: datetime of x if x is a timestamp
"""
if np.issubdtype(type(x), np.floating) or np.issubdtype(type(x), np.integer):
return dt.datetime.utcfromtimestamp(x)
elif isinstance(x, np.datetime64):
return x.astype(dt.datetime).replace(tzinfo=dt.timezone.utc)
elif isinstance(x, dt.datetime) or isinstance(x, dt.date):
if x.tzinfo == None:
return x.replace(tzinfo=dt.timezone.utc)
return x
raise (Exception("Input not a datetime or timestamp"))


"""
Actions used to define when observations happen
Could be expanded to other Action Based Indexing as well
Expand Down Expand Up @@ -202,30 +230,6 @@ def from_configs(cls, configs):
configs["g3tsmurf_db"],
meta_path=os.path.join(configs["data_prefix"],"smurf"))

@staticmethod
def _make_datetime(x):
"""
Takes an input (either a timestamp or datetime), and returns a datetime.
Intended to allow flexibility in inputs for various other functions
Note that x will be assumed to be in UTC if timezone is not specified

Args
----
x: input datetime of timestamp

Returns
----
datetime: datetime of x if x is a timestamp
"""
if np.issubdtype(type(x), np.floating) or np.issubdtype(type(x), np.integer):
return dt.datetime.utcfromtimestamp(x)
elif isinstance(x, np.datetime64):
return x.astype(dt.datetime).replace(tzinfo=dt.timezone.utc)
elif isinstance(x, dt.datetime) or isinstance(x, dt.date):
if x.tzinfo == None:
return x.replace(tzinfo=dt.timezone.utc)
return x
raise (Exception("Input not a datetime or timestamp"))

def add_file(self, path, session, overwrite=False):
"""
Expand Down Expand Up @@ -1162,8 +1166,8 @@ def _stream_ids_in_range(self, start, end):
stream_ids: List of stream ids.
"""
session = self.Session()
start = self._make_datetime(start)
end = self._make_datetime(end)
start = make_datetime(start)
end = make_datetime(end)
all_ids = (
session.query(Files.stream_id)
.filter(Files.start < end, Files.stop >= start)
Expand Down Expand Up @@ -1238,8 +1242,8 @@ def load_data(

"""
session = self.Session()
start = self._make_datetime(start)
end = self._make_datetime(end)
start = make_datetime(start)
end = make_datetime(end)

if stream_id is None:
sids = self._stream_ids_in_range(start, end)
Expand Down Expand Up @@ -1632,7 +1636,7 @@ def from_time(cls, time, archive, stream_id=None, show_pb=False):
status : (SmurfStatus instance)
object indexing of rogue variables at specified time.
"""
time = archive._make_datetime(time)
time = make_datetime(time)
session = archive.Session()
q = (
session.query(Frames)
Expand Down