Skip to content

Commit b6f1c04

Browse files
mmccrackanMichael McCrackanMichael McCrackanMichael McCrackan
authored
Rework detector selection for preprocess loading (#1116)
* rework det selection for preprocess loading * add select=False when running pipelines in load functions * add in_place for select * remove unneeded import * fix missing in_places * base in_place should be true * fix missing in_place * update invvar flag select * fix type --------- Co-authored-by: Michael McCrackan <[email protected]> Co-authored-by: Michael McCrackan <[email protected]> Co-authored-by: Michael McCrackan <[email protected]>
1 parent 8f3182b commit b6f1c04

File tree

3 files changed

+105
-50
lines changed

3 files changed

+105
-50
lines changed

sotodlib/preprocess/pcore.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def save(self, proc_aman, *args):
9898
if self.save_cfgs is None:
9999
return
100100
raise NotImplementedError
101-
102-
def select(self, meta, proc_aman=None):
101+
102+
def select(self, meta, proc_aman=None, in_place=True):
103103
""" This function runs any desired data selection of the preprocessing
104104
pipeline results. Assumes the pipeline has already been run and that the
105105
resulting proc_aman is now saved under the ``preprocess`` key in the
@@ -111,12 +111,18 @@ def select(self, meta, proc_aman=None):
111111
112112
Arguments
113113
---------
114-
meta : AxisManager
114+
meta : AxisManager
115115
Metadata related to the specific observation
116+
proc_aman : AxisManager
117+
Optional. Any information generated by previous elements in the
118+
preprocessing pipeline.
119+
in_place : bool
120+
Optional. Apply selection and return restricted axis manager if
121+
True, else return the flag array.
116122
117123
Returns
118124
-------
119-
meta : AxisManager
125+
meta : AxisManager
120126
Metadata where non-selected detectors have been removed
121127
"""
122128

sotodlib/preprocess/preprocess_util.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,13 @@ def load_preprocess_det_select(obs_id, configs, context=None,
319319
pipe = Pipeline(configs["process_pipe"], logger=logger)
320320

321321
meta = context.get_meta(obs_id, dets=dets, meta=meta)
322-
logger.info(f"Cutting on the last process: {pipe[-1].name}")
323-
pipe[-1].select(meta)
322+
logger.info("Restricting detectors on all processes")
323+
keep_all = np.ones(meta.dets.count,dtype=bool)
324+
for process in pipe[:]:
325+
keep = process.select(meta, in_place=False)
326+
if isinstance(keep, np.ndarray):
327+
keep_all &= keep
328+
meta.restrict("dets", meta.dets.vals[keep_all])
324329
return meta
325330

326331

@@ -367,7 +372,7 @@ def load_and_preprocess(obs_id, configs, context=None, dets=None, meta=None,
367372
else:
368373
pipe = Pipeline(configs["process_pipe"], logger=logger)
369374
aman = context.get_obs(meta, no_signal=no_signal)
370-
pipe.run(aman, aman.preprocess)
375+
pipe.run(aman, aman.preprocess, select=False)
371376
return aman
372377

373378

@@ -436,21 +441,29 @@ def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc,
436441

437442
if check_cfg_match(aman_cfgs_ref, meta_proc.preprocess['pcfg_ref'],
438443
logger=logger):
444+
pipe_proc = Pipeline(configs_proc["process_pipe"], logger=logger)
439445

446+
logger.info("Restricting detectors on all proc pipeline processes")
447+
keep_all = np.ones(meta_proc.dets.count, dtype=bool)
448+
for process in pipe_proc[:]:
449+
keep = process.select(meta_proc, in_place=False)
450+
if isinstance(keep, np.ndarray):
451+
keep_all &= keep
452+
meta_proc.restrict("dets", meta_proc.dets.vals[keep_all])
440453
meta_init.restrict('dets', meta_proc.dets.vals)
454+
441455
aman = context_init.get_obs(meta_init, no_signal=no_signal)
442456
logger.info("Running initial pipeline")
443-
pipe_init.run(aman, aman.preprocess)
457+
pipe_init.run(aman, aman.preprocess, select=False)
444458
if init_only:
445459
return aman
446460

447-
pipe_proc = Pipeline(configs_proc["process_pipe"], logger=logger)
448461
logger.info("Running dependent pipeline")
449462
proc_aman = context_proc.get_meta(obs_id, meta=aman)
450463

451464
aman.preprocess.merge(proc_aman.preprocess)
452465

453-
pipe_proc.run(aman, aman.preprocess)
466+
pipe_proc.run(aman, aman.preprocess, select=False)
454467

455468
return aman
456469
else:

sotodlib/preprocess/processes.py

Lines changed: 76 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,17 @@ def save(self, proc_aman, dbc_aman):
6868
if self.save_cfgs:
6969
proc_aman.wrap("det_bias_flags", dbc_aman)
7070

71-
def select(self, meta, proc_aman=None):
71+
def select(self, meta, proc_aman=None, in_place=True):
7272
if self.select_cfgs is None:
7373
return meta
7474
if proc_aman is None:
7575
proc_aman = meta.preprocess
76-
keep = ~proc_aman.det_bias_flags.det_bias_flags
77-
meta.restrict("dets", meta.dets.vals[has_all_cut(keep)])
78-
return meta
76+
keep = ~has_all_cut(proc_aman.det_bias_flags.det_bias_flags)
77+
if in_place:
78+
meta.restrict("dets", meta.dets.vals[keep])
79+
return meta
80+
else:
81+
return keep
7982

8083
def plot(self, aman, proc_aman, filename):
8184
if self.plot_cfgs is None:
@@ -132,7 +135,7 @@ def save(self, proc_aman, trend_aman):
132135
if self.save_cfgs:
133136
proc_aman.wrap("trends", trend_aman)
134137

135-
def select(self, meta, proc_aman=None):
138+
def select(self, meta, proc_aman=None, in_place=True):
136139
if self.select_cfgs is None:
137140
return meta
138141
if proc_aman is None:
@@ -144,8 +147,11 @@ def select(self, meta, proc_aman=None):
144147
else:
145148
raise ValueError(f"Entry '{self.select_cfgs['kind']}' not"
146149
"understood. Expect 'any' or 'all'")
147-
meta.restrict("dets", meta.dets.vals[keep])
148-
return meta
150+
if in_place:
151+
meta.restrict("dets", meta.dets.vals[keep])
152+
return meta
153+
else:
154+
return keep
149155

150156
def plot(self, aman, proc_aman, filename):
151157
if self.plot_cfgs is None:
@@ -210,7 +216,7 @@ def save(self, proc_aman, glitch_aman):
210216
if self.save_cfgs:
211217
proc_aman.wrap(self.glitch_name, glitch_aman)
212218

213-
def select(self, meta, proc_aman=None):
219+
def select(self, meta, proc_aman=None, in_place=True):
214220
if self.select_cfgs is None:
215221
return meta
216222
if proc_aman is None:
@@ -220,8 +226,11 @@ def select(self, meta, proc_aman=None):
220226
)
221227
n_cut = count_cuts(flag)
222228
keep = n_cut <= self.select_cfgs["max_n_glitch"]
223-
meta.restrict("dets", meta.dets.vals[keep])
224-
return meta
229+
if in_place:
230+
meta.restrict("dets", meta.dets.vals[keep])
231+
return meta
232+
else:
233+
return keep
225234

226235
def plot(self, aman, proc_aman, filename):
227236
if self.plot_cfgs is None:
@@ -325,7 +334,7 @@ def save(self, proc_aman, jump_aman):
325334
name = self.save_cfgs.get('jumps_name', 'jumps')
326335
proc_aman.wrap(name, jump_aman)
327336

328-
def select(self, meta, proc_aman=None):
337+
def select(self, meta, proc_aman=None, in_place=True):
329338
if self.select_cfgs is None:
330339
return meta
331340
if proc_aman is None:
@@ -334,8 +343,10 @@ def select(self, meta, proc_aman=None):
334343

335344
n_cut = count_cuts(proc_aman[name].jump_flag)
336345
keep = n_cut <= self.select_cfgs["max_n_jumps"]
337-
meta.restrict("dets", meta.dets.vals[keep])
338-
return meta
346+
if in_place:
347+
meta.restrict("dets", meta.dets.vals[keep])
348+
return meta
349+
return keep
339350

340351
def plot(self, aman, proc_aman, filename):
341352
if self.plot_cfgs is None:
@@ -619,7 +630,7 @@ def save(self, proc_aman, noise):
619630
else:
620631
proc_aman.wrap(self.save_cfgs['wrap_name'], noise)
621632

622-
def select(self, meta, proc_aman=None):
633+
def select(self, meta, proc_aman=None, in_place=True):
623634
if self.select_cfgs is None:
624635
return meta
625636

@@ -648,8 +659,11 @@ def select(self, meta, proc_aman=None):
648659
keep &= (wn >= np.float64(self.select_cfgs["min_noise"]))
649660
if fk is not None and "max_fknee" in self.select_cfgs.keys():
650661
keep &= (fk <= np.float64(self.select_cfgs["max_fknee"]))
651-
meta.restrict("dets", meta.dets.vals[keep])
652-
return meta
662+
if in_place:
663+
meta.restrict("dets", meta.dets.vals[keep])
664+
return meta
665+
else:
666+
return keep
653667

654668
class Calibrate(_Preprocess):
655669
"""Calibrate the timestreams based on some provided information.
@@ -1244,14 +1258,17 @@ def save(self, proc_aman, dark_aman):
12441258
if self.save_cfgs:
12451259
proc_aman.wrap("darks", dark_aman)
12461260

1247-
def select(self, meta, proc_aman=None):
1261+
def select(self, meta, proc_aman=None, in_place=True):
12481262
if self.select_cfgs is None:
12491263
return meta
12501264
if proc_aman is None:
12511265
proc_aman = meta.preprocess
12521266
keep = ~has_all_cut(proc_aman.darks.darks)
1253-
meta.restrict("dets", meta.dets.vals[keep])
1254-
return meta
1267+
if in_place:
1268+
meta.restrict("dets", meta.dets.vals[keep])
1269+
return meta
1270+
else:
1271+
return keep
12551272

12561273
class SourceFlags(_Preprocess):
12571274
"""Calculate the source flags in the data.
@@ -1315,7 +1332,7 @@ def save(self, proc_aman, source_aman):
13151332
if self.save_cfgs:
13161333
proc_aman.wrap("source_flags", source_aman)
13171334

1318-
def select(self, meta, proc_aman=None):
1335+
def select(self, meta, proc_aman=None, in_place=True):
13191336
if self.select_cfgs is None:
13201337
return meta
13211338
if proc_aman is None:
@@ -1330,12 +1347,20 @@ def select(self, meta, proc_aman=None):
13301347
if len(source_list) == 0:
13311348
raise ValueError("No tags match source list")
13321349

1350+
keep_all = np.ones(meta.dets.count, dtype=bool)
1351+
13331352
for source in source_list:
13341353
if source in source_flags._fields:
13351354
keep = ~has_all_cut(source_flags[source])
1336-
meta.restrict("dets", meta.dets.vals[keep])
1337-
source_flags.restrict("dets", source_flags.dets.vals[keep])
1338-
return meta
1355+
if in_place:
1356+
meta.restrict("dets", meta.dets.vals[keep])
1357+
source_flags.restrict("dets", source_flags.dets.vals[keep])
1358+
else:
1359+
keep_all &= keep
1360+
if in_place:
1361+
return meta
1362+
else:
1363+
return keep_all
13391364

13401365
class HWPAngleModel(_Preprocess):
13411366
"""Apply hwp angle model to the TOD.
@@ -1545,14 +1570,17 @@ def save(self, proc_aman, pca_aman):
15451570
if self.save_cfgs:
15461571
proc_aman.wrap(self.run_name, pca_aman)
15471572

1548-
def select(self, meta, proc_aman=None):
1573+
def select(self, meta, proc_aman=None, in_place=True):
15491574
if self.select_cfgs is None:
15501575
return meta
15511576
if proc_aman is None:
15521577
proc_aman = meta.preprocess
15531578
keep = ~proc_aman[self.run_name]['pca_det_mask']
1554-
meta.restrict("dets", meta.dets.vals[keep])
1555-
return meta
1579+
if in_place:
1580+
meta.restrict("dets", meta.dets.vals[keep])
1581+
return meta
1582+
else:
1583+
return keep
15561584

15571585
def plot(self, aman, proc_aman, filename):
15581586
if self.plot_cfgs is None:
@@ -1701,14 +1729,17 @@ def save(self, proc_aman, ptp_aman):
17011729
if self.save_cfgs:
17021730
proc_aman.wrap("ptp_flags", ptp_aman)
17031731

1704-
def select(self, meta, proc_aman=None):
1732+
def select(self, meta, proc_aman=None, in_place=True):
17051733
if self.select_cfgs is None:
17061734
return meta
17071735
if proc_aman is None:
17081736
proc_aman = meta.preprocess
17091737
keep = ~has_all_cut(proc_aman.ptp_flags.ptp_flags)
1710-
meta.restrict("dets", meta.dets.vals[keep])
1711-
return meta
1738+
if in_place:
1739+
meta.restrict("dets", meta.dets.vals[keep])
1740+
return meta
1741+
else:
1742+
return keep
17121743

17131744
class InvVarFlags(_Preprocess):
17141745
"""Find detectors with too high inverse variance.
@@ -1741,14 +1772,16 @@ def save(self, proc_aman, inv_var_aman):
17411772
if self.save_cfgs:
17421773
proc_aman.wrap("inv_var_flags", inv_var_aman)
17431774

1744-
def select(self, meta, proc_aman=None):
1775+
def select(self, meta, proc_aman=None, in_place=True):
17451776
if self.select_cfgs is None:
17461777
return meta
17471778
if proc_aman is None:
17481779
proc_aman = meta.preprocess
17491780
keep = ~has_all_cut(proc_aman.inv_var_flags.inv_var_flags)
1750-
meta.restrict("dets", meta.dets.vals[keep])
1751-
return meta
1781+
if in_place:
1782+
meta.restrict("dets", meta.dets.vals[keep])
1783+
return meta
1784+
return keep
17521785

17531786
class EstimateT2P(_Preprocess):
17541787
"""Estimate T to P leakage coefficients.
@@ -2011,15 +2044,17 @@ def save(self, proc_aman, fp_aman):
20112044
return
20122045
if self.save_cfgs:
20132046
proc_aman.wrap("fp_flags", fp_aman)
2014-
2015-
def select(self, meta, proc_aman=None):
2047+
2048+
def select(self, meta, proc_aman=None, in_place=True):
20162049
if self.select_cfgs is None:
20172050
return meta
20182051
if proc_aman is None:
20192052
proc_aman = meta.preprocess
20202053
keep = ~has_all_cut(proc_aman.fp_flags.fp_nans)
2021-
meta.restrict("dets", meta.dets.vals[keep])
2022-
return meta
2054+
if in_place:
2055+
meta.restrict("dets", meta.dets.vals[keep])
2056+
return meta
2057+
return keep
20232058

20242059
class PointingModel(_Preprocess):
20252060
"""Apply pointing model to the TOD.
@@ -2086,15 +2121,16 @@ def save(self, proc_aman, calc_aman, name):
20862121
if self.save_cfgs:
20872122
proc_aman.wrap(name, calc_aman)
20882123

2089-
def select(self, meta, proc_aman=None):
2124+
def select(self, meta, proc_aman=None, in_place=True):
20902125
if self.select_cfgs is None:
20912126
return meta
20922127
if proc_aman is None:
20932128
proc_aman = meta.preprocess
2094-
if hasattr(proc_aman.noisy_dets_flags, "valid_dets"):
2095-
keep = proc_aman.noisy_dets_flags.valid_dets
2129+
keep = proc_aman.noisy_dets_flags.valid_dets
2130+
if in_place:
20962131
meta.restrict('dets', proc_aman.dets.vals[keep])
2097-
return meta
2132+
return meta
2133+
return keep
20982134

20992135
class CorrectIIRParams(_Preprocess):
21002136
"""Correct missing iir_params by default values.

0 commit comments

Comments
 (0)