Skip to content

Commit 872d8f9

Browse files
committed
Clean up tests, comments
1 parent 622a560 commit 872d8f9

File tree

3 files changed

+86
-625
lines changed

3 files changed

+86
-625
lines changed

astrocut/asdf_cutouts.py

Lines changed: 52 additions & 266 deletions
Original file line numberDiff line numberDiff line change
@@ -1,272 +1,59 @@
11
# Licensed under a 3-clause BSD style license - see LICENSE.rst
22

33
"""This module implements cutout functionality similar to fitscut, but for the ASDF file format."""
4-
import copy
5-
import pathlib
6-
from typing import Union, Tuple
7-
import requests
4+
from pathlib import Path
5+
from typing import List, Union
86

9-
import asdf
107
import astropy
118
import gwcs
129
import numpy as np
13-
import s3fs
10+
from astropy.utils.decorators import deprecated_renamed_argument
1411
from s3path import S3Path
1512

16-
from astropy.coordinates import SkyCoord
17-
from astropy.modeling import models
18-
19-
from . import log
20-
from .utils.utils import _handle_verbose
21-
22-
23-
def _get_cloud_http(s3_uri: Union[str, S3Path], key: str = None, secret: str = None,
24-
token: str = None, verbose: bool = False) -> str:
25-
"""
26-
Get the HTTP URI of a cloud resource from an S3 URI.
27-
28-
Parameters
29-
----------
30-
s3_uri : string | S3Path
31-
the S3 URI of the cloud resource
32-
key : string
33-
Default None. Access key ID for S3 file system.
34-
secret : string
35-
Default None. Secret access key for S3 file system.
36-
token : string
37-
Default None. Security token for S3 file system.
38-
verbose : bool
39-
Default False. If true intermediate information is printed.
40-
"""
41-
42-
# check if public or private by sending an HTTP request
43-
s3_path = S3Path.from_uri(s3_uri) if isinstance(s3_uri, str) else s3_uri
44-
url = f'https://{s3_path.bucket}.s3.amazonaws.com/{s3_path.key}'
45-
resp = requests.head(url, timeout=10)
46-
is_anon = False if resp.status_code == 403 else True
47-
if not is_anon:
48-
log.debug('Attempting to access private S3 bucket: %s', s3_path.bucket)
49-
50-
# create file system and get URL of file
51-
fs = s3fs.S3FileSystem(anon=is_anon, key=key, secret=secret, token=token)
52-
with fs.open(s3_uri, 'rb') as f:
53-
return f.url()
13+
from .ASDFCutout import ASDFCutout
5414

5515

5616
def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple:
5717
"""
58-
Get the center pixel from a Roman 2D science image.
59-
60-
For an input RA, Dec sky coordinate, get the closest pixel location
61-
on the input Roman image.
18+
Get the closest pixel location on an input image for a given set of coordinates.
6219
6320
Parameters
6421
----------
6522
gwcsobj : gwcs.wcs.WCS
66-
The Roman GWCS object.
23+
The GWCS object.
6724
ra : float
68-
The input right ascension.
25+
The right ascension of the input coordinates.
6926
dec : float
70-
The input declination.
27+
The declination of the input coordinates.
7128
7229
Returns
7330
-------
74-
tuple
75-
The pixel position, FITS wcs object
31+
pixel_position
32+
The pixel position of the input coordinates.
33+
wcs_updated : `~astropy.wcs.WCS`
34+
The approximated FITS WCS object.
7635
"""
77-
78-
# Convert the gwcs object to an astropy FITS WCS header
79-
header = gwcsobj.to_fits_sip()
80-
81-
# Update WCS header with some keywords that it's missing.
82-
# Otherwise, it won't work with astropy.wcs tools (TODO: Figure out why. What are these keywords for?)
83-
for k in ['cpdis1', 'cpdis2', 'det2im1', 'det2im2', 'sip']:
84-
if k not in header:
85-
header[k] = 'na'
86-
87-
# New WCS object with updated header
88-
wcs_updated = astropy.wcs.WCS(header)
89-
90-
# Turn input RA, Dec into a SkyCoord object
91-
coordinates = SkyCoord(ra, dec, unit='deg')
92-
93-
# Map the coordinates to a pixel's location on the Roman 2d array (row, col)
94-
row, col = gwcsobj.invert(coordinates)
95-
96-
return (row, col), wcs_updated
97-
98-
99-
def _get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, SkyCoord],
100-
wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits",
101-
write_file: bool = True, fill_value: Union[int, float] = np.nan,
102-
gwcsobj: gwcs.wcs.WCS = None) -> astropy.nddata.Cutout2D:
103-
"""
104-
Get a Roman image cutout.
105-
106-
Cut out a square section from the input image data array. The ``coords`` can either be a tuple of x, y
107-
pixel coordinates or an astropy SkyCoord object, in which case, a wcs is required. Writes out a
108-
new output file containing the image cutout of the specified ``size``. Default is 20 pixels.
109-
110-
Parameters
111-
----------
112-
data : asdf.tags.core.ndarray.NDArrayType
113-
the input Roman image data array
114-
coords : Union[tuple, SkyCoord]
115-
the input pixel or sky coordinates
116-
wcs : astropy.wcs.wcs.WCS, Optional
117-
the astropy FITS wcs object
118-
size : int, optional
119-
the image cutout pizel size, by default 20
120-
outfile : str, optional
121-
the name of the output cutout file, by default "example_roman_cutout.fits"
122-
write_file : bool, by default True
123-
Flag to write the cutout to a file or not
124-
fill_value: int | float, by default np.nan
125-
The fill value for pixels outside the original image.
126-
gwcsobj : gwcs.wcs.WCS, Optional
127-
the original gwcs object for the full image, needed only when writing cutout as asdf file
128-
129-
Returns
130-
-------
131-
astropy.nddata.Cutout2D:
132-
an image cutout object
133-
134-
Raises
135-
------
136-
ValueError:
137-
when a wcs is not present when coords is a SkyCoord object
138-
RuntimeError:
139-
when the requested cutout does not overlap with the original image
140-
ValueError:
141-
when no gwcs object is provided when writing to an asdf file
36+
return ASDFCutout.get_center_pixel(gwcsobj, ra, dec)
37+
38+
39+
@deprecated_renamed_argument('output_file', None, '1.0.0', warning_type=DeprecationWarning,
40+
message='`output_file` is non-operational and will be removed in a future version.')
41+
def asdf_cut(input_files: List[Union[str, Path, S3Path]],
42+
ra: float,
43+
dec: float,
44+
cutout_size: int = 25,
45+
output_file: Union[str, Path] = "example_roman_cutout.fits",
46+
write_file: bool = True,
47+
fill_value: Union[int, float] = np.nan,
48+
output_dir: Union[str, Path] = '.',
49+
output_format: str = '.asdf',
50+
key: str = None,
51+
secret: str = None,
52+
token: str = None,
53+
verbose: bool = False) -> astropy.nddata.Cutout2D:
14254
"""
143-
144-
# check for correct inputs
145-
if isinstance(coords, SkyCoord) and not wcs:
146-
raise ValueError('wcs must be input if coords is a SkyCoord.')
147-
148-
# create the cutout
149-
try:
150-
cutout = astropy.nddata.Cutout2D(data, position=coords, wcs=wcs, size=(size, size), mode='partial',
151-
fill_value=fill_value)
152-
except astropy.nddata.utils.NoOverlapError as e:
153-
raise RuntimeError('Could not create 2d cutout. The requested cutout does not overlap with the '
154-
'original image.') from e
155-
156-
# check if the data is a quantity and get the array data
157-
if isinstance(cutout.data, astropy.units.Quantity):
158-
data = cutout.data.value
159-
else:
160-
data = cutout.data
161-
162-
# write the cutout to the output file
163-
if write_file:
164-
# check the output file type
165-
out = pathlib.Path(outfile)
166-
write_as = out.suffix or '.fits'
167-
outfile = outfile if out.suffix else str(out) + write_as
168-
169-
# write out the file
170-
if write_as == '.fits':
171-
_write_fits(cutout, outfile)
172-
elif write_as == '.asdf':
173-
if not gwcsobj:
174-
raise ValueError('The original gwcs object is needed when writing to asdf file.')
175-
_write_asdf(cutout, gwcsobj, outfile)
176-
177-
return cutout
178-
179-
180-
def _write_fits(cutout: astropy.nddata.Cutout2D, outfile: str = "example_roman_cutout.fits"):
181-
"""
182-
Write cutout as FITS file.
183-
184-
Parameters
185-
----------
186-
cutout : astropy.nddata.Cutout2D
187-
the 2d cutout
188-
outfile : str, optional
189-
the name of the output cutout file, by default "example_roman_cutout.fits"
190-
"""
191-
# check if the data is a quantity and get the array data
192-
if isinstance(cutout.data, astropy.units.Quantity):
193-
data = cutout.data.value
194-
else:
195-
data = cutout.data
196-
197-
astropy.io.fits.writeto(outfile, data=data, header=cutout.wcs.to_header(relax=True), overwrite=True)
198-
199-
200-
def _slice_gwcs(gwcsobj: gwcs.wcs.WCS, slices: Tuple[slice, slice]) -> gwcs.wcs.WCS:
201-
"""
202-
Slice the original gwcs object.
203-
204-
"Slices" the original gwcs object down to the cutout shape. This is a hack
205-
until proper gwcs slicing is in place a la fits WCS slicing. The ``slices``
206-
keyword input is a tuple with the x, y cutout boundaries in the original image
207-
array, e.g. ``cutout.slices_original``. Astropy Cutout2D slices are in the form
208-
((ymin, ymax, None), (xmin, xmax, None))
209-
210-
Parameters
211-
----------
212-
gwcsobj : gwcs.wcs.WCS
213-
the original gwcs from the input image
214-
slices : Tuple[slice, slice]
215-
the cutout x, y slices as ((ymin, ymax), (xmin, xmax))
216-
217-
Returns
218-
-------
219-
gwcs.wcs.WCS
220-
The sliced gwcs object
221-
"""
222-
tmp = copy.deepcopy(gwcsobj)
223-
224-
# get the cutout array bounds and create a new shift transform to the cutout
225-
# add the new transform to the gwcs
226-
xmin, xmax = slices[1].start, slices[1].stop
227-
ymin, ymax = slices[0].start, slices[0].stop
228-
shape = (ymax - ymin, xmax - xmin)
229-
offsets = models.Shift(xmin, name='cutout_offset1') & models.Shift(ymin, name='cutout_offset2')
230-
tmp.insert_transform('detector', offsets, after=True)
231-
232-
# modify the gwcs bounding box to the cutout shape
233-
tmp.bounding_box = ((0, shape[0] - 1), (0, shape[1] - 1))
234-
tmp.pixel_shape = shape[::-1]
235-
tmp.array_shape = shape
236-
return tmp
237-
238-
239-
def _write_asdf(cutout: astropy.nddata.Cutout2D, gwcsobj: gwcs.wcs.WCS, outfile: str = "example_roman_cutout.asdf"):
240-
"""
241-
Write cutout as ASDF file.
242-
243-
Parameters
244-
----------
245-
cutout : astropy.nddata.Cutout2D
246-
the 2d cutout
247-
gwcsobj : gwcs.wcs.WCS
248-
the original gwcs object for the full image
249-
outfile : str, optional
250-
the name of the output cutout file, by default "example_roman_cutout.asdf"
251-
"""
252-
# slice the origial gwcs to the cutout
253-
sliced_gwcs = _slice_gwcs(gwcsobj, cutout.slices_original)
254-
255-
# create the asdf tree
256-
tree = {'roman': {'meta': {'wcs': sliced_gwcs}, 'data': cutout.data}}
257-
af = asdf.AsdfFile(tree)
258-
259-
# Write the data to a new file
260-
af.write_to(outfile)
261-
262-
263-
def asdf_cut(input_file: Union[str, pathlib.Path, S3Path], ra: float, dec: float, cutout_size: int = 25,
264-
output_file: Union[str, pathlib.Path] = "example_roman_cutout.fits",
265-
write_file: bool = True, fill_value: Union[int, float] = np.nan, key: str = None,
266-
secret: str = None, token: str = None, verbose: bool = False) -> astropy.nddata.Cutout2D:
267-
"""
268-
Takes a single ASDF input file (`input_file`) and generates a cutout of designated size `cutout_size`
269-
around the given coordinates (`coordinates`).
55+
Takes one of more ASDF input files (`input_files`) and generates a cutout of designated size `cutout_size`
56+
around the given coordinates (`coordinates`). The cutout is written to a file or returned as an object.
27057
27158
Parameters
27259
----------
@@ -283,10 +70,17 @@ def asdf_cut(input_file: Union[str, pathlib.Path, S3Path], ra: float, dec: float
28370
pixel grid.
28471
output_file : str | Path
28572
Optional, default "example_roman_cutout.fits". The name of the output cutout file.
73+
This parameter is deprecated and will be removed in a future version.
28674
write_file : bool
28775
Optional, default True. Flag to write the cutout to a file or not.
28876
fill_value: int | float
28977
Optional, default `np.nan`. The fill value for pixels outside the original image.
78+
output_dir : str | Path
79+
Optional, default ".". The directory to write the cutout file(s) to.
80+
output_format : str
81+
Optional, default ".asdf". The format of the output cutout file. If `write_file` is False,
82+
then cutouts will be returned as `asdf.AsdfFile` objects if `output_format` is ".asdf" or
83+
as `astropy.io.fits.HDUList` objects if `output_format` is ".fits".
29084
key : string
29185
Default None. Access key ID for S3 file system. Only applicable if `input_file` is a
29286
cloud resource.
@@ -301,25 +95,17 @@ def asdf_cut(input_file: Union[str, pathlib.Path, S3Path], ra: float, dec: float
30195
30296
Returns
30397
-------
304-
astropy.nddata.Cutout2D:
305-
An image cutout object.
98+
response : str | list
99+
A list of cutout file paths if `write_file` is True, otherwise a list of cutout objects.
306100
"""
307-
# Log messages based on verbosity
308-
_handle_verbose(verbose)
309-
310-
# if file comes from AWS cloud bucket, get HTTP URL to open with asdf
311-
file = input_file
312-
if (isinstance(input_file, str) and input_file.startswith('s3://')) or isinstance(input_file, S3Path):
313-
file = _get_cloud_http(input_file, key, secret, token, verbose)
314-
315-
# get the 2d image data
316-
with asdf.open(file) as f:
317-
data = f['roman']['data']
318-
gwcsobj = f['roman']['meta']['wcs']
319-
320-
# get the center pixel
321-
pixel_coordinates, wcs = get_center_pixel(gwcsobj, ra, dec)
322-
323-
# create the 2d image cutout
324-
return _get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file,
325-
write_file=write_file, fill_value=fill_value, gwcsobj=gwcsobj)
101+
return ASDFCutout(input_files=input_files,
102+
coordinates=(ra, dec),
103+
cutout_size=cutout_size,
104+
fill_value=fill_value,
105+
memory_only=not write_file,
106+
output_dir=output_dir,
107+
output_format=output_format,
108+
key=key,
109+
secret=secret,
110+
token=token,
111+
verbose=verbose).cutout()

0 commit comments

Comments
 (0)