Skip to content

Add loader and det cal functionality to load in bandpass info using bg #1206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
12 changes: 7 additions & 5 deletions sotodlib/core/metadata/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,11 +845,13 @@ class MetadataSpec:
with fields found in this metadata item. See notes below.

``load_fields`` (list of str or None)
List of fields to load. This may include entire child
AxisManagers, or fields within them using "." for hierarchical
addressing. This is only for AxisManager metadata. Default
is None, which meaning to load all fields. Wildcards are not
supported.
List of fields to load. For AxisManager metadata this may
include entire child AxisManagers, or fields within them
using "." for hierarchical addressing. For ResultSet metadata,
subitems of this list can be specified as dictionaries
(original_fieldname -> new_fieldname) which allow the user to
change the fieldname of the subitem. Default is None, which
means to load all fields. Wildcards are not supported.

``drop_fields`` (list of str)
List of fields (which may contain wildcard character ``*``) to
Expand Down
37 changes: 29 additions & 8 deletions sotodlib/io/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,38 @@ def from_loadspec(self, load_params, **kwargs):


class ResultSetHdfLoader(LoaderInterface):
def _prefilter_data(self, data_in, key_map={}):
"""When a dataset is loaded and converted to a structured numpy
array, this function is called before the data are returned to
the user. The key_map can be used to rename fields, on load.
def _check_key_map(self, key_map, data_in):
"""Check that all keys in key_map are valid fields in data_in."""
for key in key_map.keys():
if key not in data_in.dtype.names:
raise KeyError(f"{key} not included inside dataset with keys"
f"{data_in.dtype.names}")

def _prefilter_data(self, data_in, load_fields=None):
"""When a dataset is loaded and converted to a structured numpy array,
this function is called before the data are returned to the user. If
'load_fields' has data for this instance, this function adds any fields
specified within to the key map and checks that these fields are all
valid fields in data_in.

This function may be extended in subclasses, but you will
likely want to call the super() handler before doing
additional processing. The loading functions do not pass in
key_map -- this is for the exclusive use of subclasses.

additional processing.
"""
key_map = {}
if load_fields is not None:
for field in load_fields:
if isinstance(field, dict):
key_map.update(field)
else:
key_map[field] = field

for k in data_in.dtype.names:
if k not in key_map.keys():
key_map[k] = None

self._check_key_map(key_map, data_in)

return _decode_array(data_in, key_map=key_map)

def _populate(self, data, keys=None, row_order=None):
Expand Down Expand Up @@ -192,7 +213,7 @@ def batch_from_loadspec(self, load_params, **kwargs):
dataset = load_params[idx]['dataset']
if dataset is not last_dataset:
data = fin[dataset][()]
data = self._prefilter_data(data)
data = self._prefilter_data(data, **kwargs)
last_dataset = dataset

# Dereference the extrinsic axis request. Every
Expand Down
7 changes: 4 additions & 3 deletions sotodlib/preprocess/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,7 @@ class PCARelCal(_Preprocess):
def __init__(self, step_cfgs):
self.signal = step_cfgs.get('signal', 'signal')
self.run = step_cfgs.get('pca_run', 'run1')
self.bandpass = step_cfgs.get('bandpass_key', 'wafer.bandpass')
self.run_name = f'{self.signal}_{self.run}'

super().__init__(step_cfgs)
Expand All @@ -1366,15 +1367,15 @@ def calc_and_save(self, aman, proc_aman):
if self.plot_cfgs:
self.plot_signal = filt_aman[self.signal]

bands = np.unique(aman.det_info.wafer.bandpass)
bands = np.unique(aman.det_info[self.bandpass])
bands = bands[bands != 'NC']
# align samps w/ proc_aman to include samps restriction when loading back from db.
rc_aman = core.AxisManager(proc_aman.dets, proc_aman.samps)
pca_det_mask = np.full(aman.dets.count, False, dtype=bool)
relcal = np.zeros(aman.dets.count)
pca_weight0 = np.zeros(aman.dets.count)
for band in bands:
m0 = aman.det_info.wafer.bandpass == band
m0 = aman.det_info[self.bandpass] == band
if self.plot_cfgs is not None:
rc_aman.wrap(f'{band}_idx', m0, [(0, 'dets')])
band_aman = aman.restrict('dets', aman.dets.vals[m0], in_place=False)
Expand Down Expand Up @@ -1430,7 +1431,7 @@ def plot(self, aman, proc_aman, filename):
det = aman.dets.vals[0]
ufm = det.split('_')[2]

bands = np.unique(aman.det_info.wafer.bandpass)
bands = np.unique(aman.det_info[self.bandpass])
bands = bands[bands != 'NC']
for band in bands:
if f'{band}_pca_mode0' in proc_aman[self.run_name]:
Expand Down
18 changes: 18 additions & 0 deletions sotodlib/site_pipeline/update_det_cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
DEFAULT_pA_per_phi0 = 9e6
TES_BIAS_COUNT = 12 # per detset / primary file group

# For converting bias group to bandpass.
BGS = {'lb': [0, 1, 4, 5, 8, 9], 'hb': [2, 3, 6, 7, 10, 11]}
BAND_STR = {'mf': {'lb': 'f090', 'hb': 'f150'},
'uhf': {'lb': 'f220', 'hb': 'f280'},
'lf': {'lb': 'f030', 'hb': 'f040'}}

logger = logging.getLogger("det_cal")
if not logger.hasHandlers():
sp_util.init_logger("det_cal")
Expand Down Expand Up @@ -248,6 +254,8 @@ class CalInfo:
Current responsivity of the TES [1/V] computed using bias steps at the
bias point. This is based on the naive bias step estimation without
using any additional corrections.
bandpass: str
Detector bandpass, computed from bias group information.
"""

readout_id: str = ""
Expand All @@ -269,6 +277,7 @@ class CalInfo:
naive_r_frac: float = np.nan
naive_p_bias: float = np.nan
naive_s_i: float = np.nan
bandpass: str = "NC"

@classmethod
def dtype(cls) -> List[Tuple[str, Any]]:
Expand All @@ -277,6 +286,9 @@ def dtype(cls) -> List[Tuple[str, Any]]:
for field in fields(cls):
if field.name == "readout_id":
dt: Tuple[str, Any] = ("dets:readout_id", "<U40")
elif field_name == 'bandpass':
# Our bandpass str is at max 4 characters
dt: Tuple[str, Any] = ("bandpass", "<U4")
else:
dt = (field.name, field.type)
dtype.append(dt)
Expand Down Expand Up @@ -617,6 +629,12 @@ def find_correction_results(band, chan, dset):
else:
cal.phase_to_pW = pA_per_phi0 / (2 * np.pi) / cal.s_i * cal.polarity

# Add bandpass informaton from bias group
if cal.bg in BGS['lb']:
cal.bandpass = BAND_STR[tube_flavor]['lb']
elif cal.bg in BGS['hb']:
cal.bandpass = BAND_STR[tube_flavor]['hb']

res.result_set = np.array([astuple(c) for c in cals], dtype=CalInfo.dtype())
res.success = True
except Exception as e:
Expand Down
45 changes: 42 additions & 3 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,19 @@ def test_120_load_fields(self):
self.assertCountEqual(tod.ondisk._fields.keys(), ['disk1', 'subaman'])
self.assertCountEqual(tod.ondisk.subaman._fields.keys(), ['disk2'])

def test_load_fields_resultset(self):
dataset_sim = DatasetSim()
obs_id = dataset_sim.obss['obs_id'][1]
ctx = dataset_sim.get_context(with_resultset_ondisk=True)
tod = ctx.get_obs(obs_id)
# Make sure that 'band1' was loaded into det_info
self.assertTrue('band1' in tod.det_info._fields)
# Make sure that 'band1' laods the same informaton as 'band'
self.assertTrue((tod.det_info.band1 == tod.det_info.band).all())
# Make sure that only 'pol_code' was loaded into ondisc_resultset
self.assertCountEqual(tod.ondisk_resultset._fields.keys(), ['pol_code'])
self.assertTrue((tod.det_info.pol_code == tod.ondisk_resultset.pol_code).all())

def test_200_load_metadata(self):
"""Test the simple metadata load wrapper."""
dataset_sim = DatasetSim()
Expand Down Expand Up @@ -376,8 +389,8 @@ def __init__(self):
self.sample_count = 100

class _TestML(metadata.LoaderInterface):
def from_loadspec(_self, load_params):
return self.metadata_loader(load_params)
def from_loadspec(_self, load_params, **load_kwargs):
return self.metadata_loader(load_params, **load_kwargs)

OBSLOADER_REGISTRY['unittest_loader'] = self.tod_loader
metadata.SuperLoader.register_metadata('unittest_loader', _TestML)
Expand All @@ -388,6 +401,7 @@ def get_context(self, with_detdb=True, with_metadata=True,
with_incomplete_metadata=False,
with_inconcatable=False,
with_axisman_ondisk=False,
with_resultset_ondisk=False,
on_missing='trim'):
"""Args:
with_detdb: if False, no detdb is included.
Expand Down Expand Up @@ -636,10 +650,35 @@ def _db_multi_dataset(filename, detsets=['neard', 'fard']):
output['subaman'].wrap_new(f'disk{i}', shape=('dets',))
output.save(filename, 'xyz')

if with_resultset_ondisk:
_scheme = metadata.ManifestScheme() \
.add_exact_match('obs:obs_id') \
.add_data_field('dataset')
ondisk_db = metadata.ManifestDb(scheme=_scheme)
ondisk_db._tempdir = tempfile.TemporaryDirectory()
filename = os.path.join(ondisk_db._tempdir.name,
'ondisk_resultset_metadata.h5')
write_dataset(self.dets, filename, 'obs_number_12')
ondisk_db.add_entry(
{'obs:obs_id': 'obs_number_12',
'dataset': 'obs_number_12'},
filename)
ctx['metadata'].insert(0, {
'db': ondisk_db,
'det_info': True,
'load_fields': [{'band': 'dets:band1'}, {'readout_id': 'dets:readout_id'}],
})

ctx['metadata'].insert(0, {
'db': ondisk_db,
'unpack': 'ondisk_resultset',
'load_fields': [{'readout_id': 'dets:readout_id'}, 'pol_code'],
})

return ctx


def metadata_loader(self, kw):
def metadata_loader(self, kw, **load_kwargs):
# For Superloader.
filename = os.path.split(kw['filename'])[1]
if filename == 'bands.h5':
Expand Down