Skip to content

Commit 94a16e2

Browse files
committed
Embed ASDF data into FITS extension
1 parent 48c2b88 commit 94a16e2

File tree

3 files changed

+78
-35
lines changed

3 files changed

+78
-35
lines changed

astrocut/asdf_cutout.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from astropy.utils.decorators import deprecated_renamed_argument
1818
from astropy.wcs import WCS
1919
from s3path import S3Path
20+
from stdatamodels import asdf_in_fits
2021

2122
from . import log, __version__
2223
from .image_cutout import ImageCutout
@@ -106,17 +107,21 @@ def fits_cutouts(self) -> List[fits.HDUList]:
106107
"""
107108
if not self._fits_cutouts:
108109
fits_cutouts = []
109-
for file, cutouts in self.cutouts_by_file.items():
110-
# TODO: Create a FITS object with ASDF extension
111-
# Create a primary FITS header to hold data and WCS
110+
for i, (file, cutouts) in enumerate(self.cutouts_by_file.items()):
112111
cutout = cutouts[0]
113-
primary_hdu = fits.PrimaryHDU(data=cutout.data, header=cutout.wcs.to_header(relax=True))
112+
if self._lite:
113+
tree = self._get_lite_tree(file, cutout, self._gwcs_objects[i])
114+
else:
115+
tree = self._asdf_trees[i]
114116

115-
# Add original file to header
116-
primary_hdu.header['ORIG_FLE'] = str(file)
117+
# Create a primary FITS header to hold data and WCS
118+
primary_hdu = fits.PrimaryHDU(data=cutout.data, header=cutout.wcs.to_header(relax=True))
119+
primary_hdu.header['ORIG_FLE'] = file # Add original file to header
120+
hdul = fits.HDUList([primary_hdu])
117121

118-
# Write to HDUList
119-
fits_cutouts.append(fits.HDUList([primary_hdu]))
122+
# Embed ASDF into FITS
123+
hdul_embed = asdf_in_fits.to_hdulist(tree, hdul)
124+
fits_cutouts.append(hdul_embed)
120125
self._fits_cutouts = fits_cutouts
121126
return self._fits_cutouts
122127

@@ -130,13 +135,7 @@ def asdf_cutouts(self) -> List[asdf.AsdfFile]:
130135
for i, (file, cutouts) in enumerate(self.cutouts_by_file.items()):
131136
cutout = cutouts[0]
132137
if self._lite:
133-
tree = {
134-
self._mission_kwd: {
135-
'meta': {'wcs': self._slice_gwcs(cutout, self._gwcs_objects[i]),
136-
'orig_file': str(file)},
137-
'data': cutout.data
138-
}
139-
}
138+
tree = self._get_lite_tree(file, cutout, self._gwcs_objects[i])
140139
else:
141140
tree = self._asdf_trees[i]
142141

@@ -156,6 +155,32 @@ def asdf_cutouts(self) -> List[asdf.AsdfFile]:
156155

157156
self._asdf_cutouts = asdf_cutouts
158157
return self._asdf_cutouts
158+
159+
def _get_lite_tree(self, file: str, cutout: Cutout2D, gwcs: gwcs.wcs.WCS) -> dict:
160+
"""
161+
Helper function to create an ASDF tree in lite mode.
162+
163+
Parameters
164+
----------
165+
file : str
166+
The input filename.
167+
cutout : `~astropy.nddata.Cutout2D`
168+
The cutout object.
169+
gwcs : gwcs.wcs.WCS
170+
The original GWCS object.
171+
172+
Returns
173+
-------
174+
tree : dict
175+
The ASDF tree in lite mode. The tree contains only the cutout data and the sliced GWCS.
176+
"""
177+
return {
178+
self._mission_kwd: {
179+
'meta': {'wcs': self._slice_gwcs(cutout, gwcs),
180+
'orig_file': file},
181+
'data': cutout.data
182+
}
183+
}
159184

160185
def _get_cloud_http(self, input_file: Union[str, S3Path]) -> str:
161186
"""
@@ -208,7 +233,7 @@ def _load_file_data(self, input_file: Union[str, Path, S3Path]) -> dict:
208233

209234
return tree
210235

211-
def _make_cutout(self, array, position, wcs):
236+
def _make_cutout(self, array: np.ndarray, position: tuple, wcs: WCS) -> Cutout2D:
212237
"""
213238
Helper to generate a Cutout2D and return plain ndarray data.
214239

astrocut/tests/test_asdf_cutout.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from astropy.io import fits
1515
from gwcs import wcs, coordinate_frames
1616
from PIL import Image
17+
from stdatamodels import asdf_in_fits
1718

1819
from astrocut.asdf_cutout import ASDFCutout, asdf_cut, get_center_pixel
1920
from astrocut.exceptions import DataWarning, InvalidInputError, InvalidQueryError
@@ -158,24 +159,28 @@ def test_asdf_cutout(test_images, center_coord, cutout_size):
158159

159160

160161
def test_asdf_cutout_write_to_file(test_images, center_coord, cutout_size, tmpdir):
162+
def check_asdf_metadata(af, original_file, cutout_data):
163+
"""Check that ASDF file contains correct metadata"""
164+
assert 'roman' in af
165+
assert 'meta' in af['roman']
166+
# Check cutout data and metadata
167+
for key in ['data', 'dq', 'err', 'context']:
168+
assert key in af['roman']
169+
assert np.all(af['roman'][key] == cutout_data)
170+
meta = af['roman']['meta']
171+
assert meta['wcs'].pixel_shape == (10, 10)
172+
assert meta['product_type'] == 'l2'
173+
assert meta['file_date'] == Time('2023-10-01T00:00:00', format='isot')
174+
assert meta['origin'] == 'STSCI/SOC'
175+
assert meta['orig_file'] == original_file.as_posix()
176+
161177
# Write cutouts to ASDF files on disk
162178
cutout = ASDFCutout(test_images, center_coord, cutout_size)
163179
asdf_files = cutout.write_as_asdf(output_dir=tmpdir)
164180
assert len(asdf_files) == 3
165181
for i, asdf_file in enumerate(asdf_files):
166182
with asdf.open(asdf_file) as af:
167-
assert 'roman' in af
168-
assert 'meta' in af['roman']
169-
# Check cutout data and metadata
170-
for key in ['data', 'dq', 'err', 'context']:
171-
assert key in af['roman']
172-
assert np.all(af['roman'][key] == cutout.cutouts[i].data)
173-
meta = af['roman']['meta']
174-
assert meta['wcs'].pixel_shape == (10, 10)
175-
assert meta['product_type'] == 'l2'
176-
assert meta['file_date'] == Time('2023-10-01T00:00:00', format='isot')
177-
assert meta['origin'] == 'STSCI/SOC'
178-
assert meta['orig_file'] == test_images[i].as_posix()
183+
check_asdf_metadata(af, test_images[i], cutout.cutouts[i].data)
179184
# Check file size is smaller than original
180185
assert Path(asdf_file).stat().st_size < Path(test_images[i]).stat().st_size
181186

@@ -191,6 +196,10 @@ def test_asdf_cutout_write_to_file(test_images, center_coord, cutout_size, tmpdi
191196
assert hdul[0].header['ORIG_FLE'] == test_images[i].as_posix()
192197
assert Path(fits_file).stat().st_size < Path(test_images[i]).stat().st_size
193198

199+
# Check ASDF extension contents
200+
with asdf_in_fits.open(fits_file) as af:
201+
check_asdf_metadata(af, test_images[i], cutout.cutouts[i].data)
202+
194203

195204
@pytest.mark.parametrize('output_format', ['.asdf', '.fits'])
196205
def test_asdf_cutout_write_to_zip(tmpdir, test_images, center_coord, cutout_size, output_format):
@@ -215,7 +224,7 @@ def test_asdf_cutout_write_to_zip(tmpdir, test_images, center_coord, cutout_size
215224
else:
216225
with fits.open(io.BytesIO(data)) as hdul:
217226
assert isinstance(hdul, fits.HDUList)
218-
assert len(hdul) == 1
227+
assert len(hdul) == 2 # primary + embedded ASDF extension
219228
assert hdul[0].data.shape == (cutout_size, cutout_size)
220229

221230

@@ -226,10 +235,9 @@ def test_asdf_cutout_write_to_zip_invalid_format(tmpdir, test_images, center_coo
226235
cutout.write_as_zip(output_dir=tmpdir, output_format='.invalid')
227236

228237

229-
def test_asdf_cutout_lite(test_images, center_coord, cutout_size, tmpdir):
230-
# Write cutouts to ASDF objects in lite mode
231-
cutout = ASDFCutout(test_images, center_coord, cutout_size, lite=True)
232-
for af in cutout.asdf_cutouts:
238+
def test_asdf_cutout_lite(test_images, center_coord, cutout_size):
239+
def check_lite_metadata(af):
240+
"""Check that ASDF file contains only lite metadata"""
233241
assert 'roman' in af
234242
assert 'data' in af['roman']
235243
assert 'meta' in af['roman']
@@ -238,11 +246,21 @@ def test_asdf_cutout_lite(test_images, center_coord, cutout_size, tmpdir):
238246
assert len(af['roman']) == 2 # only data and meta
239247
assert len(af['roman']['meta']) == 2 # only wcs and original filename
240248

249+
# Write cutouts to ASDF objects in lite mode
250+
cutout = ASDFCutout(test_images, center_coord, cutout_size, lite=True)
251+
for af in cutout.asdf_cutouts:
252+
check_lite_metadata(af)
253+
241254
# Write cutouts to HDUList objects in lite mode
242255
cutout = ASDFCutout(test_images, center_coord, cutout_size, lite=True)
243256
for hdul in cutout.fits_cutouts:
244-
assert len(hdul) == 1 # primary HDU only
257+
assert len(hdul) == 2 # primary HDU + embedded ASDF extension
245258
assert hdul[0].name == 'PRIMARY'
259+
assert hdul[1].name == 'ASDF'
260+
261+
# Check ASDF extension contents
262+
with asdf_in_fits.open(hdul) as af:
263+
check_lite_metadata(af)
246264

247265

248266
def test_asdf_cutout_partial(test_images, center_coord, cutout_size):

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ install_requires =
2222
fsspec[http]>=2022.8.2 # for remote cutouts
2323
s3fs>=2022.8.2 # for remote cutouts
2424
s3path>=0.5.7 # for remote file paths
25-
roman_datamodels>=0.19.0 # for roman file support
2625
requests>=2.32.3 # for making HTTP requests
2726
spherical_geometry>=1.3.0
2827
gwcs>=0.21.0
28+
stdatamodels>=4.1.0
2929
scipy
3030
Pillow
3131

0 commit comments

Comments
 (0)