Skip to content

Add loader and det cal functionality to load in bandpass info using bg #1206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 6, 2025
Merged
12 changes: 7 additions & 5 deletions sotodlib/core/metadata/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,11 +845,13 @@ class MetadataSpec:
with fields found in this metadata item. See notes below.

``load_fields`` (list of str or None)
List of fields to load. This may include entire child
AxisManagers, or fields within them using "." for hierarchical
addressing. This is only for AxisManager metadata. Default
is None, which meaning to load all fields. Wildcards are not
supported.
List of fields to load. For AxisManager metadata this may
include entire child AxisManagers, or fields within them
using "." for hierarchical addressing. For ResultSet metadata,
subitems of this list can be specified as dictionaries
(original_fieldname -> new_fieldname) which allow the user to
change the fieldname of the subitem. Default is None, which
means to load all fields. Wildcards are not supported.

``drop_fields`` (list of str)
List of fields (which may contain wildcard character ``*``) to
Expand Down
37 changes: 29 additions & 8 deletions sotodlib/io/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,38 @@ def from_loadspec(self, load_params, **kwargs):


class ResultSetHdfLoader(LoaderInterface):
def _prefilter_data(self, data_in, key_map={}):
"""When a dataset is loaded and converted to a structured numpy
array, this function is called before the data are returned to
the user. The key_map can be used to rename fields, on load.
def _check_key_map(self, key_map, data_in):
"""Check that all keys in key_map are valid fields in data_in."""
for key in key_map.keys():
if key not in data_in.dtype.names:
raise KeyError(f"{key} not included inside dataset with keys"
f"{data_in.dtype.names}")

def _prefilter_data(self, data_in, load_fields=None):
"""When a dataset is loaded and converted to a structured numpy array,
this function is called before the data are returned to the user. If
'load_fields' has data for this instance, this function adds any fields
specified within to the key map and checks that these fields are all
valid fields in data_in.

This function may be extended in subclasses, but you will
likely want to call the super() handler before doing
additional processing. The loading functions do not pass in
key_map -- this is for the exclusive use of subclasses.

additional processing.
"""
key_map = {}
if load_fields is not None:
for field in load_fields:
if isinstance(field, dict):
key_map.update(field)
else:
key_map[field] = field

for k in data_in.dtype.names:
if k not in key_map.keys():
key_map[k] = None

self._check_key_map(key_map, data_in)

return _decode_array(data_in, key_map=key_map)

def _populate(self, data, keys=None, row_order=None):
Expand Down Expand Up @@ -192,7 +213,7 @@ def batch_from_loadspec(self, load_params, **kwargs):
dataset = load_params[idx]['dataset']
if dataset is not last_dataset:
data = fin[dataset][()]
data = self._prefilter_data(data)
data = self._prefilter_data(data, **kwargs)
last_dataset = dataset

# Dereference the extrinsic axis request. Every
Expand Down
7 changes: 4 additions & 3 deletions sotodlib/preprocess/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,7 @@ class PCARelCal(_Preprocess):
def __init__(self, step_cfgs):
self.signal = step_cfgs.get('signal', 'signal')
self.run = step_cfgs.get('pca_run', 'run1')
self.bandpass = step_cfgs.get('bandpass_key', 'wafer.bandpass')
self.run_name = f'{self.signal}_{self.run}'

super().__init__(step_cfgs)
Expand All @@ -1366,15 +1367,15 @@ def calc_and_save(self, aman, proc_aman):
if self.plot_cfgs:
self.plot_signal = filt_aman[self.signal]

bands = np.unique(aman.det_info.wafer.bandpass)
bands = np.unique(aman.det_info[self.bandpass])
bands = bands[bands != 'NC']
# align samps w/ proc_aman to include samps restriction when loading back from db.
rc_aman = core.AxisManager(proc_aman.dets, proc_aman.samps)
pca_det_mask = np.full(aman.dets.count, False, dtype=bool)
relcal = np.zeros(aman.dets.count)
pca_weight0 = np.zeros(aman.dets.count)
for band in bands:
m0 = aman.det_info.wafer.bandpass == band
m0 = aman.det_info[self.bandpass] == band
if self.plot_cfgs is not None:
rc_aman.wrap(f'{band}_idx', m0, [(0, 'dets')])
band_aman = aman.restrict('dets', aman.dets.vals[m0], in_place=False)
Expand Down Expand Up @@ -1430,7 +1431,7 @@ def plot(self, aman, proc_aman, filename):
det = aman.dets.vals[0]
ufm = det.split('_')[2]

bands = np.unique(aman.det_info.wafer.bandpass)
bands = np.unique(aman.det_info[self.bandpass])
bands = bands[bands != 'NC']
for band in bands:
if f'{band}_pca_mode0' in proc_aman[self.run_name]:
Expand Down
43 changes: 31 additions & 12 deletions sotodlib/site_pipeline/update_det_cal.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,27 @@
DEFAULT_pA_per_phi0 = 9e6
TES_BIAS_COUNT = 12 # per detset / primary file group

# For converting bias group to bandpass.
BGS = {'lb': [0, 1, 4, 5, 8, 9], 'hb': [2, 3, 6, 7, 10, 11]}
BAND_STR = {'mf': {'lb': 'f090', 'hb': 'f150'},
'uhf': {'lb': 'f220', 'hb': 'f280'},
'lf': {'lb': 'f030', 'hb': 'f040'}}

logger = logging.getLogger("det_cal")
if not logger.hasHandlers():
sp_util.init_logger("det_cal")


def get_data_root(ctx: core.Context) -> str:
def get_data_root(ctx: core.Context, obs_id:str) -> str:
"Get root data directory based on context file"
c = ctx.obsfiledb.conn.execute("select name from files limit 1")
res = [r[0] for r in c][0]
# split out <data_root>/obs/<timecode>/<obsid>/fname
for _ in range(4):
for _ in range(5):
res = os.path.dirname(res)
# add telescope str to path
tel_str = obs_id.split(('_'))[2]
res = os.path.join(res, tel_str)
return res


Expand All @@ -59,9 +68,6 @@ class DetCalCfg:
Path to the root of the results directory.
context_path: str
Path to the context file to use.
data_root: Optional[str]
Root path of L3 data. If this is not specified, will automatically
determine it based on the context.
raise_exceptions: bool
If Exceptions should be raised in the get_cal_resset function.
Defaults to False.
Expand Down Expand Up @@ -106,7 +112,6 @@ def __init__(
root_dir: str,
context_path: str,
*,
data_root: Optional[str] = None,
raise_exceptions: bool = False,
apply_cal_correction: bool = True,
index_path: str = "det_cal.sqlite",
Expand All @@ -125,8 +130,6 @@ def __init__(
self.root_dir = root_dir
self.context_path = os.path.expandvars(context_path)
ctx = core.Context(self.context_path)
if data_root is None:
self.data_root = get_data_root(ctx)
self.raise_exceptions = raise_exceptions
self.apply_cal_correction = apply_cal_correction
self.cache_failed_obsids = cache_failed_obsids
Expand Down Expand Up @@ -248,6 +251,8 @@ class CalInfo:
Current responsivity of the TES [1/V] computed using bias steps at the
bias point. This is based on the naive bias step estimation without
using any additional corrections.
bandpass: str
Detector bandpass, computed from bias group information.
"""

readout_id: str = ""
Expand All @@ -269,6 +274,7 @@ class CalInfo:
naive_r_frac: float = np.nan
naive_p_bias: float = np.nan
naive_s_i: float = np.nan
bandpass: str = "NC"

@classmethod
def dtype(cls) -> List[Tuple[str, Any]]:
Expand All @@ -277,6 +283,9 @@ def dtype(cls) -> List[Tuple[str, Any]]:
for field in fields(cls):
if field.name == "readout_id":
dt: Tuple[str, Any] = ("dets:readout_id", "<U40")
elif field.name == 'bandpass':
# Our bandpass str is max 4 characters
dt: Tuple[str, Any] = ("bandpass", "<U4")
else:
dt = (field.name, field.type)
dtype.append(dt)
Expand Down Expand Up @@ -357,7 +366,7 @@ def get_obs_info(cfg: DetCalCfg, obs_id: str) -> ObsInfoResult:
if oid is not None:
timecode = oid.split("_")[1][:5]
zsmurf_dir = os.path.join(
cfg.data_root, "oper", timecode, oid, f"Z_smurf"
get_data_root(ctx, obs_id), "oper", timecode, oid, f"Z_smurf"
)
for f in os.listdir(zsmurf_dir):
if "iv" in f:
Expand All @@ -377,7 +386,7 @@ def get_obs_info(cfg: DetCalCfg, obs_id: str) -> ObsInfoResult:
if oid is not None:
timecode = oid.split("_")[1][:5]
zsmurf_dir = os.path.join(
cfg.data_root, "oper", timecode, oid, f"Z_smurf"
get_data_root(ctx, obs_id), "oper", timecode, oid, f"Z_smurf"
)
for f in os.listdir(zsmurf_dir):
if "bias_step" in f:
Expand Down Expand Up @@ -617,6 +626,13 @@ def find_correction_results(band, chan, dset):
else:
cal.phase_to_pW = pA_per_phi0 / (2 * np.pi) / cal.s_i * cal.polarity

# Add bandpass informaton from bias group
tube_flavor = am.obs_info.tube_flavor
if cal.bg in BGS['lb']:
cal.bandpass = BAND_STR[tube_flavor]['lb']
elif cal.bg in BGS['hb']:
cal.bandpass = BAND_STR[tube_flavor]['hb']

res.result_set = np.array([astuple(c) for c in cals], dtype=CalInfo.dtype())
res.success = True
except Exception as e:
Expand Down Expand Up @@ -716,7 +732,6 @@ def run_update_site(cfg: DetCalCfg) -> None:

logger.info(f"Processing {len(obs_ids)} obsids...")

mp.set_start_method(cfg.multiprocess_start_method)
with mp.Pool(cfg.nprocs_result_set) as pool:
for oid in tqdm(obs_ids, disable=(not cfg.show_pb)):
res = get_obs_info(cfg, oid)
Expand Down Expand Up @@ -754,6 +769,7 @@ def run_update_nersc(cfg: DetCalCfg) -> None:
obs_ids = get_obsids_to_run(cfg)
# obs_ids = ['obs_1713962395_satp1_0000100']
# obs_ids = ['obs_1713758716_satp1_1000000']
# obs_ids = ['obs_1701383445_satp3_1000000']
logger.info(f"Processing {len(obs_ids)} obsids...")

pb = tqdm(total=len(obs_ids), disable=(not cfg.show_pb))
Expand All @@ -769,7 +785,6 @@ def errback(e):
# We split into multiple pools because:
# - we don't want to overload sqlite files with too much concurrent access
# - we want to be able to continue getting the next obs_info data while ressets are being computed
mp.set_start_method(cfg.multiprocess_start_method)
pool1 = mp.Pool(cfg.nprocs_obs_info)
pool2 = mp.Pool(cfg.nprocs_result_set)

Expand Down Expand Up @@ -847,4 +862,8 @@ def get_parser(
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
cfg = DetCalCfg.from_yaml(args.config_file)
# This needs to be run here in __main__ or else this will throw a "context
# has already been set" RuntimeError
mp.set_start_method(cfg.multiprocess_start_method)
main(config_file=args.config_file)
45 changes: 42 additions & 3 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,19 @@ def test_120_load_fields(self):
self.assertCountEqual(tod.ondisk._fields.keys(), ['disk1', 'subaman'])
self.assertCountEqual(tod.ondisk.subaman._fields.keys(), ['disk2'])

def test_load_fields_resultset(self):
dataset_sim = DatasetSim()
obs_id = dataset_sim.obss['obs_id'][1]
ctx = dataset_sim.get_context(with_resultset_ondisk=True)
tod = ctx.get_obs(obs_id)
# Make sure that 'band1' was loaded into det_info
self.assertTrue('band1' in tod.det_info._fields)
# Make sure that 'band1' laods the same informaton as 'band'
self.assertTrue((tod.det_info.band1 == tod.det_info.band).all())
# Make sure that only 'pol_code' was loaded into ondisc_resultset
self.assertCountEqual(tod.ondisk_resultset._fields.keys(), ['pol_code'])
self.assertTrue((tod.det_info.pol_code == tod.ondisk_resultset.pol_code).all())

def test_200_load_metadata(self):
"""Test the simple metadata load wrapper."""
dataset_sim = DatasetSim()
Expand Down Expand Up @@ -376,8 +389,8 @@ def __init__(self):
self.sample_count = 100

class _TestML(metadata.LoaderInterface):
def from_loadspec(_self, load_params):
return self.metadata_loader(load_params)
def from_loadspec(_self, load_params, **load_kwargs):
return self.metadata_loader(load_params, **load_kwargs)

OBSLOADER_REGISTRY['unittest_loader'] = self.tod_loader
metadata.SuperLoader.register_metadata('unittest_loader', _TestML)
Expand All @@ -388,6 +401,7 @@ def get_context(self, with_detdb=True, with_metadata=True,
with_incomplete_metadata=False,
with_inconcatable=False,
with_axisman_ondisk=False,
with_resultset_ondisk=False,
on_missing='trim'):
"""Args:
with_detdb: if False, no detdb is included.
Expand Down Expand Up @@ -636,10 +650,35 @@ def _db_multi_dataset(filename, detsets=['neard', 'fard']):
output['subaman'].wrap_new(f'disk{i}', shape=('dets',))
output.save(filename, 'xyz')

if with_resultset_ondisk:
_scheme = metadata.ManifestScheme() \
.add_exact_match('obs:obs_id') \
.add_data_field('dataset')
ondisk_db = metadata.ManifestDb(scheme=_scheme)
ondisk_db._tempdir = tempfile.TemporaryDirectory()
filename = os.path.join(ondisk_db._tempdir.name,
'ondisk_resultset_metadata.h5')
write_dataset(self.dets, filename, 'obs_number_12')
ondisk_db.add_entry(
{'obs:obs_id': 'obs_number_12',
'dataset': 'obs_number_12'},
filename)
ctx['metadata'].insert(0, {
'db': ondisk_db,
'det_info': True,
'load_fields': [{'band': 'dets:band1'}, {'readout_id': 'dets:readout_id'}],
})

ctx['metadata'].insert(0, {
'db': ondisk_db,
'unpack': 'ondisk_resultset',
'load_fields': [{'readout_id': 'dets:readout_id'}, 'pol_code'],
})

return ctx


def metadata_loader(self, kw):
def metadata_loader(self, kw, **load_kwargs):
# For Superloader.
filename = os.path.split(kw['filename'])[1]
if filename == 'bands.h5':
Expand Down