Skip to content

Commit 1c4c526

Browse files
committed
Extend trace model to NIRSpec and MIRI slit and slitless modes
1 parent 22fedad commit 1c4c526

9 files changed

Lines changed: 505 additions & 112 deletions

File tree

jwst/adaptive_trace_model/adaptive_trace_model_step.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,24 +95,58 @@ def process(self, input_data):
9595
else:
9696
models = [output_model]
9797

98+
# Update each model in place
9899
for model in models:
99-
if not isinstance(model, datamodels.IFUImageModel):
100-
log.warning("The adaptive_trace_model step is only implemented for IFU data.")
101-
log.warning("Skipping processing for datamodel type %s.", str(output_model))
100+
log.info("Fitting trace model for %s", model.meta.filename)
101+
if isinstance(model, datamodels.MultiSlitModel):
102+
results = None
103+
if self.save_intermediate_results:
104+
new_model = datamodels.MultiSlitModel()
105+
new_model.update(model, only="PRIMARY")
106+
if self.oversample == 1.0:
107+
results = [None, new_model, new_model.copy(), None, None]
108+
else:
109+
results = [
110+
None,
111+
new_model,
112+
new_model.copy(),
113+
new_model.copy(),
114+
new_model.copy(),
115+
]
116+
for slit in model.slits:
117+
log.info(f"Working on slit {slit.name}")
118+
log.debug(f"Slit is of type {type(slit)}")
119+
120+
slit_results = fit_and_oversample(
121+
slit,
122+
fit_threshold=self.fit_threshold,
123+
slope_limit=self.slope_limit,
124+
oversample_factor=self.oversample,
125+
psf_optimal=self.psf_optimal,
126+
return_intermediate_models=self.save_intermediate_results,
127+
)
128+
if self.save_intermediate_results:
129+
for i, intermediate_model in enumerate(slit_results[1:]):
130+
if intermediate_model is not None:
131+
results[i + 1].slits.append(intermediate_model)
132+
133+
elif isinstance(model, (datamodels.SlitModel, datamodels.IFUImageModel)):
134+
results = fit_and_oversample(
135+
model,
136+
fit_threshold=self.fit_threshold,
137+
slope_limit=self.slope_limit,
138+
oversample_factor=self.oversample,
139+
psf_optimal=self.psf_optimal,
140+
return_intermediate_models=self.save_intermediate_results,
141+
)
142+
else:
143+
log.warning(
144+
"The adaptive_trace_model step is not implemented for %s.", str(output_model)
145+
)
146+
log.warning("Skipping processing.")
102147
model.meta.cal_step.adaptive_trace_model = "SKIPPED"
103148
continue
104149

105-
# Update the model in place
106-
log.info("Fitting trace model for %s", model.meta.filename)
107-
results = fit_and_oversample(
108-
model,
109-
fit_threshold=self.fit_threshold,
110-
slope_limit=self.slope_limit,
111-
oversample_factor=self.oversample,
112-
psf_optimal=self.psf_optimal,
113-
return_intermediate_models=self.save_intermediate_results,
114-
)
115-
116150
model.meta.cal_step.adaptive_trace_model = "COMPLETE"
117151
if self.save_intermediate_results:
118152
_, full_spline, used_spline, linear, residual = results

jwst/adaptive_trace_model/tests/helpers.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22
from stdatamodels.jwst import datamodels
33

44
from jwst.assign_wcs.assign_wcs_step import AssignWcsStep
5-
from jwst.assign_wcs.tests.test_miri import create_hdul
5+
from jwst.assign_wcs.tests.test_miri import (
6+
create_datamodel_cube,
7+
create_hdul,
8+
create_hdul_lrs_slitless,
9+
)
610
from jwst.assign_wcs.tests.test_nirspec import create_nirspec_fs_file, create_nirspec_ifu_file
711
from jwst.extract_2d.extract_2d_step import Extract2dStep
812

913
__all__ = [
14+
"miri_lrs_slitless_model_with_source",
1015
"miri_mrs_model",
1116
"miri_mrs_model_with_source",
1217
"nirspec_ifu_model",
@@ -16,6 +21,39 @@
1621
]
1722

1823

24+
def miri_lrs_slitless_model_with_source():
25+
"""
26+
Create a mock MIRI LRS slitless model with a simple spectral source in the data array.
27+
28+
Returns
29+
-------
30+
model : `~stdatamodels.jwst.datamodels.SlitModel`
31+
The LRS slitless datamodel.
32+
"""
33+
shape = (5, 416, 72)
34+
hdul = create_hdul_lrs_slitless()
35+
cube_model = create_datamodel_cube(hdul, shape)
36+
hdul.close()
37+
38+
model = datamodels.SlitModel(cube_model)
39+
cube_model.close()
40+
41+
model.data = np.full(shape, np.nan)
42+
model.err = np.zeros(shape)
43+
model.dq = np.zeros(shape, dtype=np.uint32)
44+
model.var_poisson = np.zeros(shape)
45+
model.var_rnoise = np.zeros(shape)
46+
47+
ysize, xsize = shape[-2:]
48+
x, y = np.meshgrid(np.arange(xsize), np.arange(ysize))
49+
_, _, lam = model.meta.wcs(x, y)
50+
51+
region_map = (~np.isnan(lam)).astype(int)
52+
_add_source(model, region_map, along_x=False)
53+
54+
return model
55+
56+
1957
def miri_mrs_model(detector="MIRIFUSHORT", channel="12", band="SHORT", shape=(1024, 1032)):
2058
"""
2159
Create a mock MIRI MRS model.
@@ -141,9 +179,9 @@ def nirspec_ifu_model_with_source(wcs_style="coordinates"):
141179
return model
142180

143181

144-
def nirspec_slit_model_with_source():
182+
def nirspec_slit_model():
145183
"""
146-
Create a mock NIRSpec FS model with a simple spectral source in the data array.
184+
Create a mock NIRSpec FS model with no source in the data array.
147185
148186
Calls assign_wcs and extract_2d.
149187
@@ -157,8 +195,8 @@ def nirspec_slit_model_with_source():
157195
hdul.close()
158196

159197
shape = (2048, 2048)
160-
model.data = np.full(shape, np.nan)
161-
model.err = np.zeros(shape)
198+
model.data = np.ones(shape)
199+
model.err = model.data * 0.01
162200
model.dq = np.zeros(shape, dtype=np.uint32)
163201
model.var_poisson = np.zeros(shape)
164202
model.var_rnoise = np.zeros(shape)
@@ -172,6 +210,23 @@ def nirspec_slit_model_with_source():
172210
slit.meta.bunit_data = "MJy"
173211
else:
174212
slit.meta.bunit_data = "MJy/sr"
213+
214+
return model
215+
216+
217+
def nirspec_slit_model_with_source():
218+
"""
219+
Create a mock NIRSpec FS model with a simple spectral source in the data array.
220+
221+
Calls assign_wcs and extract_2d.
222+
223+
Returns
224+
-------
225+
model : `~stdatamodels.jwst.datamodels.MultiSlitModel`
226+
The FS datamodel.
227+
"""
228+
model = nirspec_slit_model()
229+
for slit in model.slits:
175230
region_map = (~np.isnan(slit.wavelength)).astype(int)
176231
_add_source(slit, region_map)
177232

jwst/adaptive_trace_model/tests/test_adaptive_trace_model_step.py

Lines changed: 112 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,20 @@ def nirspec_ifu_slice_wcs():
3737
model.close()
3838

3939

40+
@pytest.fixture(scope="module")
41+
def nirspec_slit_model():
42+
model = helpers.nirspec_slit_model()
43+
yield model
44+
model.close()
45+
46+
47+
@pytest.fixture(scope="module")
48+
def nirspec_slit_model_with_source():
49+
model = helpers.nirspec_slit_model_with_source()
50+
yield model
51+
model.close()
52+
53+
4054
@pytest.fixture()
4155
def asn_input(tmp_path, miri_mrs_model):
4256
"""
@@ -74,6 +88,13 @@ def asn_input(tmp_path, miri_mrs_model):
7488
return asn
7589

7690

91+
@pytest.fixture(scope="module")
92+
def miri_lrs_slitless_model_with_source():
93+
model = helpers.miri_lrs_slitless_model_with_source()
94+
yield model
95+
model.close()
96+
97+
7798
@pytest.mark.parametrize("dataset", ["miri_mrs_model", "nirspec_ifu_model_with_source"])
7899
def test_adaptive_trace_model_step_success(request, dataset):
79100
model = request.getfixturevalue(dataset)
@@ -239,7 +260,7 @@ def test_adaptive_trace_model_step_with_container(miri_mrs_model):
239260
def test_adaptive_trace_model_unsupported_model(caplog):
240261
model = datamodels.ImageModel()
241262
result = AdaptiveTraceModelStep.call(model)
242-
assert "only implemented for IFU" in caplog.text
263+
assert "not implemented for <ImageModel>" in caplog.text
243264

244265
assert result is not model
245266
assert result.meta.cal_step.adaptive_trace_model == "SKIPPED"
@@ -290,9 +311,11 @@ def test_save_container_asn_id_missing(tmp_path, asn_input):
290311
assert output.exists()
291312

292313

314+
@pytest.mark.parametrize("dataset", ["miri_mrs_model", "nirspec_slit_model"])
293315
@pytest.mark.parametrize("oversample", [1.0, 2.0])
294-
def test_adaptive_trace_model_step_save_intermediate(tmp_path, miri_mrs_model, oversample):
295-
model = miri_mrs_model
316+
def test_adaptive_trace_model_step_save_intermediate(tmp_path, request, dataset, oversample):
317+
model = request.getfixturevalue(dataset).copy()
318+
model.meta.filename = "test_input_cal.fits"
296319
AdaptiveTraceModelStep.call(
297320
model,
298321
oversample=oversample,
@@ -304,20 +327,98 @@ def test_adaptive_trace_model_step_save_intermediate(tmp_path, miri_mrs_model, o
304327

305328
# Check for expected files
306329
expected = [
307-
"test12SHORT_atm.fits",
308-
"test12SHORT_spline_full.fits",
309-
"test12SHORT_spline_used.fits",
330+
"test_input_atm.fits",
331+
"test_input_spline_full.fits",
332+
"test_input_spline_used.fits",
310333
]
334+
# For a model with no source, the spline models are expected to be empty
311335
expect_empty = [False, True, True]
312336
if oversample > 1:
313337
# Extra files expected if oversampling is done
314-
expected.extend(["test12SHORT_linear_interp.fits", "test12SHORT_spline_residual.fits"])
338+
expected.extend(["test_input_linear_interp.fits", "test_input_spline_residual.fits"])
315339
expect_empty.extend([False, True])
316340
for filename, is_empty in zip(expected, expect_empty):
317341
assert (tmp_path / filename).exists()
318-
with datamodels.open(str(tmp_path / filename)) as model:
319-
assert isinstance(model, datamodels.IFUImageModel)
342+
with datamodels.open(str(tmp_path / filename)) as new_model:
343+
assert isinstance(new_model, type(model))
344+
if isinstance(new_model, datamodels.MultiSlitModel):
345+
data = new_model.slits[0].data
346+
else:
347+
data = new_model.data
320348
if is_empty:
321-
assert np.all(np.isnan(model.data))
349+
assert np.all(np.isnan(data))
322350
else:
323-
assert not np.all(np.isnan(model.data))
351+
assert not np.all(np.isnan(data))
352+
353+
354+
def test_adaptive_trace_model_step_tso(miri_lrs_slitless_model_with_source):
355+
model = miri_lrs_slitless_model_with_source
356+
result = AdaptiveTraceModelStep.call(model, oversample=1)
357+
assert result.meta.cal_step.adaptive_trace_model == "COMPLETE"
358+
359+
# data is unchanged with oversample=1
360+
np.testing.assert_equal(result.data, model.data)
361+
362+
# trace_model is attached, contains non-NaN trace from the median image
363+
assert result.trace_model.shape == result.data.shape[-2:]
364+
indx = ~np.isnan(result.data[0]) & ~np.isnan(result.trace_model)
365+
assert np.all(np.isnan(result.trace_model[~indx]))
366+
assert np.all(~np.isnan(result.trace_model[indx]))
367+
368+
# fit trace is a reasonable model of the slit but not perfect
369+
atol = 0.25 * np.nanmax(model.data)
370+
np.testing.assert_allclose(result.data[0, indx], result.trace_model[indx], atol=atol)
371+
372+
result.close()
373+
374+
375+
def test_adaptive_trace_model_step_tso_oversample(miri_lrs_slitless_model_with_source):
376+
model = miri_lrs_slitless_model_with_source
377+
with pytest.raises(ValueError, match="Oversampling is not supported for TSO data"):
378+
AdaptiveTraceModelStep.call(model, oversample=2)
379+
380+
381+
def test_adaptive_trace_model_step_oversample_slit(nirspec_slit_model_with_source):
382+
model = nirspec_slit_model_with_source
383+
result = AdaptiveTraceModelStep.call(model, oversample=2, slope_limit=0.05, fit_threshold=0.0)
384+
assert result.meta.cal_step.adaptive_trace_model == "COMPLETE"
385+
assert isinstance(result, datamodels.MultiSlitModel)
386+
387+
# data is twice the size of the input along the x axis
388+
extnames = ["data", "dq", "err", "var_poisson", "var_rnoise", "var_flat"]
389+
input_models = model.slits
390+
output_models = result.slits
391+
for input_model, output_model in zip(input_models, output_models, strict=True):
392+
for extname in extnames:
393+
# check for extension presence
394+
if not input_model.hasattr(extname):
395+
assert not output_model.hasattr(extname)
396+
continue
397+
398+
# check that shape is expected
399+
input_ext = getattr(input_model, extname)
400+
output_ext = getattr(output_model, extname)
401+
assert output_ext.shape == (input_ext.shape[0] * 2, input_ext.shape[1])
402+
403+
# trace_model is attached, contains non-NaN trace for the one bright slit only,
404+
# and only over the core of the source trace
405+
assert output_model.trace_model.shape == output_model.data.shape
406+
407+
indx = ~np.isnan(output_model.wavelength)
408+
assert np.all(np.isnan(output_model.trace_model[~indx]))
409+
assert np.sum(~np.isnan(output_model.trace_model[indx])) > 0.15 * np.sum(indx)
410+
411+
# fit trace is a reasonable model of the slice but not identical -
412+
# the slice is mostly linearly interpolated
413+
valid = indx & ~np.isnan(output_model.data) & ~np.isnan(output_model.trace_model)
414+
atol = 0.25 * np.nanmax(input_model.data)
415+
if output_model.meta.bunit_data == "MJy":
416+
# for flux density units, we need to account for flux conservation,
417+
# due to pixel size change
418+
factor = 2.0
419+
else:
420+
factor = 1.0
421+
np.testing.assert_allclose(
422+
output_model.data[valid] * factor, output_model.trace_model[valid], atol=atol
423+
)
424+
result.close()

0 commit comments

Comments
 (0)