Skip to content

Commit 1838e78

Browse files
author
Yoshinori Sueno
committed
Modify mainly source_flags so that we can use multiple sources and taua
1 parent 71e9a91 commit 1838e78

File tree

3 files changed

+172
-25
lines changed

3 files changed

+172
-25
lines changed

sotodlib/coords/planets.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@
3939
('QSO_J2253+1608', 343.4952422, 16.14301323),
4040
('galactic_center', -93.5833, -29.0078)]
4141

42+
def get_source_list_fromstr(target_source):
43+
"""Get a source_list from SOURCE_LIST by name, or raise ValueError if not found.
44+
"""
45+
for isource in SOURCE_LIST:
46+
if isinstance(isource, str):
47+
isource_name = isource
48+
elif isinstance(isource, tuple):
49+
isource_name = isource[0]
50+
if isource_name.lower() == target_source.lower():
51+
return isource
52+
raise ValueError(f'Source "{target_source}" not found in {SOURCE_LIST}.')
4253

4354
class SlowSource:
4455
"""Class to track the time-dependent position of a slow-moving source,
@@ -119,7 +130,11 @@ def get_scan_q(tod, planet, boresight_offset=None, refq=None):
119130
el = np.median(tod.boresight.el[::10])
120131
az = np.median(tod.boresight.az[::10])
121132
t = (tod.timestamps[0] + tod.timestamps[-1]) / 2
122-
if isinstance(planet, str):
133+
if isinstance(planet, (list, tuple)):
134+
planet_name, ra, dec = planet
135+
planet = SlowSource(t, float(ra) * coords.DEG,
136+
float(dec) * coords.DEG)
137+
else:
123138
planet = SlowSource.for_named_source(planet, t)
124139

125140
def scan_q_model(t, az, el, planet):

sotodlib/core/flagman.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,61 @@ def has_all_cut(flag):
278278
def count_cuts(flag):
279279
return np.array([len(x.ranges()) for x in flag], dtype='int')
280280

281+
def has_ratio_cuts(flag, ratio):
282+
"""Determine if the ratio of flag samples to total samples in each flag
283+
exceeds a given threshold.
284+
285+
Args:
286+
flag (RangesMatrix): so3g.proj.RangesMatrix
287+
ratio (float or int): The threshold ratio above which a flag is
288+
considered to have too many cuts.
289+
"""
290+
291+
if not isinstance(ratio, (float, int)):
292+
raise ValueError('Ratio must be float or int')
293+
len_samps = flag.shape[-1]
294+
return np.array([np.sum(np.diff(x.ranges()))/len_samps > ratio for x in flag])
295+
296+
def flag_cut_select(flags, cut, kind):
297+
"""Determine which detectors to select based on flag conditions.
298+
299+
Args:
300+
flags (RangesMatrix): An instance of so3g.proj.RangesMatrix indicating flagged time ranges.
301+
cut (bool): If True, returns detectors to be excluded (cut); if False, returns detectors to be selected.
302+
kind (str or float): One of the following:
303+
- 'any': Select/cut detectors with any flagged samples.
304+
- 'all': Select/cut detectors with all samples flagged.
305+
- float: A threshold ratio (0.0–1.0); selects/cuts detectors whose flagged ratio exceeds the threshold.
306+
307+
Returns:
308+
np.ndarray: Boolean array indicating which detectors to **keep** (True) or **drop** (False),
309+
depending on `cut` and `kind`.
310+
311+
Examples:
312+
1. cut=True, kind='any' → Select detectors with **no** True flags (e.g., for Moon cut).
313+
2. cut=False, kind='any' → Select detectors with **any** True flags (e.g., for planet selection).
314+
3. cut=True, kind=0.4 → Select detectors with <40% of True flags.
315+
"""
316+
if cut:
317+
if kind == 'any':
318+
return ~has_any_cuts(flags)
319+
elif kind == 'all':
320+
return ~has_all_cut(flags)
321+
elif isinstance(kind, float):
322+
return ~has_ratio_cuts(flags, ratio=kind)
323+
else:
324+
raise ValueError("kind must be 'any', 'all', or a float between 0.0 and 1.0")
325+
else:
326+
if kind == 'any':
327+
return has_any_cuts(flags)
328+
elif kind == 'all':
329+
return has_all_cut(flags)
330+
elif isinstance(kind, float):
331+
return has_ratio_cuts(flags, ratio=kind)
332+
else:
333+
raise ValueError("kind must be 'any', 'all', or a float between 0.0 and 1.0")
334+
335+
281336
def sparse_to_ranges_matrix(arr, buffer=0, close_gaps=0, val=True):
282337
"""Convert a csr sparse array into a ranges matrix
283338
@@ -299,3 +354,23 @@ def sparse_to_ranges_matrix(arr, buffer=0, close_gaps=0, val=True):
299354
x[i].close_gaps(close_gaps)
300355
return x
301356

357+
def find_common_edge_idx(flags):
358+
"""Find the common valid range across multiple RangesMatrix objects.
359+
360+
Args:
361+
flags (RangesMatrix): An instance of so3g.proj.RangesMatrix indicating flagged time ranges.
362+
Returns:
363+
min_idx, max_idx: minmum and maximum indices that has False flag across all detectros.
364+
"""
365+
max_val = max(arr.complement().ranges()[:, 1].max() for arr in flags)
366+
common_mask = np.ones(max_val, dtype=bool)
367+
for arr in flags:
368+
mask = np.zeros(max_val, dtype=bool)
369+
for start, end in arr.complement().ranges():
370+
mask[start:end] = True
371+
common_mask &= mask
372+
373+
valid_indices = np.where(common_mask)[0]
374+
if len(valid_indices) == 0:
375+
raise ValueError("No common valid range found across all flags.")
376+
return valid_indices[0], valid_indices[-1]

sotodlib/preprocess/processes.py

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import sotodlib.coords.planets as planets
1313

1414
from sotodlib.core.flagman import (has_any_cuts, has_all_cut,
15-
count_cuts,
15+
count_cuts, flag_cut_select,
1616
sparse_to_ranges_matrix)
1717

1818
from .pcore import _Preprocess, _FracFlaggedMixIn
@@ -245,7 +245,6 @@ def plot(self, aman, proc_aman, filename):
245245
plot_ds_factor=self.plot_cfgs.get("plot_ds_factor", 50), filename=filename.replace('{name}', f'{ufm}_glitch_signal_diff'))
246246
plot_flag_stats(aman, proc_aman[self.glitch_name], flag_type='glitches', filename=filename.replace('{name}', f'{ufm}_glitch_stats'))
247247

248-
249248
class FixJumps(_Preprocess):
250249
"""
251250
Repairs the jump heights given a set of jump flags and heights.
@@ -1281,6 +1280,7 @@ class SourceFlags(_Preprocess):
12811280
Example config block::
12821281
12831282
- name : "source_flags"
1283+
source_flags_name: "my_source_flags"
12841284
calc:
12851285
mask: {'shape': 'circle',
12861286
'xyr': [0, 0, 1.]}
@@ -1290,80 +1290,113 @@ class SourceFlags(_Preprocess):
12901290
distance: 0 # max distance of footprint from source in degrees
12911291
save: True
12921292
select: True # optional
1293+
select_source: 'jupiter' # list of str or str. If not provided, all sources from center_on are selected.
1294+
cut: True # True performs cut based on flags
1295+
kind: 'any' # 'any', 'all', or 'cut'
1296+
Examples:
1297+
1. cut=True, kind='any' → Select detectors with **no** True flags (e.g., for Moon cut).
1298+
2. cut=False, kind='any' → Select detectors with **any** True flags (e.g., for planet selection).
1299+
3. cut=True, kind=0.4 → Select detectors with <40% of True flags.
12931300
12941301
.. autofunction:: sotodlib.tod_ops.flags.get_source_flags
12951302
"""
12961303
name = "source_flags"
1304+
def __init__(self, step_cfgs):
1305+
self.source_flags_name = step_cfgs.get('source_flags_name', 'source_flags')
1306+
super().__init__(step_cfgs)
12971307

12981308
def calc_and_save(self, aman, proc_aman):
1299-
source_list = np.atleast_1d(self.calc_cfgs.get('center_on', 'planet'))
1300-
if source_list == ['planet']:
1301-
from sotodlib.coords.planets import SOURCE_LIST
1309+
from sotodlib.coords.planets import SOURCE_LIST
1310+
source_name = np.atleast_1d(self.calc_cfgs.get('center_on', 'planet'))
1311+
if 'planet' in source_name:
13021312
source_list = [x for x in aman.tags if x in SOURCE_LIST]
13031313
if len(source_list) == 0:
13041314
raise ValueError("No tags match source list")
1315+
else:
1316+
source_list = [planets.get_source_list_fromstr(isource) for isource in source_name]
13051317

13061318
# find if source is within footprint + distance
13071319
positions = planets.get_nearby_sources(tod=aman, source_list=source_list,
13081320
distance=self.calc_cfgs.get('distance', 0))
1309-
13101321
source_aman = core.AxisManager(aman.dets, aman.samps)
13111322
for p in positions:
1323+
center_on = planets.get_source_list_fromstr(p[0])
13121324
source_flags = tod_ops.flags.get_source_flags(aman,
13131325
merge=self.calc_cfgs.get('merge', False),
13141326
overwrite=self.calc_cfgs.get('overwrite', True),
13151327
source_flags_name=self.calc_cfgs.get('source_flags_name', None),
13161328
mask=self.calc_cfgs.get('mask', None),
1317-
center_on=p[0],
1329+
center_on=center_on,
13181330
res=self.calc_cfgs.get('res', None),
13191331
max_pix=self.calc_cfgs.get('max_pix', None))
13201332

13211333
source_aman.wrap(p[0], source_flags, [(0, 'dets'), (1, 'samps')])
13221334

1335+
# if inv_flag is set, add inverse of source flags
1336+
if self.calc_cfgs.get('inv_flag'):
1337+
source_aman.wrap(p[0]+ '_inv',
1338+
RangesMatrix.from_mask(~source_flags.mask()),
1339+
[(0, 'dets'), (1, 'samps')])
1340+
13231341
# add sources that were not nearby from source list
1324-
for source in source_list:
1342+
for source in source_name:
13251343
if source not in source_aman._fields:
13261344
source_aman.wrap(source, RangesMatrix.zeros([aman.dets.count, aman.samps.count]),
13271345
[(0, 'dets'), (1, 'samps')])
1346+
1347+
if self.calc_cfgs.get('inv_flag'):
1348+
source_aman.wrap(source + '_inv',
1349+
RangesMatrix.ones([aman.dets.count, aman.samps.count]),
1350+
[(0, 'dets'), (1, 'samps')])
13281351

13291352
self.save(proc_aman, source_aman)
13301353

13311354
def save(self, proc_aman, source_aman):
13321355
if self.save_cfgs is None:
13331356
return
13341357
if self.save_cfgs:
1335-
proc_aman.wrap("source_flags", source_aman)
1358+
proc_aman.wrap(self.source_flags_name, source_aman)
13361359

13371360
def select(self, meta, proc_aman=None, in_place=True):
13381361
if self.select_cfgs is None:
13391362
return meta
13401363
if proc_aman is None:
13411364
source_flags = meta.preprocess.source_flags
13421365
else:
1343-
source_flags = proc_aman.source_flags
1366+
source_flags = proc_aman[self.source_flags_name]
13441367

1345-
source_list = np.atleast_1d(self.calc_cfgs.get('center_on', 'planet'))
1346-
if source_list == ['planet']:
1368+
select_list = np.atleast_1d(self.select_cfgs.get("select_source", self.calc_cfgs.get('center_on', 'planet')))
1369+
if 'planet' in select_list:
13471370
from sotodlib.coords.planets import SOURCE_LIST
1348-
source_list = [x for x in aman.tags if x in SOURCE_LIST]
1349-
if len(source_list) == 0:
1371+
select_list = [x for x in meta.tags if x in SOURCE_LIST]
1372+
if len(select_list) == 0:
13501373
raise ValueError("No tags match source list")
13511374

1375+
cuts = self.select_cfgs.get("cut", True) # default of True is for backward compatibility
1376+
if isinstance(cuts, bool):
1377+
cuts = [cuts]*len(select_list)
1378+
elif len(cuts) != len(select_list):
1379+
raise ValueError("Length of cuts must match length of select_source, or just bool")
1380+
1381+
kinds = self.select_cfgs.get("kind", 'all') # default of 'all' is for backward compatibility
1382+
if isinstance(kinds, (str, float)):
1383+
kinds = [kinds]*len(select_list)
1384+
elif len(kinds) != len(select_list):
1385+
raise ValueError("Length of kinds must match length of select_source, or just str")
1386+
13521387
keep_all = np.ones(meta.dets.count, dtype=bool)
13531388

1354-
for source in source_list:
1389+
for source, kind, cut in zip(select_list, kinds, cuts):
13551390
if source in source_flags._fields:
1356-
keep = ~has_all_cut(source_flags[source])
1357-
if in_place:
1358-
meta.restrict("dets", meta.dets.vals[keep])
1359-
source_flags.restrict("dets", source_flags.dets.vals[keep])
1360-
else:
1361-
keep_all &= keep
1391+
keep_all &= flag_cut_select(source_flags[source], cut, kind)
1392+
13621393
if in_place:
1394+
meta.restrict("dets", meta.dets.vals[keep_all])
1395+
source_flags.restrict("dets", source_flags.dets.vals[keep_all])
13631396
return meta
13641397
else:
13651398
return keep_all
1366-
1399+
13671400
class HWPAngleModel(_Preprocess):
13681401
"""Apply hwp angle model to the TOD.
13691402
@@ -1689,7 +1722,7 @@ def __init__(self, step_cfgs):
16891722

16901723
super().__init__(step_cfgs)
16911724

1692-
def process(self, aman, proc_aman):
1725+
def process(self, aman, proc_aman, sim=False):
16931726
n_modes = self.process_cfgs.get('n_modes')
16941727
signal = aman.get(self.signal)
16951728
flags = aman.flags.get(self.process_cfgs.get('source_flags'))
@@ -2146,10 +2179,33 @@ class CorrectIIRParams(_Preprocess):
21462179
"""
21472180
name = "correct_iir_params"
21482181

2149-
def process(self, aman, proc_aman):
2182+
def process(self, aman, proc_aman, sim=False):
21502183
from sotodlib.obs_ops import correct_iir_params
21512184
correct_iir_params(aman)
21522185

2186+
class TrimFlagEdge(_Preprocess):
2187+
"""Trim edge until given flags of all detectors are False
2188+
To find first and last sample id that has False (i.e., no flags applied) for all detectors.
2189+
This is for avoiding glitchfill problem for data whose edge has flags of True.
2190+
2191+
Example config block::
2192+
2193+
- name: "trim_flag_edge"
2194+
process:
2195+
flags: "pca_exclude"
2196+
2197+
.. autofunction:: sotodlib.core.flagman.find_common_edge_idx
2198+
"""
2199+
name = 'trim_flag_edge'
2200+
2201+
def process(self, aman, proc_aman, sim=False):
2202+
flags = aman.flags.get(self.process_cfgs.get('flags'))
2203+
trimst, trimen = core.flagman.find_common_edge_idx(flags)
2204+
aman.restrict('samps', (aman.samps.offset + trimst,
2205+
aman.samps.offset + trimen))
2206+
proc_aman.restrict('samps', (proc_aman.samps.offset + trimst,
2207+
proc_aman.samps.offset + trimen))
2208+
21532209
_Preprocess.register(SplitFlags)
21542210
_Preprocess.register(SubtractT2P)
21552211
_Preprocess.register(EstimateT2P)
@@ -2193,3 +2249,4 @@ def process(self, aman, proc_aman):
21932249
_Preprocess.register(PointingModel)
21942250
_Preprocess.register(BadSubscanFlags)
21952251
_Preprocess.register(CorrectIIRParams)
2252+
_Preprocess.register(TrimFlagEdge)

0 commit comments

Comments
 (0)