Skip to content

Commit f692597

Browse files
authored
Merge pull request #19 from simonsobs/docs2
Docs2
2 parents c740529 + 5db74c0 commit f692597

4 files changed

Lines changed: 434 additions & 61 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
__pycache__
2+
LAT_beams.egg-info

lat_beams/mapmaking.py

Lines changed: 265 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,67 @@
1+
"""
2+
Functions for producing source maps.
3+
"""
4+
15
import glob
26
import logging
37
import os
48
import time
9+
from argparse import Namespace
10+
from logging import Logger
11+
from typing import Optional, cast
512

613
import numpy as np
7-
from pixell import enmap
14+
import sqlalchemy as sqy
15+
from astropy.wcs import WCS
16+
from mpi4py import MPI
17+
from pixell import bunch, enmap
818
from so3g.proj import RangesMatrix
919
from sotodlib import mapmaking, tod_ops
1020
from sotodlib.coords import planets as cp
21+
from sotodlib.coords import pmat
22+
from sotodlib.core import AxisManager
23+
from sotodlib.site_pipeline import jobdb
1124

1225
from .beam_utils import estimate_cent
1326
from .utils import log_lvl, set_tag
1427

1528

16-
def make_cuts(aman, source_flags, n_modes, job, logger, cfg):
29+
def make_cuts(
30+
aman: AxisManager,
31+
source_flags: RangesMatrix,
32+
n_modes: int,
33+
job: jobdb.Job,
34+
logger: Logger,
35+
cfg: Namespace,
36+
) -> Optional[RangesMatrix]:
37+
"""
38+
Compute cuts on a source TOD before mapping.
39+
This filters for the source and then calculates the peak SNR of each detector
40+
by dividing the max of the region in `source_flags` by the standard deviation of the region outside of `source_flags`.
41+
Any detector with an SNR less than `cfg.min_snr` is cut.
42+
43+
Parameters
44+
----------
45+
aman : AxisManager
46+
The loaded data ready to be mapped.
47+
source_flags : RangesMatrix
48+
RangesMatrix with all samples within some radius of the source flagged.
49+
n_modes : int
50+
The number of modes to use when filtering.
51+
job : jobdb.Job
52+
The job associated with making this map.
53+
logger : Logger
54+
The logger to log to.
55+
cfg : Namespace
56+
The loaded configuration.
57+
See `lat_beams.utils.config` for details.
58+
59+
Returns
60+
-------
61+
cuts : Optional[RangesMatrix]
62+
The calculated cuts.
63+
If the number of uncut detectors is less than `cfg.min_dets` then `None` is returned.
64+
"""
1765
sig_filt = cp.filter_for_sources(
1866
tod=aman,
1967
signal=aman.signal.copy(),
@@ -41,45 +89,105 @@ def make_cuts(aman, source_flags, n_modes, job, logger, cfg):
4189
msg = f"Not enough detectors after source flag cuts!"
4290
logger.error("\t%s", msg)
4391
set_tag(job, "message", msg)
44-
job.jstate = "failed"
92+
job.jstate = cast(sqy.Column[str], jobdb.JState.failed)
93+
4594
return None
4695
return cuts
4796

4897

4998
def make_map(
50-
aman,
51-
src_to_map,
52-
res,
53-
cuts,
54-
source_flags,
55-
comps,
56-
n_modes,
57-
pixsize,
58-
filename,
59-
min_det_secs,
60-
info,
61-
job,
62-
map_str,
63-
logger,
64-
cfg,
65-
):
99+
aman: AxisManager,
100+
src_to_map: str,
101+
res: float,
102+
cuts: RangesMatrix,
103+
source_flags: RangesMatrix,
104+
comps: str,
105+
n_modes: int,
106+
pixsize: float, # TODO: This doesn't need to exist
107+
filename: str,
108+
min_det_secs: float,
109+
info: dict[str, str],
110+
job: jobdb.Job,
111+
map_str: str,
112+
logger: Logger,
113+
cfg: Namespace,
114+
) -> tuple[Optional[dict], Optional[tuple[int, int]]]:
115+
"""
116+
Make a filter-bin map of a source and estimate the center.
117+
The map will be in source-scan coordinates and uses `domdir` threading.
118+
If the mapping fails the associated job will have its message updated with an explanation.
119+
If mapping succeeds then paths to the map's files will be added to the job.
120+
121+
Once mapped the center of the map and its SNR are estimated.
122+
If the SNR is below `cfg.min_snr` and `cfg.del_map` is set then the files
123+
associated with the map are deleted.
124+
125+
Parameters
126+
----------
127+
aman : AxisManager
128+
The loaded data ready to be mapped.
129+
src_to_map : str
130+
The name of the source to map.
131+
res : float
132+
The desired map resolution in radians.
133+
cuts : RangesMatrix
134+
The output of `make_cuts`.
135+
source_flags : RangesMatrix
136+
RangesMatrix with all samples within some radius of the source flagged.
137+
comps : str
138+
The maps to compute, should be `T` or `TQU`.
139+
n_modes : int
140+
The number of modes to use when filtering.
141+
pixsize : float
142+
`res` in arcseconds
143+
filename : str
144+
The pattern for the output map filename.
145+
See `sotodlib.coords.planets.make_map` for details.
146+
min_det_secs : float
147+
The minimum number of detector seconds.
148+
If we are below this a map is not made.
149+
info : dict[str, str]
150+
The information used to fill in `filename`.
151+
See `sotodlib.coords.planets.make_map` for details.
152+
job : jobdb.Job
153+
The job associated with making this map.
154+
map_str : str
155+
A short string to describe the map in the logs and job (ie. "initial").
156+
logger : Logger
157+
The logger to log to.
158+
cfg : Namespace
159+
The loaded configuration.
160+
See `lat_beams.utils.config` for details.
161+
162+
Returns
163+
-------
164+
outmap : Optional[dict]
165+
The output of `sotodlib.coords.planets.make_map`.
166+
`None` is returned if we are below `min_det_secs` or `cfg.min_snr`.
167+
cent : Optional[tuple[int, int]]]
168+
The estimated center of the map.
169+
`None` is returned if we are below `min_det_secs` or `cfg.min_snr`.
170+
"""
66171
# Get time on source
67-
det_secs = np.sum((source_flags * ~cuts).get_stats()["samples"]) * np.mean(
68-
np.diff(aman.timestamps)
172+
sf_uncut = source_flags * ~cuts
173+
if sf_uncut is None:
174+
raise ValueError("RangesMatrix somehow became none...")
175+
det_secs = np.sum(sf_uncut.get_stats()["samples"]) * np.mean(
176+
np.diff(np.array(aman.timestamps))
69177
)
70178
logger.debug("\t%s detector seconds on source in %s mask", det_secs, map_str)
71179
if det_secs < min_det_secs:
72180
msg = f"\tNot enough time on source in {map_str} mask."
73181
logger.error("\t%s", msg)
74182
set_tag(job, "message", msg)
75-
job.jstate = "failed"
183+
job.jstate = cast(sqy.Column[str], jobdb.JState.failed)
76184
return None, None
77185

78186
# Initial map
79187
with log_lvl(logger, logging.WARNING):
80188
out = cp.make_map(
81189
aman.copy(),
82-
thread_algo="domdir",
190+
thread_algo="domdir", # type: ignore
83191
center_on=src_to_map,
84192
res=res,
85193
cuts=cuts,
@@ -102,7 +210,7 @@ def make_map(
102210
msg = f"{map_str.title()} map SNR too low."
103211
logger.error("\t%s", msg)
104212
set_tag(job, "message", msg)
105-
job.jstate = "failed"
213+
job.jstate = cast(sqy.Column[str], jobdb.JState.failed)
106214
if cfg.del_map and filename is not None:
107215
logger.debug("\tDeleting map files")
108216
glob_path = os.path.splitext(filename)[0] + "*.*"
@@ -116,7 +224,26 @@ def make_map(
116224
return out, cent
117225

118226

119-
def get_passes(cfg):
227+
def get_passes(cfg: Namespace) -> list[bunch.Bunch]:
228+
"""
229+
Setup passes for making an ML mapmaker.
230+
The output will have `cfg.mlpass` elements with all of them
231+
having a downsampling factor of 1.
232+
The last pass will hade bilinear interpolations and the rest will be nearest neighbor.
233+
The i'th pass will have `max(1, cfg.cgiters//2, cfg.cgiters//(i + 1))` CG iters.
234+
235+
Parameters
236+
----------
237+
cfg : Namespace
238+
The loaded configuration.
239+
See `lat_beams.utils.config` for details.
240+
241+
Returns
242+
-------
243+
passes : list[bunch.Bunch]
244+
The passes for the sotodlib ML mapmaker.
245+
See this function's docstring for details on the contents.
246+
"""
120247
passes = []
121248
if cfg.mlpass > 0:
122249
dsstr = "1"
@@ -133,12 +260,58 @@ def get_passes(cfg):
133260

134261

135262
def add_obs_to_mapmaker(
136-
aman, sub_id, mapmaker, ipass, passinfo, P, guess, eval_prev, mapmaker_prev, logger
263+
aman: AxisManager,
264+
sub_id: str,
265+
mapmaker: mapmaking.MLMapmaker,
266+
ipass: int,
267+
passinfo: bunch.Bunch,
268+
P: pmat.P,
269+
guess: Optional[enmap.ndmap],
270+
eval_prev: Optional[mapmaking.MLEvaluator],
271+
mapmaker_prev: Optional[mapmaking.MLMapmaker],
272+
logger: Logger,
137273
):
274+
"""
275+
Add an observation to the MLMapmaker.
276+
This makes sure that the correct weather and site are included and will use
277+
an input map or a previous iteration of the mapmaker to estimate the signal
278+
when computing the noise model if they are provided.
279+
280+
Parameters
281+
----------
282+
aman : AxisManager
283+
The TOD after preprocessing.
284+
The cut samples should be in the `glitch` flag.
285+
sub_id : str
286+
The `sub_id` of the TOD.
287+
Should have format `{obs_id}:{ws}:{band}`.
288+
mapmaker : mapmaking.MLMapmaker
289+
The mapmaker instance to add the observation to.
290+
ipass : int
291+
The pass number.
292+
passinfo : bunch.Bunch
293+
The pass we are on. See `setup_passes` for details.
294+
P : pmat.P
295+
The `sotodlib` projectioe matrix.
296+
guess : Optional[enmap.ndmap]
297+
A guess at what the map is.
298+
If this is not `None` and `ipass == 0` then this is
299+
used to estimate the signal when building the noise model.
300+
eval_prev : Optional[mapmakin.MLEvaluator]
301+
Evaluator for the previous pass of the mapmaker.
302+
If `ipass > 0` and both this and `mapmaker_prev` are not `None`
303+
then they ary used to estimate the signal when building the noise model.
304+
mapmaker_prev : Optional[mapmakin.MLMapmaker]
305+
Mapmaker instance for the previous pass of the mapmaker.
306+
If `ipass > 0` and both this and `eval_prev` are not `None`
307+
then they ary used to estimate the signal when building the noise model.
308+
logger : Logger
309+
The logger to log to.
310+
"""
138311
if passinfo.downsample != 1:
139312
aman = mapmaking.downsample_obs(aman, passinfo.downsample)
140313
raise ValueError("downsampling not properly implemented currently")
141-
aman.signal = aman.signal.astype(np.float32)
314+
aman.signal = np.array(aman.signal).astype(np.float32)
142315
if "weather" not in aman:
143316
aman.wrap("weather", np.full(1, "typical"))
144317
if "site" not in aman:
@@ -162,8 +335,59 @@ def add_obs_to_mapmaker(
162335

163336

164337
def make_ml_map(
165-
amans, passes, shape, wcs, prefix, out_dir, comm, logger, cfg, guess=None
166-
):
338+
amans: dict[str, tuple[AxisManager, pmat.P]],
339+
passes: list[bunch.Bunch],
340+
shape: tuple[int, int],
341+
wcs: WCS,
342+
prefix: str,
343+
out_dir: str,
344+
comm: MPI.Comm,
345+
logger: Logger,
346+
cfg: Namespace,
347+
guess: Optional[enmap.ndmap] = None,
348+
) -> tuple[Optional[enmap.ndmap], tuple[str, str, str, str]]:
349+
"""
350+
Make a multipass ML source map using the sotodlib ML mapmaker.
351+
May be worth adding a `sogma` option down the line.
352+
353+
Parameters
354+
----------
355+
amans : dict[str, tuple[AxisManager, pmat.P]]
356+
The AxisManagers to map.
357+
Each entry should map a sub_id (with format `{obs_id}:{ws}:{band}`)
358+
to a tuple consisting of an AxisManager that is preprocessed and ready to map
359+
and the corresponding projection matrix.
360+
passes : list[bunch.Bunch]
361+
The output of `setup_passes`.
362+
shape : tuple[int, int]
363+
The desired shape of the output map.
364+
wcs : WCS
365+
The WCS to be used for the output map.
366+
prefix : str
367+
The prefix to be preprended to the map filenames.
368+
out_dir : str
369+
The directory to save maps to.
370+
comm : MPI.Comm
371+
The MPI communicator to use when mapmaking.
372+
All processes must have a non-zero sized `amans`.
373+
logger : Logger
374+
The logger to log to.
375+
cfg : Namespace
376+
The loaded configuration.
377+
See `lat_beams.utils.config` for details.
378+
guess : Optional[enmap.ndmap], default: None
379+
A map to use to estimate the signal when constructing
380+
the noise model in the first pass of mapmaking.
381+
Pass `None` if you don't have a good starting guess.
382+
383+
Returns
384+
-------
385+
outmap : Optional[enmap.ndmap]
386+
The signal map from the final pass.
387+
Will be `None` if no passes are run.
388+
paths : tuple[str, str, str, str]
389+
The paths to the signal, rhs, div, and bin maps from the final pass.
390+
"""
167391
mlmap_path = ""
168392
rhs_path = ""
169393
div_path = ""
@@ -191,7 +415,7 @@ def make_ml_map(
191415
wcs,
192416
comm,
193417
comps=cfg.comps,
194-
dtype=np.float64,
418+
dtype=np.float64, # type: ignore
195419
tiled=False,
196420
interpol=passinfo.interpol,
197421
)
@@ -229,7 +453,11 @@ def make_ml_map(
229453
logger.debug("\tWrote rhs, div, bin")
230454

231455
# Set up initial condition
232-
x0 = None if ipass == 0 else mapmaker.translate(mapmaker_prev, eval_prev.x_zip)
456+
x0 = (
457+
None
458+
if ipass == 0 or eval_prev is None or mapmaker_prev is None
459+
else mapmaker.translate(mapmaker_prev, eval_prev.x_zip)
460+
)
233461

234462
# Solve
235463
t1 = time.time()
@@ -253,10 +481,15 @@ def make_ml_map(
253481
raise ValueError("Mapmaker ran 0 steps!")
254482
for signal, val in zip(signals, step.x):
255483
if signal.output:
256-
outmap = val
484+
outmap = cast(enmap.ndmap, val)
257485
mlmap_path = signal.write(pass_prefix, "map", val, unit="pW")
258486

259487
mapmaker_prev = mapmaker
260488
eval_prev = mapmaker.evaluator(step.x_zip)
261489

490+
mlmap_path = "" if mlmap_path is None else mlmap_path
491+
rhs_path = "" if rhs_path is None else rhs_path
492+
div_path = "" if div_path is None else div_path
493+
bin_path = "" if bin_path is None else bin_path
494+
262495
return outmap, (mlmap_path, rhs_path, div_path, bin_path)

0 commit comments

Comments
 (0)