1+ """
2+ Functions for producing source maps.
3+ """
4+
15import glob
26import logging
37import os
48import time
9+ from argparse import Namespace
10+ from logging import Logger
11+ from typing import Optional , cast
512
613import numpy as np
7- from pixell import enmap
14+ import sqlalchemy as sqy
15+ from astropy .wcs import WCS
16+ from mpi4py import MPI
17+ from pixell import bunch , enmap
818from so3g .proj import RangesMatrix
919from sotodlib import mapmaking , tod_ops
1020from sotodlib .coords import planets as cp
21+ from sotodlib .coords import pmat
22+ from sotodlib .core import AxisManager
23+ from sotodlib .site_pipeline import jobdb
1124
1225from .beam_utils import estimate_cent
1326from .utils import log_lvl , set_tag
1427
1528
16- def make_cuts (aman , source_flags , n_modes , job , logger , cfg ):
29+ def make_cuts (
30+ aman : AxisManager ,
31+ source_flags : RangesMatrix ,
32+ n_modes : int ,
33+ job : jobdb .Job ,
34+ logger : Logger ,
35+ cfg : Namespace ,
36+ ) -> Optional [RangesMatrix ]:
37+ """
38+ Compute cuts on a source TOD before mapping.
39+ This filters for the source and then calculates the peak SNR of each detector
40+ by dividing the max of the region in `source_flags` by the standard deviation of the region outside of `source_flags`.
41+ Any detector with an SNR less than `cfg.min_snr` is cut.
42+
43+ Parameters
44+ ----------
45+ aman : AxisManager
46+ The loaded data ready to be mapped.
47+ source_flags : RangesMatrix
48+ RangesMatrix with all samples within some radius of the source flagged.
49+ n_modes : int
50+ The number of modes to use when filtering.
51+ job : jobdb.Job
52+ The job associated with making this map.
53+ logger : Logger
54+ The logger to log to.
55+ cfg : Namespace
56+ The loaded configuration.
57+ See `lat_beams.utils.config` for details.
58+
59+ Returns
60+ -------
61+ cuts : Optional[RangesMatrix]
62+ The calculated cuts.
63+ If the number of uncut detectors is less than `cfg.min_dets` then `None` is returned.
64+ """
1765 sig_filt = cp .filter_for_sources (
1866 tod = aman ,
1967 signal = aman .signal .copy (),
@@ -41,45 +89,105 @@ def make_cuts(aman, source_flags, n_modes, job, logger, cfg):
4189 msg = f"Not enough detectors after source flag cuts!"
4290 logger .error ("\t %s" , msg )
4391 set_tag (job , "message" , msg )
44- job .jstate = "failed"
92+ job .jstate = cast (sqy .Column [str ], jobdb .JState .failed )
93+
4594 return None
4695 return cuts
4796
4897
4998def make_map (
50- aman ,
51- src_to_map ,
52- res ,
53- cuts ,
54- source_flags ,
55- comps ,
56- n_modes ,
57- pixsize ,
58- filename ,
59- min_det_secs ,
60- info ,
61- job ,
62- map_str ,
63- logger ,
64- cfg ,
65- ):
99+ aman : AxisManager ,
100+ src_to_map : str ,
101+ res : float ,
102+ cuts : RangesMatrix ,
103+ source_flags : RangesMatrix ,
104+ comps : str ,
105+ n_modes : int ,
106+ pixsize : float , # TODO: This doesn't need to exist
107+ filename : str ,
108+ min_det_secs : float ,
109+ info : dict [str , str ],
110+ job : jobdb .Job ,
111+ map_str : str ,
112+ logger : Logger ,
113+ cfg : Namespace ,
114+ ) -> tuple [Optional [dict ], Optional [tuple [int , int ]]]:
115+ """
116+ Make a filter-bin map of a source and estimate the center.
117+ The map will be in source-scan coordinates and uses `domdir` threading.
118+ If the mapping fails the associated job will have its message updated with an explanation.
119+ If mapping succeeds then paths to the map's files will be added to the job.
120+
121+ Once mapped the center of the map and its SNR are estimated.
122+ If the SNR is below `cfg.min_snr` and `cfg.del_map` is set then the files
123+ associated with the map are deleted.
124+
125+ Parameters
126+ ----------
127+ aman : AxisManager
128+ The loaded data ready to be mapped.
129+ src_to_map : str
130+ The name of the source to map.
131+ res : float
132+ The desired map resolution in radians.
133+ cuts : RangesMatrix
134+ The output of `make_cuts`.
135+ source_flags : RangesMatrix
136+ RangesMatrix with all samples within some radius of the source flagged.
137+ comps : str
138+ The maps to compute, should be `T` or `TQU`.
139+ n_modes : int
140+ The number of modes to use when filtering.
141+ pixsize : float
142+ `res` in arcseconds
143+ filename : str
144+ The pattern for the output map filename.
145+ See `sotodlib.coords.planets.make_map` for details.
146+ min_det_secs : float
147+ The minimum number of detector seconds.
148+ If we are below this a map is not made.
149+ info : dict[str, str]
150+ The information used to fill in `filename`.
151+ See `sotodlib.coords.planets.make_map` for details.
152+ job : jobdb.Job
153+ The job associated with making this map.
154+ map_str : str
155+ A short string to describe the map in the logs and job (ie. "initial").
156+ logger : Logger
157+ The logger to log to.
158+ cfg : Namespace
159+ The loaded configuration.
160+ See `lat_beams.utils.config` for details.
161+
162+ Returns
163+ -------
164+ outmap : Optional[dict]
165+ The output of `sotodlib.coords.planets.make_map`.
166+ `None` is returned if we are below `min_det_secs` or `cfg.min_snr`.
167+ cent : Optional[tuple[int, int]]]
168+ The estimated center of the map.
169+ `None` is returned if we are below `min_det_secs` or `cfg.min_snr`.
170+ """
66171 # Get time on source
67- det_secs = np .sum ((source_flags * ~ cuts ).get_stats ()["samples" ]) * np .mean (
68- np .diff (aman .timestamps )
172+ sf_uncut = source_flags * ~ cuts
173+ if sf_uncut is None :
174+ raise ValueError ("RangesMatrix somehow became none..." )
175+ det_secs = np .sum (sf_uncut .get_stats ()["samples" ]) * np .mean (
176+ np .diff (np .array (aman .timestamps ))
69177 )
70178 logger .debug ("\t %s detector seconds on source in %s mask" , det_secs , map_str )
71179 if det_secs < min_det_secs :
72180 msg = f"\t Not enough time on source in { map_str } mask."
73181 logger .error ("\t %s" , msg )
74182 set_tag (job , "message" , msg )
75- job .jstate = " failed"
183+ job .jstate = cast ( sqy . Column [ str ], jobdb . JState . failed )
76184 return None , None
77185
78186 # Initial map
79187 with log_lvl (logger , logging .WARNING ):
80188 out = cp .make_map (
81189 aman .copy (),
82- thread_algo = "domdir" ,
190+ thread_algo = "domdir" , # type: ignore
83191 center_on = src_to_map ,
84192 res = res ,
85193 cuts = cuts ,
@@ -102,7 +210,7 @@ def make_map(
102210 msg = f"{ map_str .title ()} map SNR too low."
103211 logger .error ("\t %s" , msg )
104212 set_tag (job , "message" , msg )
105- job .jstate = " failed"
213+ job .jstate = cast ( sqy . Column [ str ], jobdb . JState . failed )
106214 if cfg .del_map and filename is not None :
107215 logger .debug ("\t Deleting map files" )
108216 glob_path = os .path .splitext (filename )[0 ] + "*.*"
@@ -116,7 +224,26 @@ def make_map(
116224 return out , cent
117225
118226
119- def get_passes (cfg ):
227+ def get_passes (cfg : Namespace ) -> list [bunch .Bunch ]:
228+ """
229+ Setup passes for making an ML mapmaker.
230+ The output will have `cfg.mlpass` elements with all of them
231+ having a downsampling factor of 1.
232+ The last pass will hade bilinear interpolations and the rest will be nearest neighbor.
233+ The i'th pass will have `max(1, cfg.cgiters//2, cfg.cgiters//(i + 1))` CG iters.
234+
235+ Parameters
236+ ----------
237+ cfg : Namespace
238+ The loaded configuration.
239+ See `lat_beams.utils.config` for details.
240+
241+ Returns
242+ -------
243+ passes : list[bunch.Bunch]
244+ The passes for the sotodlib ML mapmaker.
245+ See this function's docstring for details on the contents.
246+ """
120247 passes = []
121248 if cfg .mlpass > 0 :
122249 dsstr = "1"
@@ -133,12 +260,58 @@ def get_passes(cfg):
133260
134261
135262def add_obs_to_mapmaker (
136- aman , sub_id , mapmaker , ipass , passinfo , P , guess , eval_prev , mapmaker_prev , logger
263+ aman : AxisManager ,
264+ sub_id : str ,
265+ mapmaker : mapmaking .MLMapmaker ,
266+ ipass : int ,
267+ passinfo : bunch .Bunch ,
268+ P : pmat .P ,
269+ guess : Optional [enmap .ndmap ],
270+ eval_prev : Optional [mapmaking .MLEvaluator ],
271+ mapmaker_prev : Optional [mapmaking .MLMapmaker ],
272+ logger : Logger ,
137273):
274+ """
275+ Add an observation to the MLMapmaker.
276+ This makes sure that the correct weather and site are included and will use
277+ an input map or a previous iteration of the mapmaker to estimate the signal
278+ when computing the noise model if they are provided.
279+
280+ Parameters
281+ ----------
282+ aman : AxisManager
283+ The TOD after preprocessing.
284+ The cut samples should be in the `glitch` flag.
285+ sub_id : str
286+ The `sub_id` of the TOD.
287+ Should have format `{obs_id}:{ws}:{band}`.
288+ mapmaker : mapmaking.MLMapmaker
289+ The mapmaker instance to add the observation to.
290+ ipass : int
291+ The pass number.
292+ passinfo : bunch.Bunch
293+ The pass we are on. See `setup_passes` for details.
294+ P : pmat.P
295+ The `sotodlib` projectioe matrix.
296+ guess : Optional[enmap.ndmap]
297+ A guess at what the map is.
298+ If this is not `None` and `ipass == 0` then this is
299+ used to estimate the signal when building the noise model.
300+ eval_prev : Optional[mapmakin.MLEvaluator]
301+ Evaluator for the previous pass of the mapmaker.
302+ If `ipass > 0` and both this and `mapmaker_prev` are not `None`
303+ then they ary used to estimate the signal when building the noise model.
304+ mapmaker_prev : Optional[mapmakin.MLMapmaker]
305+ Mapmaker instance for the previous pass of the mapmaker.
306+ If `ipass > 0` and both this and `eval_prev` are not `None`
307+ then they ary used to estimate the signal when building the noise model.
308+ logger : Logger
309+ The logger to log to.
310+ """
138311 if passinfo .downsample != 1 :
139312 aman = mapmaking .downsample_obs (aman , passinfo .downsample )
140313 raise ValueError ("downsampling not properly implemented currently" )
141- aman .signal = aman .signal .astype (np .float32 )
314+ aman .signal = np . array ( aman .signal ) .astype (np .float32 )
142315 if "weather" not in aman :
143316 aman .wrap ("weather" , np .full (1 , "typical" ))
144317 if "site" not in aman :
@@ -162,8 +335,59 @@ def add_obs_to_mapmaker(
162335
163336
164337def make_ml_map (
165- amans , passes , shape , wcs , prefix , out_dir , comm , logger , cfg , guess = None
166- ):
338+ amans : dict [str , tuple [AxisManager , pmat .P ]],
339+ passes : list [bunch .Bunch ],
340+ shape : tuple [int , int ],
341+ wcs : WCS ,
342+ prefix : str ,
343+ out_dir : str ,
344+ comm : MPI .Comm ,
345+ logger : Logger ,
346+ cfg : Namespace ,
347+ guess : Optional [enmap .ndmap ] = None ,
348+ ) -> tuple [Optional [enmap .ndmap ], tuple [str , str , str , str ]]:
349+ """
350+ Make a multipass ML source map using the sotodlib ML mapmaker.
351+ May be worth adding a `sogma` option down the line.
352+
353+ Parameters
354+ ----------
355+ amans : dict[str, tuple[AxisManager, pmat.P]]
356+ The AxisManagers to map.
357+ Each entry should map a sub_id (with format `{obs_id}:{ws}:{band}`)
358+ to a tuple consisting of an AxisManager that is preprocessed and ready to map
359+ and the corresponding projection matrix.
360+ passes : list[bunch.Bunch]
361+ The output of `setup_passes`.
362+ shape : tuple[int, int]
363+ The desired shape of the output map.
364+ wcs : WCS
365+ The WCS to be used for the output map.
366+ prefix : str
367+ The prefix to be preprended to the map filenames.
368+ out_dir : str
369+ The directory to save maps to.
370+ comm : MPI.Comm
371+ The MPI communicator to use when mapmaking.
372+ All processes must have a non-zero sized `amans`.
373+ logger : Logger
374+ The logger to log to.
375+ cfg : Namespace
376+ The loaded configuration.
377+ See `lat_beams.utils.config` for details.
378+ guess : Optional[enmap.ndmap], default: None
379+ A map to use to estimate the signal when constructing
380+ the noise model in the first pass of mapmaking.
381+ Pass `None` if you don't have a good starting guess.
382+
383+ Returns
384+ -------
385+ outmap : Optional[enmap.ndmap]
386+ The signal map from the final pass.
387+ Will be `None` if no passes are run.
388+ paths : tuple[str, str, str, str]
389+ The paths to the signal, rhs, div, and bin maps from the final pass.
390+ """
167391 mlmap_path = ""
168392 rhs_path = ""
169393 div_path = ""
@@ -191,7 +415,7 @@ def make_ml_map(
191415 wcs ,
192416 comm ,
193417 comps = cfg .comps ,
194- dtype = np .float64 ,
418+ dtype = np .float64 , # type: ignore
195419 tiled = False ,
196420 interpol = passinfo .interpol ,
197421 )
@@ -229,7 +453,11 @@ def make_ml_map(
229453 logger .debug ("\t Wrote rhs, div, bin" )
230454
231455 # Set up initial condition
232- x0 = None if ipass == 0 else mapmaker .translate (mapmaker_prev , eval_prev .x_zip )
456+ x0 = (
457+ None
458+ if ipass == 0 or eval_prev is None or mapmaker_prev is None
459+ else mapmaker .translate (mapmaker_prev , eval_prev .x_zip )
460+ )
233461
234462 # Solve
235463 t1 = time .time ()
@@ -253,10 +481,15 @@ def make_ml_map(
253481 raise ValueError ("Mapmaker ran 0 steps!" )
254482 for signal , val in zip (signals , step .x ):
255483 if signal .output :
256- outmap = val
484+ outmap = cast ( enmap . ndmap , val )
257485 mlmap_path = signal .write (pass_prefix , "map" , val , unit = "pW" )
258486
259487 mapmaker_prev = mapmaker
260488 eval_prev = mapmaker .evaluator (step .x_zip )
261489
490+ mlmap_path = "" if mlmap_path is None else mlmap_path
491+ rhs_path = "" if rhs_path is None else rhs_path
492+ div_path = "" if div_path is None else div_path
493+ bin_path = "" if bin_path is None else bin_path
494+
262495 return outmap , (mlmap_path , rhs_path , div_path , bin_path )
0 commit comments