Skip to content

Commit bb65ecd

Browse files
committed
Generalization classes for FITS cutouts
1 parent 7b204f0 commit bb65ecd

File tree

4 files changed

+1132
-475
lines changed

4 files changed

+1132
-475
lines changed

astrocut/Cutout.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from abc import abstractmethod, ABC
2+
from pathlib import Path
3+
from typing import List, Union, Tuple
4+
5+
from astropy import wcs
6+
import astropy.units as u
7+
from s3path import S3Path
8+
from astropy.coordinates import SkyCoord
9+
import numpy as np
10+
11+
from astrocut.exceptions import InvalidInputError, InvalidQueryError
12+
13+
from . import log
14+
from .utils.utils import _handle_verbose, parse_size_input
15+
16+
17+
class Cutout(ABC):
18+
"""
19+
Abstract class for creating cutouts. This class defines attributes and methods that are common to all
20+
cutout classes.
21+
22+
Attributes
23+
----------
24+
input_files : list
25+
List of input image files.
26+
coordinates : str | `~astropy.coordinates.SkyCoord`
27+
Coordinates of the center of the cutout.
28+
cutout_size : int | array | list | tuple | `~astropy.units.Quantity`
29+
Size of the cutout array.
30+
fill_value : int | float
31+
Value to fill the cutout with if the cutout is outside the image.
32+
memory_only : bool
33+
If True, the cutout is written to memory instead of disk.
34+
output_dir : str | Path
35+
Directory to write the cutout file(s) to.
36+
limit_rounding_method : str
37+
Method to use for rounding the cutout limits. Options are 'round', 'ceil', and 'floor'.
38+
verbose : bool
39+
If True, log messages are printed to the console.
40+
41+
Methods
42+
-------
43+
get_cutout_limits(img_wcs)
44+
Returns the x and y pixel limits for the cutout.
45+
cutout()
46+
Generate the cutouts.
47+
"""
48+
49+
def __init__(self, input_files: List[Union[str, Path, S3Path]], coordinates: Union[SkyCoord, str],
50+
cutout_size: Union[int, np.ndarray, u.Quantity, List[int], Tuple[int]] = 25,
51+
fill_value: Union[int, float] = np.nan, memory_only: bool = False,
52+
output_dir: Union[str, Path] = '.', limit_rounding_method: str = 'round', verbose: bool = True):
53+
54+
# Log messages according to verbosity
55+
_handle_verbose(verbose)
56+
57+
# Ensure that input files are in a list
58+
if isinstance(input_files, str) or isinstance(input_files, Path):
59+
input_files = [input_files]
60+
self._input_files = input_files
61+
62+
# Get coordinates as a SkyCoord object
63+
if coordinates and not isinstance(coordinates, SkyCoord):
64+
coordinates = SkyCoord(coordinates, unit='deg')
65+
self._coordinates = coordinates
66+
log.debug('Coordinates: %s', self._coordinates)
67+
68+
# Turning the cutout size into an array of two values
69+
self._cutout_size = parse_size_input(cutout_size)
70+
log.debug('Cutout size: %s', self._cutout_size)
71+
72+
# Assigning other attributes
73+
valid_rounding = ['round', 'ceil', 'floor']
74+
if not isinstance(limit_rounding_method, str) or limit_rounding_method.lower() not in valid_rounding:
75+
raise InvalidInputError(f'Limit rounding method {limit_rounding_method} is not recognized. '
76+
'Valid options are {valid_rounding}.')
77+
self._limit_rounding_method = limit_rounding_method
78+
self._fill_value = fill_value
79+
self._memory_only = memory_only
80+
self._output_dir = output_dir
81+
self._verbose = verbose
82+
83+
def _get_cutout_limits(self, img_wcs: wcs.WCS) -> np.ndarray:
84+
"""
85+
Returns the x and y pixel limits for the cutout.
86+
87+
Note: This function does no bounds checking, so the returned limits are not
88+
guaranteed to overlap the original image.
89+
90+
Parameters
91+
----------
92+
img_wcs : `~astropy.wcs.WCS`
93+
The WCS for the image that the cutout is being cut from.
94+
95+
Returns
96+
-------
97+
response : `numpy.array`
98+
The cutout pixel limits in an array of the form [[xmin,xmax],[ymin,ymax]]
99+
"""
100+
# Calculate pixel corresponding to coordinate
101+
try:
102+
center_pixel = self._coordinates.to_pixel(img_wcs)
103+
except wcs.NoConvergence: # If wcs can't converge, center coordinate is far from the footprint
104+
raise InvalidQueryError("Cutout location is not in image footprint!")
105+
106+
# Sometimes, we may get nans without a NoConvergence error
107+
if np.isnan(center_pixel).any():
108+
raise InvalidQueryError("Cutout location is not in image footprint!")
109+
110+
lims = np.zeros((2, 2), dtype=int)
111+
for axis, size in enumerate(self._cutout_size):
112+
113+
if not isinstance(size, u.Quantity): # assume pixels
114+
dim = size / 2
115+
elif size.unit == u.pixel: # also pixels
116+
dim = size.value / 2
117+
elif size.unit.physical_type == 'angle': # angular size
118+
pixel_scale = u.Quantity(wcs.utils.proj_plane_pixel_scales(img_wcs)[axis],
119+
img_wcs.wcs.cunit[axis])
120+
dim = (size / pixel_scale).decompose() / 2
121+
else:
122+
raise InvalidInputError(f'Cutout size units {size.unit} are not supported.')
123+
124+
# Round the limits according to the requested method
125+
rounding_funcs = {
126+
'round': np.round,
127+
'ceil': np.ceil,
128+
'floor': np.floor
129+
}
130+
round_func = rounding_funcs[self._limit_rounding_method]
131+
132+
lims[axis, 0] = int(round_func(center_pixel[axis] - dim))
133+
lims[axis, 1] = int(round_func(center_pixel[axis] + dim))
134+
135+
# The case where the requested area is so small it rounds to zero
136+
if self._limit_rounding_method == 'round' and lims[axis, 0] == lims[axis, 1]:
137+
lims[axis, 0] = int(np.floor(center_pixel[axis] - 1))
138+
lims[axis, 1] = lims[axis, 0] + 1
139+
140+
return lims
141+
142+
@abstractmethod
143+
def cutout(self):
144+
"""
145+
Generate the cutout(s).
146+
147+
This method is abstract and should be defined in subclasses.
148+
"""
149+
pass

0 commit comments

Comments
 (0)