-
Notifications
You must be signed in to change notification settings - Fork 21
Loading batcher extension #305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
dccb1d8
ecd2435
7c03554
8882523
302e09f
863ed34
a86bd1f
ea387e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,14 +4,16 @@ | |
| import os | ||
| import numpy as np | ||
| import logging | ||
| import datetime as dt | ||
|
|
||
| from sotodlib.io.load_smurf import load_file, SmurfStatus | ||
| from sotodlib.io.load_smurf import load_file, SmurfStatus, make_datetime | ||
| from sotodlib.io.g3tsmurf_db import Observations, Files | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
|
|
||
| def get_obs_folder(obs_id, archive): | ||
| """ | ||
| Get the folder associated with the observation action. Assumes | ||
|
|
@@ -69,6 +71,9 @@ def get_batch( | |
| n_samps=None, | ||
| det_chunks=None, | ||
| samp_chunks=None, | ||
| start = None, | ||
| end = None, | ||
| startend_buffer = 10, | ||
| test = False, | ||
| load_file_args={}, | ||
| ): | ||
|
|
@@ -90,7 +95,7 @@ def get_batch( | |
|
|
||
| Arguments | ||
| ---------- | ||
| obs_id : string | ||
| obs_id : string, or Observatin object | ||
| Level 2 observation IDs | ||
| archive : G3tSmurf Instance | ||
| The G3tSmurf database connected to the obs_id | ||
|
|
@@ -117,6 +122,10 @@ def get_batch( | |
| samp_chunks: None or list of tuples | ||
| if specified, each entry in the list is successively passed to load the | ||
| AxisManagers as `load_file(... samples = list[i] ... )` | ||
| start: Datetime or timestamp. Begin batching at this time. | ||
| end: Datetime or timestamp. End betching at this time. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. spelling |
||
| startend_buffer: int. Seconds of buffer to add on either side samples estimated | ||
| to correspond with start, end based on sampling rate. | ||
| test: bool | ||
| If true, yields a tuple of (det_chunks, samp_chunks) instead of a loaded | ||
| AxisManager | ||
|
|
@@ -129,13 +138,41 @@ def get_batch( | |
| """ | ||
|
|
||
| session = archive.Session() | ||
| obs = session.query(Observations).filter(Observations.obs_id==obs_id).one() | ||
|
|
||
| if isinstance(obs_id, Observations): | ||
| obs = obs_id | ||
| obs_id = obs.obs_id | ||
| else: | ||
| obs = session.query(Observations).filter(Observations.obs_id==obs_id).one() | ||
| db_files = session.query(Files).filter(Files.obs_id==obs_id).order_by(Files.start) | ||
| filenames = sorted( [f.name for f in db_files]) | ||
|
|
||
| ts = obs.tunesets[0] ## if this throws an error we have some fallbacks | ||
| obs_dets, obs_samps = len(ts.dets), obs.n_samples | ||
| session.close() | ||
|
|
||
| ## should we let folks overwrite this here? | ||
| if "archive" in load_file_args: | ||
| archive = load_file_args.pop("archive") | ||
|
|
||
| if "status" in load_file_args: | ||
| status = load_file_args.pop("status") | ||
| else: | ||
| status = SmurfStatus.from_file(filenames[0]) | ||
|
|
||
| if start == None: | ||
| start = obs.start | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
if start is None can we just set |
||
| if end == None: | ||
| end = obe.stop | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo. did you test this? |
||
| samprate = 4e3 / (status.downsample_enabled * status.downsample_factor) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this divide by 0 if
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's definitely what it does. Will fix
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I expect |
||
| samp_s = (make_datetime(start).timestamp() - obs.start.timestamp()) * samprate \ | ||
| - startend_buffer * samprate | ||
| samp_s = max(int(samp_s), 0) | ||
| samp_e = obs_samps - (obs.stop.timestamp() - make_datetime(end).timestamp()) * samprate \ | ||
| + startend_buffer * samprate | ||
| samp_e = min(obs_samps, int(samp_e)) | ||
| obs_samps = samp_e-samp_s | ||
|
|
||
|
|
||
| if n_det_chunks is not None and n_dets is not None: | ||
| logger.warning("Both n_det_chunks and n_dets specified, n_det_chunks overrides") | ||
|
|
@@ -159,29 +196,20 @@ def get_batch( | |
| det_chunks = [range(i*n_dets,min((i+1)*n_dets,obs_dets)) for i in range(n_det_chunks)] | ||
| if n_samp_chunks is not None: | ||
| n_samps = int(np.ceil(obs_samps/n_samp_chunks)) | ||
| samp_chunks = [(i*n_samps,min((i+1)*n_samps,obs_samps)) for i in range(n_samp_chunks)] | ||
| samp_chunks = [(samp_s+i*n_samps,min((i+1)*n_samps+samp_s,samp_e)) for i in range(n_samp_chunks)] | ||
|
|
||
| if n_dets is not None: | ||
| n_det_chunks = int(np.ceil(obs_dets/n_dets)) | ||
| det_chunks = [range(i*n_dets,min((i+1)*n_dets,obs_dets)) for i in range(n_det_chunks)] | ||
| if n_samps is not None: | ||
| n_samp_chunks = int(np.ceil(obs_samps/n_samps)) | ||
| samp_chunks = [(i*n_samps,min((i+1)*n_samps,obs_samps)) for i in range(n_samp_chunks)] | ||
| samp_chunks = [(samp_s+i*n_samps,min((i+1)*n_samps+samp_s,samp_e)) for i in range(n_samp_chunks)] | ||
|
|
||
| if det_chunks is None: | ||
| det_chunks = [range(0,obs_dets)] | ||
| if samp_chunks is None: | ||
| samp_chunks = [(0,obs_samps)] | ||
|
|
||
| ## should we let folks overwrite this here? | ||
| if "archive" in load_file_args: | ||
| archive = load_file_args.pop("archive") | ||
|
|
||
| if "status" in load_file_args: | ||
| status = load_file_args.pop("status") | ||
| else: | ||
| status = SmurfStatus.from_file(filenames[0]) | ||
|
|
||
| samp_chunks = [(samp_s,samp_e)] | ||
|
|
||
| logger.debug(f"Loading data with det_chunks: {det_chunks}.") | ||
| logger.debug(f"Loading data in samp_chunks: {samp_chunks}.") | ||
|
|
||
|
|
@@ -191,14 +219,26 @@ def get_batch( | |
| if test: | ||
| yield (det_chunk, samp_chunk) | ||
| else: | ||
| yield load_file( | ||
| aman = load_file( | ||
| filenames, | ||
| channels=det_chunk, | ||
| samples=samp_chunk, | ||
| archive=archive, | ||
| status=status, | ||
| **load_file_args, | ||
| ) | ||
| if samp_s !=0 or samp_e != obs.n_samples: | ||
| msk = np.all( | ||
| [aman.timestamps >= start.timestamp(), aman.timestamps < end.timestamp()], | ||
| axis=0, | ||
| ) | ||
| idx = np.where(msk)[0] | ||
| if len(idx) == 0: | ||
| aman.restrict("samps", (aman.samps.offset, aman.samps.offset)) | ||
| else: | ||
| aman.restrict("samps", (aman.samps.offset+idx[0], aman.samps.offset+idx[-1])) | ||
| yield aman | ||
|
|
||
| except GeneratorExit: | ||
| pass | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
spelling