diff --git a/astrocut/ASDFCutout.py b/astrocut/ASDFCutout.py new file mode 100644 index 00000000..27c4620f --- /dev/null +++ b/astrocut/ASDFCutout.py @@ -0,0 +1,468 @@ +import copy +from pathlib import Path +from time import monotonic +from typing import List, Tuple, Union, Optional +import warnings + +import asdf +import gwcs +import numpy as np +import requests +import s3fs +from astropy.coordinates import SkyCoord +from astropy.io import fits +from astropy.modeling import models +from astropy.nddata.utils import Cutout2D, NoOverlapError +from astropy.units import Quantity +from astropy.wcs import WCS +from s3path import S3Path + +from . import log +from .ImageCutout import ImageCutout +from .exceptions import DataWarning, InvalidInputError + + +class ASDFCutout(ImageCutout): + """ + Class for creating cutouts from ASDF files. + + Args + ---- + input_files : list + List of input image files. + coordinates : str | `~astropy.coordinates.SkyCoord` + Coordinates of the center of the cutout. + cutout_size : int | array | list | tuple | `~astropy.units.Quantity` + Size of the cutout array. + fill_value : int | float + Value to fill the cutout with if the cutout is outside the image. + limit_rounding_method : str + Method to use for rounding the cutout limits. Options are 'round', 'ceil', and 'floor'. + key : str + Optional, default None. Access key ID for S3 file system. + secret : str + Optional, default None. Secret access key for S3 file system. + token : str + Optional, default None. Security token for S3 file system. + verbose : bool + If True, log messages are printed to the console. + + Attributes + ---------- + cutouts : list + The cutouts as a list of `astropy.nddata.Cutout2D` objects. + cutouts_by_file : dict + The cutouts as `astropy.nddata.Cutout2D` objects stored by input filename. + fits_cutouts : list + The cutouts as a list `astropy.io.fits.HDUList` objects. + asdf_cutouts : list + The cutouts as a list of `asdf.AsdfFile` objects. + + Methods + ------- + _get_cloud_http(input_file) + Get the HTTP URL of a cloud resource from an S3 URI. + _load_file_data(input_file) + Load the data from an input file. + _get_cutout_data(data, wcs, pixel_coords) + Get the cutout data from the input image. + _slice_gwcs(cutout, gwcs) + Slice the original gwcs object to fit the cutout. + _cutout_file(file) + Create a cutout from an input file. + cutout() + Generate cutouts from a list of input images. + _write_as_format(output_format, output_dir) + Write the cutout to disk or memory in the specified format. + write_as_fits(output_dir) + Write the cutouts to disk or memory in FITS format. + write_as_asdf(output_dir) + Write the cutouts to disk or memory in ASDF format. + get_center_pixel(gwcsobj, ra, dec) + Get the closest pixel location on an input image for a given set of coordinates. + """ + + def __init__(self, input_files: List[Union[str, Path, S3Path]], coordinates: Union[SkyCoord, str], + cutout_size: Union[int, np.ndarray, Quantity, List[int], Tuple[int]] = 25, + fill_value: Union[int, float] = np.nan, limit_rounding_method: str = 'round', + key: Optional[str] = None, secret: Optional[str] = None, + token: Optional[str] = None, verbose: bool = False): + # Superclass constructor + super().__init__(input_files, coordinates, cutout_size, fill_value, limit_rounding_method, verbose=verbose) + + # Assign AWS credential attributes + self._key = key + self._secret = secret + self._token = token + self._mission_kwd = 'roman' + + self.cutouts = [] # Public attribute to hold `Cutout2D` objects + self._asdf_cutouts = None # Store ASDF objects + self._fits_cutouts = None # Store FITS objects + self._gwcs_objects = [] # Store original GWCS objects + + # Make cutouts + self.cutout() + + @property + def fits_cutouts(self) -> List[fits.HDUList]: + """ + Return the cutouts as a list `astropy.io.fits.HDUList` objects. + """ + if not self._fits_cutouts: + fits_cutouts = [] + for cutout in self.cutouts: + # TODO: Create a FITS object with ASDF extension + # Create a primary FITS header to hold data and WCS + primary_hdu = fits.PrimaryHDU(data=cutout.data, header=cutout.wcs.to_header(relax=True)) + + # Write to HDUList + fits_cutouts.append(fits.HDUList([primary_hdu])) + self._fits_cutouts = fits_cutouts + return self._fits_cutouts + + @property + def asdf_cutouts(self) -> List[asdf.AsdfFile]: + """ + Return the cutouts as a list of `asdf.AsdfFile` objects. + """ + if not self._asdf_cutouts: + asdf_cutouts = [] + for i, cutout in enumerate(self.cutouts): + # Slice the origial gwcs to the cutout + sliced_gwcs = self._slice_gwcs(cutout, self._gwcs_objects[i]) + + # Create the asdf tree + tree = {self._mission_kwd: {'meta': {'wcs': sliced_gwcs}, 'data': cutout.data}} + asdf_cutouts.append(asdf.AsdfFile(tree)) + self._asdf_cutouts = asdf_cutouts + return self._asdf_cutouts + + def _get_cloud_http(self, input_file: Union[str, S3Path]) -> str: + """ + Get the HTTP URL of a cloud resource from an S3 URI. + + Parameters + ---------- + input_file : str | S3Path + The input file S3 URI. + + Returns + ------- + str + The HTTP URL of the cloud resource. + """ + # Check if public or private by sending an HTTP request + s3_path = S3Path.from_uri(input_file) if isinstance(input_file, str) else input_file + url = f'https://{s3_path.bucket}.s3.amazonaws.com/{s3_path.key}' + resp = requests.head(url, timeout=10) + is_anon = False if resp.status_code == 403 else True + if not is_anon: + log.debug('Attempting to access private S3 bucket: %s', s3_path.bucket) + + # Create file system and get URL of file + fs = s3fs.S3FileSystem(anon=is_anon, key=self._key, secret=self._secret, token=self._token) + with fs.open(input_file, 'rb') as f: + return f.url() + + def _load_file_data(self, input_file: Union[str, Path, S3Path]) -> Tuple[np.ndarray, gwcs.wcs.WCS]: + """ + Load relevant data from an input file. + + Parameters + ---------- + input_file : str | Path | S3Path + The input file to load data from. + + Returns + ------- + data : array + The image data. + gwcs : `~gwcs.wcs.WCS` + The GWCS of the image. + """ + # If file comes from AWS cloud bucket, get HTTP URL to open with asdf + if (isinstance(input_file, str) and input_file.startswith('s3://')) or isinstance(input_file, S3Path): + input_file = self._get_cloud_http(input_file) + + # Get data and GWCS object from ASDF input file + with asdf.open(input_file) as af: + data = af[self._mission_kwd]['data'] + gwcs = af[self._mission_kwd]['meta'].get('wcs', None) + + return (data, gwcs) + + def _get_cutout_data(self, data: np.ndarray, wcs: WCS, pixel_coords: Tuple[int, int]) -> Cutout2D: + """ + Get the cutout data from the input image. + + Parameters + ---------- + data : array + The input image data. + wcs : `~astropy.wcs.WCS` + The approximated WCS of the input image. + pixel_coords : tuple + The pixel coordinates closest to the center of the cutout. + + Returns + ------- + img_cutout : `~astropy.nddata.Cutout2D` + The cutout object. + """ + log.debug('Original image shape: %s', data.shape) + + # Using `~astropy.nddata.Cutout2D` to get the cutout data and handle WCS + # Passing in pixel coordinates that were calculated using the original GWCS object, + # so the approximate WCS object will not be used to calculate the pixel coordinates + # of the cutout center. Approximate WCS will be used in calculation of cutout bounds + # if cutout size is given in angular units. + img_cutout = Cutout2D(data, + position=pixel_coords, + wcs=wcs, + size=(self._cutout_size[1], self._cutout_size[0]), + mode='partial', + fill_value=self._fill_value) + + log.debug('Image cutout shape: %s', img_cutout.shape) + + return img_cutout + + def _slice_gwcs(self, cutout: Cutout2D, gwcs: gwcs.wcs.WCS) -> gwcs.wcs.WCS: + """ + Slice the original gwcs object. + + "Slices" the original gwcs object down to the cutout shape. This is a hack + until proper gwcs slicing is in place a la fits WCS slicing. The ``slices`` + keyword input is a tuple with the x, y cutout boundaries in the original image + array, e.g. ``cutout.slices_original``. Astropy Cutout2D slices are in the form + ((ymin, ymax, None), (xmin, xmax, None)) + + Parameters + ---------- + cutout : astropy.nddata.Cutout2D + The cutout object. + gwcs : gwcs.wcs.WCS + The original GWCS from the input image. + + Returns + ------- + gwcs.wcs.WCS + The sliced GWCS object. + """ + # Create copy of original gwcs object + tmp = copy.deepcopy(gwcs) + + # Get the cutout array bounds and create a new shift transform to the cutout + # Add the new transform to the gwcs + slices = cutout.slices_original + xmin, xmax = slices[1].start, slices[1].stop + ymin, ymax = slices[0].start, slices[0].stop + shape = (ymax - ymin, xmax - xmin) + offsets = models.Shift(xmin, name='cutout_offset1') & models.Shift(ymin, name='cutout_offset2') + tmp.insert_transform('detector', offsets, after=True) + + # Modify the gwcs bounding box to the cutout shape + tmp.bounding_box = ((0, shape[0] - 1), (0, shape[1] - 1)) + tmp.pixel_shape = shape[::-1] + tmp.array_shape = shape + return tmp + + def _cutout_file(self, file: Union[str, Path, S3Path]): + """ + Create a cutout from a single input file. + + Parameters + ---------- + file : str | Path | S3Path + The input file to create a cutout from. + """ + # Load the data from the input file + data, gwcs = self._load_file_data(file) + + # Skip if the file does not contain a GWCS object + if gwcs is None: + warnings.warn(f'File {file} does not contain a GWCS object. Skipping...', + DataWarning) + return + + # Get closest pixel coordinates and approximated WCS + pixel_coords, wcs = self.get_center_pixel(gwcs, self._coordinates.ra.value, self._coordinates.dec.value) + + # Create the cutout + try: + cutout2D = self._get_cutout_data(data, wcs, pixel_coords) + except NoOverlapError: + warnings.warn(f'Cutout footprint does not overlap with data in {file}, skipping...', + DataWarning) + return + + # Check that there is data in the cutout image + if (cutout2D.data == 0).all() or (np.isnan(cutout2D.data)).all(): + warnings.warn(f'Cutout of {file} contains no data, skipping...', + DataWarning) + return + + # Convert Quantity data to ndarray + if isinstance(cutout2D.data, Quantity): + cutout2D.data = cutout2D.data.value + + # Store the Cutout2D object + self.cutouts.append(cutout2D) + + # Store the original GWCS to use if creating asdf.AsdfFile objects + self._gwcs_objects.append(gwcs) + + # Store cutout with filename + self.cutouts_by_file[file] = [cutout2D] + + def cutout(self) -> Union[str, List[str], List[fits.HDUList]]: + """ + Generate cutouts from a list of input images. + + Returns + ------- + cutout_path : Path | list + Cutouts as memory objects or path(s) to the written cutout files. + + Raises + ------ + InvalidInputError + If no cutouts contain data. + """ + # Track start time + start_time = monotonic() + + # Cutout each input file + for file in self._input_files: + self._cutout_file(file) + + # If no cutouts contain data, raise exception + if not self.cutouts: + raise InvalidInputError('Cutout contains no data! (Check image footprint.)') + + # Log total time elapsed + log.debug('Total time: %.2f sec', monotonic() - start_time) + + return self.cutouts + + def _write_as_format(self, output_format: str, output_dir: Union[str, Path] = '.') -> List[str]: + """ + Write the cutout to disk in the specified output format. + + Parameters + ---------- + output_format : str + The output format to write the cutout to. Options are '.fits' and '.asdf'. + output_dir : str | Path + The output directory to write the cutouts to + + Returns + ------- + cutout_paths : list + The path(s) to the cutout file(s) or the cutout memory objects. + """ + Path(output_dir).mkdir(parents=True, exist_ok=True) + cutout_paths = [] # List to store paths to cutout files + for i, file in enumerate(self.cutouts_by_file): + # Determine the output path + filename = '{}_{:.7f}_{:.7f}_{}-x-{}_astrocut{}'.format( + Path(file).stem, + self._coordinates.ra.value, + self._coordinates.dec.value, + str(self._cutout_size[0]).replace(' ', ''), + str(self._cutout_size[1]).replace(' ', ''), + output_format) + cutout_path = Path(output_dir, filename) + + if output_format == '.fits': + cutout = self.fits_cutouts[i] + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + cutout.writeto(cutout_path, overwrite=True, checksum=True) + + elif output_format == '.asdf': + cutout = self.asdf_cutouts[i] + cutout.write_to(cutout_path) + + cutout_paths.append(cutout_path.as_posix()) + + log.debug('Cutout filepaths: {}'.format(cutout_paths)) + return cutout_paths + + def write_as_fits(self, output_dir: Union[str, Path] = '.') -> List[str]: + """ + Write the cutouts to disk or memory in FITS format. + + Parameters + ---------- + output_dir : str | Path + The output directory to write the cutouts to. Defaults to the current directory. + + Returns + ------- + list + A list of paths to the cutout FITS files. + """ + return self._write_as_format(output_format='.fits', output_dir=output_dir) + + def write_as_asdf(self, output_dir: Union[str, Path] = '.') -> List[str]: + """ + Write the cutouts to disk or memory in ASDF format. + + Parameters + ---------- + output_dir : str | Path + The output directory to write the cutouts to. Defaults to the current directory. + + Returns + ------- + list + A list of paths to the cutout ASDF files. + """ + return self._write_as_format(output_format='.asdf', output_dir=output_dir) + + @staticmethod + def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> Tuple[Tuple[int, int], WCS]: + """ + Get the closest pixel location on an input image for a given set of coordinates. + + Parameters + ---------- + gwcsobj : gwcs.wcs.WCS + The GWCS object. + ra : float + The right ascension of the input coordinates. + dec : float + The declination of the input coordinates. + + Returns + ------- + pixel_position + The pixel position of the input coordinates. + wcs_updated : `~astropy.wcs.WCS` + The approximated FITS WCS object. + """ + + # Convert the gwcs object to an astropy FITS WCS header + header = gwcsobj.to_fits_sip() + + # Update WCS header with some keywords that it's missing. + # Otherwise, it won't work with astropy.wcs tools (TODO: Figure out why. What are these keywords for?) + for k in ['cpdis1', 'cpdis2', 'det2im1', 'det2im2', 'sip']: + if k not in header: + header[k] = 'na' + + # New WCS object with updated header + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + wcs_updated = WCS(header) + + # Turn input RA, Dec into a SkyCoord object + coordinates = SkyCoord(ra, dec, unit='deg') + + # Map the coordinates to a pixel's location on the Roman 2d array (row, col) + gwcsobj.bounding_box = None + row, col = gwcsobj.invert(coordinates) + + return (row, col), wcs_updated diff --git a/astrocut/Cutout.py b/astrocut/Cutout.py index 6aada729..33704f91 100644 --- a/astrocut/Cutout.py +++ b/astrocut/Cutout.py @@ -1,6 +1,7 @@ from abc import abstractmethod, ABC from pathlib import Path from typing import List, Union, Tuple +import warnings from astropy import wcs import astropy.units as u @@ -8,10 +9,10 @@ from astropy.coordinates import SkyCoord import numpy as np -from astrocut.exceptions import InvalidInputError, InvalidQueryError +from astrocut.exceptions import InputWarning, InvalidInputError, InvalidQueryError from . import log -from .utils.utils import _handle_verbose, parse_size_input +from .utils.utils import _handle_verbose class Cutout(ABC): @@ -19,8 +20,8 @@ class Cutout(ABC): Abstract class for creating cutouts. This class defines attributes and methods that are common to all cutout classes. - Attributes - ---------- + Args + ---- input_files : list List of input image files. coordinates : str | `~astropy.coordinates.SkyCoord` @@ -29,10 +30,6 @@ class Cutout(ABC): Size of the cutout array. fill_value : int | float Value to fill the cutout with if the cutout is outside the image. - memory_only : bool - If True, the cutout is written to memory instead of disk. - output_dir : str | Path - Directory to write the cutout file(s) to. limit_rounding_method : str Method to use for rounding the cutout limits. Options are 'round', 'ceil', and 'floor'. verbose : bool @@ -40,7 +37,9 @@ class Cutout(ABC): Methods ------- - get_cutout_limits(img_wcs) + _parse_size_input(cutout_size) + Makes the given cutout size into a length 2 array + _get_cutout_limits(img_wcs) Returns the x and y pixel limits for the cutout. cutout() Generate the cutouts. @@ -48,8 +47,8 @@ class Cutout(ABC): def __init__(self, input_files: List[Union[str, Path, S3Path]], coordinates: Union[SkyCoord, str], cutout_size: Union[int, np.ndarray, u.Quantity, List[int], Tuple[int]] = 25, - fill_value: Union[int, float] = np.nan, memory_only: bool = False, - output_dir: Union[str, Path] = '.', limit_rounding_method: str = 'round', verbose: bool = False): + fill_value: Union[int, float] = np.nan, limit_rounding_method: str = 'round', + verbose: bool = False): # Log messages according to verbosity _handle_verbose(verbose) @@ -60,13 +59,13 @@ def __init__(self, input_files: List[Union[str, Path, S3Path]], coordinates: Uni self._input_files = input_files # Get coordinates as a SkyCoord object - if coordinates and not isinstance(coordinates, SkyCoord): + if not isinstance(coordinates, SkyCoord): coordinates = SkyCoord(coordinates, unit='deg') self._coordinates = coordinates log.debug('Coordinates: %s', self._coordinates) # Turning the cutout size into an array of two values - self._cutout_size = parse_size_input(cutout_size) + self._cutout_size = self._parse_size_input(cutout_size) log.debug('Cutout size: %s', self._cutout_size) # Assigning other attributes @@ -79,11 +78,57 @@ def __init__(self, input_files: List[Union[str, Path, S3Path]], coordinates: Uni if not isinstance(fill_value, int) and not isinstance(fill_value, float): raise InvalidInputError('Fill value must be an integer or a float.') self._fill_value = fill_value - - self._memory_only = memory_only - self._output_dir = output_dir + self._verbose = verbose + def _parse_size_input(self, cutout_size): + """ + Makes the given cutout size into a length 2 array. + + Parameters + ---------- + cutout_size : int, array-like, `~astropy.units.Quantity` + The size of the cutout array. If ``cutout_size`` is a scalar number or a scalar + `~astropy.units.Quantity`, then a square cutout of ``cutout_size`` will be created. + If ``cutout_size`` has two elements, they should be in ``(ny, nx)`` order. Scalar numbers + in ``cutout_size`` are assumed to be in units of pixels. `~astropy.units.Quantity` objects + must be in pixel or angular units. + + Returns + ------- + response : array + Length two cutout size array, in the form [ny, nx]. + """ + + # Making size into an array [ny, nx] + if np.isscalar(cutout_size): + cutout_size = np.repeat(cutout_size, 2) + + if isinstance(cutout_size, u.Quantity): + cutout_size = np.atleast_1d(cutout_size) + if len(cutout_size) == 1: + cutout_size = np.repeat(cutout_size, 2) + elif not isinstance(cutout_size, np.ndarray): + cutout_size = np.array(cutout_size) + + if len(cutout_size) > 2: + warnings.warn('Too many dimensions in cutout size, only the first two will be used.', + InputWarning) + cutout_size = cutout_size[:2] + + + for dim in cutout_size: + # Raise error if either dimension is not a positive number + if dim <= 0: + raise InvalidInputError('Cutout size dimensions must be greater than zero. ' + f'Provided size: ({cutout_size[0]}, {cutout_size[1]})') + + # Raise error if either dimension is not an pixel or angular Quantity + if isinstance(dim, u.Quantity) and dim.unit != u.pixel and dim.unit.physical_type != 'angle': + raise InvalidInputError(f'Cutout size unit {dim.unit.aliases[0]} is not supported.') + + return cutout_size + def _get_cutout_limits(self, img_wcs: wcs.WCS) -> np.ndarray: """ Returns the x and y pixel limits for the cutout. diff --git a/astrocut/FITSCutout.py b/astrocut/FITSCutout.py index 5a14a16f..5d2e9e23 100644 --- a/astrocut/FITSCutout.py +++ b/astrocut/FITSCutout.py @@ -1,5 +1,6 @@ from datetime import date from pathlib import Path +from time import monotonic from typing import List, Literal, Optional, Tuple, Union import warnings @@ -12,7 +13,7 @@ import numpy as np from s3path import S3Path -from .exceptions import DataWarning, InputWarning, InvalidQueryError +from .exceptions import DataWarning, InvalidInputError, InvalidQueryError from .ImageCutout import ImageCutout from . import __version__, log @@ -21,8 +22,8 @@ class FITSCutout(ImageCutout): """ Class for creating cutouts from FITS files. - Attributes - ---------- + Args + ---- input_files : list List of input image files. coordinates : str | `~astropy.coordinates.SkyCoord` @@ -31,34 +32,8 @@ class FITSCutout(ImageCutout): Size of the cutout array. fill_value : int | float Value to fill the cutout with if the cutout is outside the image. - memory_only : bool - If True, the cutout is written to memory instead of disk. - output_dir : str | Path - Directory to write the cutout file(s) to. limit_rounding_method : str Method to use for rounding the cutout limits. Options are 'round', 'ceil', and 'floor'. - stretch : str - Optional, default 'asinh'. The stretch to apply to the image array. - Valid values are: asinh, sinh, sqrt, log, linear. - minmax_percent : list - Optional. Interval based on a keeping a specified fraction of pixels (can be asymmetric) - when scaling the image. The format is [lower percentile, upper percentile], where pixel - values below the lower percentile and above the upper percentile are clipped. - Only one of minmax_percent and minmax_value should be specified. - minmax_value : list - Optional. Interval based on user-specified pixel values when scaling the image. - The format is [min value, max value], where pixel values below the min value and above - the max value are clipped. - Only one of minmax_percent and minmax_value should be specified. - invert : bool - Optional, default False. If True the image is inverted (light pixels become dark and vice versa). - colorize : bool - Optional, default False. If True a single color image is produced as output, and it is expected - that three files are given as input. - output_format : str - Optional, default '.jpg'. The format of the output image file. - cutout_prefix : str - Optional, default 'cutout'. The prefix to use for the output file name. extension : int | list | 'all' Optional, default None. The extension(s) to cutout from. If None, the first extension with data is used. single_outfile : bool @@ -66,45 +41,42 @@ class FITSCutout(ImageCutout): verbose : bool If True, log messages are printed to the console. + Attributes + ---------- + cutouts_by_file : dict + The cutouts as a list of `FITSCutout.CutoutInstance` objects stored by input filename. + fits_cutouts : list + The cutouts as a list of `astropy.io.fits.HDUList` objects. + hdu_cutouts_by_file : dict + The cutouts as `astropy.io.fits.ImageHDU` objects stored by input filename. + Methods ------- - _parse_extensions() + _construct_fits_from_hdus(cutout_hdus) + Make one or more cutout HDUs into a single HDUList object. + _parse_extensions(input_file, infile_exts) Determine which extension(s) to cutout from. - _load_file_data() + _load_file_data(input_file) Load the data from an input file. - _get_img_wcs() + _get_img_wcs(hdu_header) Get the WCS for an image. - _get_cutout_data() - Get the cutout data from an image. - _get_cutout_wcs() - Get the WCS for a cutout. - _hducut() + _hducut(cutout_data, img_wcs, hdu_header, no_sip, ind, primary_filename, is_empty) Create a cutout HDU from an image HDU. - _cutout_file() + _cutout_file(file) Cutout an image file. - _construct_fits_from_hdus() - Make one or more cutout HDUs into a single HDUList object. - _write_to_memory() - Write the cutouts to memory. - _write_as_fits() - Write the cutouts to a file in FITS format. - _write_as_asdf() - Write the cutouts to a file in ASDF format. + cutout() + Generate cutouts from a list of input images. + write_as_fits(output_dir, cutout_prefix) + Write the cutouts to files in FITS format. """ def __init__(self, input_files: List[Union[str, Path, S3Path]], coordinates: Union[SkyCoord, str], cutout_size: Union[int, np.ndarray, Quantity, List[int], Tuple[int]] = 25, - fill_value: Union[int, float] = np.nan, memory_only: bool = False, - output_dir: Union[str, Path] = '.', limit_rounding_method: str = 'round', - stretch: Optional[str] = None, minmax_percent: Optional[List[int]] = None, - minmax_value: Optional[List[int]] = None, invert: Optional[bool] = None, - colorize: Optional[bool] = None, output_format: str = '.fits', - cutout_prefix: str = 'cutout', extension: Optional[Union[int, List[int], Literal['all']]] = None, + fill_value: Union[int, float] = np.nan, limit_rounding_method: str = 'round', + extension: Optional[Union[int, List[int], Literal['all']]] = None, single_outfile: bool = True, verbose: bool = False): # Superclass constructor - super().__init__(input_files, coordinates, cutout_size, fill_value, memory_only, output_dir, - limit_rounding_method, stretch, minmax_percent, minmax_value, invert, colorize, - output_format, cutout_prefix, verbose) + super().__init__(input_files, coordinates, cutout_size, fill_value, limit_rounding_method, verbose) # If a single extension is given, make it a list if isinstance(extension, int): @@ -113,6 +85,57 @@ def __init__(self, input_files: List[Union[str, Path, S3Path]], coordinates: Uni # Assigning other attributes self._single_outfile = single_outfile + self._fits_cutouts = None + self.hdu_cutouts_by_file = {} + + # Make the cutouts upon initialization + self.cutout() + + def _construct_fits_from_hdus(self, cutout_hdus: List[fits.ImageHDU]) -> fits.HDUList: + """ + Make one or more cutout HDUs into a single HDUList object. + + Parameters + ---------- + cutout_hdus : list + The `~astropy.io.fits.hdu.image.ImageHDU` object(s) to be written to the fits file. + + Returns + ------- + response : `~astropy.io.fits.HDUList` + The HDUList object. + """ + # Setting up the Primary HDU + keywords = dict() + if self._coordinates: + keywords = {'RA_OBJ': (self._coordinates.ra.deg, '[deg] right ascension'), + 'DEC_OBJ': (self._coordinates.dec.deg, '[deg] declination')} + + # Build the primary HDU with keywords + primary_hdu = fits.PrimaryHDU() + primary_hdu.header.extend([('ORIGIN', 'STScI/MAST', 'institution responsible for creating this file'), + ('DATE', str(date.today()), 'file creation date'), + ('PROCVER', __version__, 'software version')]) + for kwd in keywords: + primary_hdu.header[kwd] = keywords[kwd] + + return fits.HDUList([primary_hdu] + cutout_hdus) + + @property + def fits_cutouts(self): + """ + Return the cutouts as a list `astropy.io.fits.HDUList` objects. + """ + if not self._fits_cutouts: + fits_cutouts = [] + if self._single_outfile: # one output file for all input files + cutout_hdus = [x for file in self.hdu_cutouts_by_file for x in self.hdu_cutouts_by_file[file]] + fits_cutouts = [self._construct_fits_from_hdus(cutout_hdus)] + else: # one output file per input file + for file, cutout_list in self.hdu_cutouts_by_file.items(): + fits_cutouts.append(self._construct_fits_from_hdus(cutout_list)) + self._fits_cutouts = fits_cutouts + return self._fits_cutouts def _parse_extensions(self, input_file: Union[str, Path, S3Path], infile_exts: np.ndarray) -> List[int]: """ @@ -217,127 +240,34 @@ def _get_img_wcs(self, hdu_header: fits.Header) -> Tuple[WCS, bool]: astropy_log.log(log_rec.levelno, log_rec.msg, extra={'origin': log_rec.name}) return (img_wcs, no_sip) - - def _get_cutout_data(self, data: fits.Section, wcs: WCS) -> np.ndarray: - """ - Get the cutout data from an image. - - Parameters - ---------- - data : `~astropy.io.fits.Section` - The data for the image. - wcs : `~astropy.wcs.WCS` - The WCS for the image. - Returns - -------- - cutout_data : `numpy.ndarray` - The cutout data. - """ - log.debug('Original image shape: %s', data.shape) - - # Get the limits for the cutout - # These limits are not guaranteed to be within the image footprint - cutout_lims = self._get_cutout_limits(wcs) - xmin, xmax = cutout_lims[0] - ymin, ymax = cutout_lims[1] - ymax_img, xmax_img = data.shape - - # Check the cutout is on the image - if (xmax <= 0) or (xmin >= xmax_img) or (ymax <= 0) or (ymin >= ymax_img): - raise InvalidQueryError('Cutout location is not in image footprint!') - - # Adjust limits and figure out the padding - padding = np.zeros((2, 2), dtype=int) - if xmin < 0: - padding[1, 0] = -xmin - xmin = 0 - if ymin < 0: - padding[0, 0] = -ymin - ymin = 0 - if xmax > xmax_img: - padding[1, 1] = xmax - xmax_img - xmax = xmax_img - if ymax > ymax_img: - padding[0, 1] = ymax - ymax_img - ymax = ymax_img - img_cutout = data[ymin:ymax, xmin:xmax] - - # Adding padding to the cutout so that it's the expected size - if padding.any(): # only do if we need to pad - img_cutout = np.pad(img_cutout, padding, 'constant', constant_values=self._fill_value) - - log.debug('Image cutout shape: %s', img_cutout.shape) - - return img_cutout - - def _get_cutout_wcs(self, img_wcs: WCS, cutout_lims: np.ndarray) -> WCS: - """ - Starting with the full image WCS and adjusting it for the cutout WCS. - Adjusts CRPIX values and adds physical WCS keywords. - - Parameters - ---------- - img_wcs : `~astropy.wcs.WCS` - WCS for the image the cutout is being cut from. - cutout_lims : `numpy.ndarray` - The cutout pixel limits in an array of the form [[ymin,ymax],[xmin,xmax]] - - Returns - -------- - response : `~astropy.wcs.WCS` - The cutout WCS object including SIP distortions if present. - """ - # relax = True is important when the WCS has sip distortions, otherwise it has no effect - wcs_header = img_wcs.to_header(relax=True) - - # Adjusting the CRPIX values - wcs_header['CRPIX1'] -= cutout_lims[0, 0] - wcs_header['CRPIX2'] -= cutout_lims[1, 0] - - # Adding the physical WCS keywords - wcs_header.set('WCSNAMEP', 'PHYSICAL', 'name of world coordinate system alternate P') - wcs_header.set('WCSAXESP', 2, 'number of WCS physical axes') - wcs_header.set('CTYPE1P', 'RAWX', 'physical WCS axis 1 type CCD col') - wcs_header.set('CUNIT1P', 'PIXEL', 'physical WCS axis 1 unit') - wcs_header.set('CRPIX1P', 1, 'reference CCD column') - wcs_header.set('CRVAL1P', cutout_lims[0, 0] + 1, 'value at reference CCD column') - wcs_header.set('CDELT1P', 1.0, 'physical WCS axis 1 step') - wcs_header.set('CTYPE2P', 'RAWY', 'physical WCS axis 2 type CCD col') - wcs_header.set('CUNIT2P', 'PIXEL', 'physical WCS axis 2 unit') - wcs_header.set('CRPIX2P', 1, 'reference CCD row') - wcs_header.set('CRVAL2P', cutout_lims[1, 0] + 1, 'value at reference CCD row') - wcs_header.set('CDELT2P', 1.0, 'physical WCS axis 2 step') - - return WCS(wcs_header) - - def _hducut(self, img_hdu: fits.ImageHDU, img_wcs: WCS, hdu_header: fits.Header, no_sip: bool) -> fits.ImageHDU: + def _hducut(self, cutout_data: np.ndarray, cutout_wcs: WCS, hdu_header: fits.Header, no_sip: bool, + ind: int, primary_filename: fits.Header, is_empty: bool) -> fits.ImageHDU: """ Create a cutout HDU from an image HDU. Parameters ---------- - img_hdu : `~astropy.io.fits.ImageHDU` - The image HDU to cutout from. + cutout_data : `numpy.ndarray` + The cutout data. img_wcs : `~astropy.wcs.WCS` The WCS for the image. hdu_header : `~astropy.io.fits.Header` The header for the image HDU. no_sip : bool Whether the image WCS has no SIP information. + ind : int + The index of the extension in the original file. + primary_filename : str + The filename in the header of the primary HDU. + is_empty : bool + Indicates if the cutout has no image data. Returns ------- response : `~astropy.io.fits.ImageHDU` The cutout HDU. """ - # Get the data for the cutout - img_cutout = self._get_cutout_data(img_hdu.section, img_wcs) - - # Get the cutout WCS - # cutout_wcs = img_cutout.wcs - cutout_wcs = self._get_cutout_wcs(img_wcs, self._get_cutout_limits(img_wcs)) - # Updating the header with the new wcs info if no_sip: hdu_header.update(cutout_wcs.to_header(relax=False)) @@ -353,11 +283,18 @@ def _hducut(self, img_hdu: fits.ImageHDU, img_wcs: WCS, hdu_header: fits.Header, hdu_header.remove('FILENAME', ignore_missing=True) # Check that there is data in the cutout image - if (img_cutout == 0).all() or (np.isnan(img_cutout)).all(): + if is_empty: hdu_header['EMPTY'] = (True, 'Indicates no data in cutout image.') - self._num_empty += 1 - return fits.ImageHDU(header=hdu_header, data=img_cutout) + # Create the cutout HDU + cutout_hdu = fits.ImageHDU(header=hdu_header, data=cutout_data) + + # Adding a few more keywords + cutout_hdu.header['ORIG_EXT'] = (ind, 'Extension in original file.') + if not cutout_hdu.header.get('ORIG_FLE') and primary_filename: + cutout_hdu.header['ORIG_FLE'] = primary_filename + + return cutout_hdu def _cutout_file(self, file: Union[str, Path, S3Path]): """ @@ -371,151 +308,292 @@ def _cutout_file(self, file: Union[str, Path, S3Path]): # Load data hdulist, cutout_inds = self._load_file_data(file) + if not len(cutout_inds): # No image extensions with data were found + hdulist.close() + return + # Create HDU cutouts cutouts = [] - self._num_cutouts += len(cutout_inds) + fits_cutouts = [] + num_empty = 0 for ind in cutout_inds: try: # Get HDU, header, and WCS img_hdu = hdulist[ind] hdu_header = fits.Header(img_hdu.header, copy=True) img_wcs, no_sip = self._get_img_wcs(hdu_header) + primary_filename = hdulist[0].header.get('FILENAME') - if self._output_format == '.fits': - # Make a cutout hdu - cutout = self._hducut(img_hdu, img_wcs, hdu_header, no_sip) + # Create the cutout + # Eventually, this will be replaced by a call to Cutout2D + cutout = self.CutoutInstance(img_hdu.section, img_wcs, self) - # Adding a few more keywords - cutout.header['ORIG_EXT'] = (ind, 'Extension in original file.') - if not cutout.header.get('ORIG_FLE') and hdulist[0].header.get('FILENAME'): - cutout.header['ORIG_FLE'] = hdulist[0].header.get('FILENAME') + # Save the cutout data to use when outputting as an image + # Eventually, the values here will be a list of Cutout2D objects + is_empty = (cutout.data == 0).all() or (np.isnan(cutout.data)).all() + if is_empty: + num_empty += 1 else: - # We only need the data array for images - cutout = self._get_cutout_data(img_hdu.section, img_wcs) + cutouts.append(cutout) - # Apply the appropriate normalization parameters - cutout = self.normalize_img(cutout, self._stretch, self._minmax_percent, self._minmax_value, - self._invert) + # Also save the cutouts as ImageHDU objects for FITS output + fits_cutouts.append(self._hducut(cutout.data, cutout.wcs, hdu_header, no_sip, ind, + primary_filename, is_empty)) - if (cutout == 0).all(): - continue - - cutouts.append(cutout) except OSError as err: warnings.warn(f'Error {err} encountered when performing cutout on {file}, ' f'extension {ind}, skipping...', DataWarning) - self._num_empty += 1 + num_empty += 1 except NoOverlapError: warnings.warn(f'Cutout footprint does not overlap with data in {file}, ' f'extension {ind}, skipping...', DataWarning) - self._num_empty += 1 + num_empty += 1 except ValueError as err: if 'Input position contains invalid values' in str(err): warnings.warn(f'Cutout footprint does not overlap with data in {file}, ' f'extension {ind}, skipping...', DataWarning) - self._num_empty += 1 + num_empty += 1 else: raise # Close HDUList hdulist.close() - # Save cutouts - self._cutout_dict[file] = cutouts + if num_empty == len(cutout_inds): # No extensions have cutout data + warnings.warn(f'Cutout of {file} contains no data, skipping...', DataWarning) + else: # At least one extension has cutout data + # Save cutouts + self.cutouts_by_file[file] = cutouts + self.hdu_cutouts_by_file[file] = fits_cutouts - def _construct_fits_from_hdus(self, cutout_hdus: List[fits.ImageHDU]) -> fits.HDUList: + def cutout(self) -> Union[str, List[str], List[fits.HDUList]]: """ - Make one or more cutout HDUs into a single HDUList object. - - Parameters - ---------- - cutout_hdus : list - The `~astropy.io.fits.hdu.image.ImageHDU` object(s) to be written to the fits file. + Generate cutouts from a list of input images. Returns ------- - response : `~astropy.io.fits.HDUList` - The HDUList object. - """ - # Setting up the Primary HDU - keywords = dict() - if self._coordinates: - keywords = {'RA_OBJ': (self._coordinates.ra.deg, '[deg] right ascension'), - 'DEC_OBJ': (self._coordinates.dec.deg, '[deg] declination')} + cutout_path : Path | list + Cutouts as memory objects or path(s) to the written cutout files. - # Build the primary HDU with keywords - primary_hdu = fits.PrimaryHDU() - primary_hdu.header.extend([('ORIGIN', 'STScI/MAST', 'institution responsible for creating this file'), - ('DATE', str(date.today()), 'file creation date'), - ('PROCVER', __version__, 'software version')]) - for kwd in keywords: - primary_hdu.header[kwd] = keywords[kwd] + Raises + ------ + InvalidInputError + If no cutouts contain data. + """ + # Track start time + start_time = monotonic() - return fits.HDUList([primary_hdu] + cutout_hdus) + # Cutout each input file + for file in self._input_files: + self._cutout_file(file) + + # If no cutouts contain data, raise exception + if not self.cutouts_by_file: + raise InvalidInputError('Cutout contains no data! (Check image footprint.)') - def _write_as_fits(self) -> Union[str, List[str]]: + # Log total time elapsed + log.debug('Total time: %.2f sec', monotonic() - start_time) + + return self.fits_cutouts + + def write_as_fits(self, output_dir: Union[str, Path] = '.', cutout_prefix: str = 'cutout') -> List[str]: """ Write the cutouts to memory or to a file in FITS format. Returns ------- - cutout_paths : str | list - The path(s) to the cutout file(s). + cutout_paths : list + A list of paths to the cutout FITS files. """ + Path(output_dir).mkdir(parents=True, exist_ok=True) - if self._single_outfile: + if self._single_outfile: # one output file for all input files log.debug('Returning cutout as a single FITS file.') - cutout_hdus = [x for file in self._cutout_dict for x in self._cutout_dict[file]] - cutout_fits = self._construct_fits_from_hdus(cutout_hdus) + cutout_fits = self.fits_cutouts[0] + filename = '{}_{:.7f}_{:.7f}_{}-x-{}_astrocut.fits'.format( + cutout_prefix, + self._coordinates.ra.value, + self._coordinates.dec.value, + str(self._cutout_size[0]).replace(' ', ''), + str(self._cutout_size[1]).replace(' ', '')) + cutout_path = Path(output_dir, filename) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + cutout_fits.writeto(cutout_path, overwrite=True, checksum=True) + # Return file path or memory object + return [cutout_path.as_posix()] + + else: # one output file per input file + log.debug('Returning cutouts as individual FITS files.') - if self._memory_only: - return [cutout_fits] - else: + cutout_paths = [] + for i, file in enumerate(self.hdu_cutouts_by_file): + cutout_fits = self.fits_cutouts[i] filename = '{}_{:.7f}_{:.7f}_{}-x-{}_astrocut.fits'.format( - self._cutout_prefix, + Path(file).stem, self._coordinates.ra.value, self._coordinates.dec.value, - str(self._cutout_size[0]).replace(' ', ''), + str(self._cutout_size[0]).replace(' ', ''), str(self._cutout_size[1]).replace(' ', '')) - cutout_path = Path(self._output_dir, filename) + cutout_path = Path(output_dir, filename) with warnings.catch_warnings(): - warnings.simplefilter('ignore') + warnings.simplefilter('ignore') cutout_fits.writeto(cutout_path, overwrite=True, checksum=True) - return cutout_path.as_posix() - else: # one output file per input file - log.debug('Returning cutouts as individual FITS files.') + # Append file path or memory object + cutout_paths.append(cutout_path.as_posix()) - all_cutouts = [] - for file, cutout_list in self._cutout_dict.items(): - if np.array([x.header.get('EMPTY') for x in cutout_list]).all(): - # Skip files with no data in the cutout images - warnings.warn(f'Cutout of {file} contains no data and will not be returned.', DataWarning) - continue + log.debug('Cutout filepaths: {}'.format(cutout_paths)) + return cutout_paths + + class CutoutInstance: + """ + Represents an individual cutout with its own data and WCS. Eventually, this will be replaced + by `astropy.nddata.Cutout2D` objects. - cutout_fits = self._construct_fits_from_hdus(cutout_list) + Args + ---- + img_data : `~astropy.io.fits.Section` + The data for the image. + img_wcs : `~astropy.wcs.WCS` + The WCS for the image. + parent : `FITSCutout` + The parent FITSCutout object. - if self._memory_only: - all_cutouts.append(cutout_fits) - else: - filename = '{}_{:.7f}_{:.7f}_{}-x-{}_astrocut.fits'.format( - Path(file).stem, - self._coordinates.ra.value, - self._coordinates.dec.value, - str(self._cutout_size[0]).replace(' ', ''), - str(self._cutout_size[1]).replace(' ', '')) - cutout_path = Path(self._output_dir, filename) - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - cutout_fits.writeto(cutout_path, overwrite=True, checksum=True) - all_cutouts.append(cutout_path.as_posix()) - - if self._memory_only: - return all_cutouts - else: - return all_cutouts if len(all_cutouts) > 1 else all_cutouts[0] - - def _write_as_asdf(self): - """ASDF output is not yet implemented for FITS files.""" - warnings.warn('ASDF output is not yet implemented for FITS files.', InputWarning) - return + Attributes + ---------- + data : `numpy.ndarray` + The cutout data array. + shape : tuple + The shape of the cutout data array. + shape_input : tuple + The shape of the input image data array. + slices_original : tuple + Slices for the minimal bounding box of the cutout with respect to the original array. + xmin_original : int + The minimum x value of the cutout in the original array. + xmax_original : int + The maximum x value of the cutout in the original array. + ymin_original : int + The minimum y value of the cutout in the original array. + ymax_original : int + The maximum y value of the cutout in the original array. + wcs : `~astropy.wcs.WCS` + The WCS for the cutout. + + Methods + ------- + _get_cutout_data(data, wcs, cutout_lims, parent) + Extract the cutout data from an image. + _get_cutout_wcs(img_wcs, cutout_lims) + Get the WCS for a cutout. + """ + + def __init__(self, img_data: fits.Section, img_wcs: WCS, parent: 'FITSCutout'): + # Calculate cutout limits + cutout_lims = parent._get_cutout_limits(img_wcs) + + # Extract data from Section + self.data = self._get_cutout_data(img_data, img_wcs, cutout_lims, parent) + self.shape = self.data.shape + self.shape_input = img_data.shape + + self.wcs = self._get_cutout_wcs(img_wcs, cutout_lims) + + def _get_cutout_data(self, data: fits.Section, wcs: WCS, cutout_lims: np.ndarray, + parent: 'FITSCutout') -> np.ndarray: + """ + Extract the cutout data from an image. + + Parameters + ---------- + data : `~astropy.io.fits.Section` + The data for the image. + wcs : `~astropy.wcs.WCS` + The WCS for the image. + cutout_lims : `numpy.ndarray` + The cutout pixel limits in an array of the form [[ymin,ymax],[xmin,xmax]] + parent : `FITSCutout` + The parent FITSCutout object. Needed for access to certain attributes. + + Returns + -------- + cutout_data : `numpy.ndarray` + The cutout data. + """ + log.debug('Original image shape: %s', data.shape) + + # Get the limits for the cutout + # These limits are not guaranteed to be within the image footprint + (xmin, xmax), (ymin, ymax) = cutout_lims + ymax_img, xmax_img = data.shape + + # Check the cutout is on the image + if (xmax <= 0) or (xmin >= xmax_img) or (ymax <= 0) or (ymin >= ymax_img): + raise InvalidQueryError('Cutout location is not in image footprint!') + + # Adjust limits to fit within image bounds + xmin_clipped, xmax_clipped = max(0, xmin), min(xmax_img, xmax) + ymin_clipped, ymax_clipped = max(0, ymin), min(ymax_img, ymax) + + # Compute padding required (before and after in x and y) + padding = np.array([ + (max(0, -ymin), max(0, ymax - ymax_img)), # (top, bottom) + (max(0, -xmin), max(0, xmax - xmax_img)) # (left, right) + ]) + + # Extract the cutout + img_cutout = data[ymin_clipped:ymax_clipped, xmin_clipped:xmax_clipped] + + # Assign relevant attributes + self.slices_original = slice(ymin_clipped, ymax_clipped + 1), slice(xmin_clipped, xmax_clipped + 1) + self.xmin_original, self.xmax_original = xmin_clipped, xmax_clipped + self.ymin_original, self.ymax_original = ymin_clipped, ymax_clipped + + # Adding padding to the cutout so that it's the expected size + if padding.any(): # only do if we need to pad + img_cutout = np.pad(img_cutout, padding, 'constant', constant_values=parent._fill_value) + + log.debug('Image cutout shape: %s', img_cutout.shape) + + return img_cutout + + def _get_cutout_wcs(self, img_wcs: WCS, cutout_lims: np.ndarray) -> WCS: + """ + Starting with the full image WCS and adjusting it for the cutout WCS. + Adjusts CRPIX values and adds physical WCS keywords. + + Parameters + ---------- + img_wcs : `~astropy.wcs.WCS` + WCS for the image the cutout is being cut from. + cutout_lims : `numpy.ndarray` + The cutout pixel limits in an array of the form [[ymin,ymax],[xmin,xmax]] + + Returns + -------- + response : `~astropy.wcs.WCS` + The cutout WCS object including SIP distortions if present. + """ + # relax = True is important when the WCS has sip distortions, otherwise it has no effect + wcs_header = img_wcs.to_header(relax=True) + + # Adjusting the CRPIX values + wcs_header['CRPIX1'] -= cutout_lims[0, 0] + wcs_header['CRPIX2'] -= cutout_lims[1, 0] + + # Adding the physical WCS keywords + wcs_header.set('WCSNAMEP', 'PHYSICAL', 'name of world coordinate system alternate P') + wcs_header.set('WCSAXESP', 2, 'number of WCS physical axes') + wcs_header.set('CTYPE1P', 'RAWX', 'physical WCS axis 1 type CCD col') + wcs_header.set('CUNIT1P', 'PIXEL', 'physical WCS axis 1 unit') + wcs_header.set('CRPIX1P', 1, 'reference CCD column') + wcs_header.set('CRVAL1P', cutout_lims[0, 0] + 1, 'value at reference CCD column') + wcs_header.set('CDELT1P', 1.0, 'physical WCS axis 1 step') + wcs_header.set('CTYPE2P', 'RAWY', 'physical WCS axis 2 type CCD col') + wcs_header.set('CUNIT2P', 'PIXEL', 'physical WCS axis 2 unit') + wcs_header.set('CRPIX2P', 1, 'reference CCD row') + wcs_header.set('CRVAL2P', cutout_lims[1, 0] + 1, 'value at reference CCD row') + wcs_header.set('CDELT2P', 1.0, 'physical WCS axis 2 step') + + return WCS(wcs_header) diff --git a/astrocut/ImageCutout.py b/astrocut/ImageCutout.py index 1c3ae2cb..1800412c 100644 --- a/astrocut/ImageCutout.py +++ b/astrocut/ImageCutout.py @@ -1,20 +1,19 @@ from abc import abstractmethod, ABC from pathlib import Path -from time import monotonic from typing import List, Optional, Union, Tuple import warnings from astropy.coordinates import SkyCoord -from astropy.io import fits from astropy.units import Quantity from astropy.visualization import (SqrtStretch, LogStretch, AsinhStretch, SinhStretch, LinearStretch, MinMaxInterval, ManualInterval, AsymmetricPercentileInterval) + import numpy as np from PIL import Image from s3path import S3Path from . import log -from .exceptions import DataWarning, InputWarning, InvalidInputError, InvalidQueryError +from .exceptions import DataWarning, InputWarning, InvalidInputError from .Cutout import Cutout @@ -23,8 +22,8 @@ class ImageCutout(Cutout, ABC): Abstract class for creating cutouts from images. This class defines attributes and methods that are common to all image cutout classes. - Attributes - ---------- + Args + ---- input_files : list List of input image files. coordinates : str | `~astropy.coordinates.SkyCoord` @@ -33,100 +32,130 @@ class ImageCutout(Cutout, ABC): Size of the cutout array. fill_value : int | float Value to fill the cutout with if the cutout is outside the image. - memory_only : bool - If True, the cutout is written to memory instead of disk. - output_dir : str | Path - Directory to write the cutout file(s) to. limit_rounding_method : str Method to use for rounding the cutout limits. Options are 'round', 'ceil', and 'floor'. - stretch : str - Optional, default 'asinh'. The stretch to apply to the image array. - Valid values are: asinh, sinh, sqrt, log, linear. - minmax_percent : list - Optional. Interval based on a keeping a specified fraction of pixels (can be asymmetric) - when scaling the image. The format is [lower percentile, upper percentile], where pixel - values below the lower percentile and above the upper percentile are clipped. - Only one of minmax_percent and minmax_value should be specified. - minmax_value : list - Optional. Interval based on user-specified pixel values when scaling the image. - The format is [min value, max value], where pixel values below the min value and above - the max value are clipped. - Only one of minmax_percent and minmax_value should be specified. - invert : bool - Optional, default False. If True the image is inverted (light pixels become dark and vice versa). - colorize : bool - Optional, default False. If True a single color image is produced as output, and it is expected - that three files are given as input. - output_format : str - Optional, default '.jpg'. The format of the output image file. - cutout_prefix : str - Optional, default 'cutout'. The prefix to use for the output file name. verbose : bool If True, log messages are printed to the console. + Attributes + ---------- + cutouts_by_file : dict + Dictionary containing the cutouts for each input file. + image_cutouts : list + List of `~PIL.Image` objects representing the cutouts. + Methods ------- - _get_cutout_data() - Get the cutout data from the input image. - _cutout_file() + get_image_cutouts(stretch, minmax_percent, minmax_value, invert, colorize) + Get the cutouts as `~PIL.Image` objects. + _cutout_file(file) Cutout an image file. - _write_to_memory() - Write the cutouts to memory. - _write_as_fits() - Write the cutouts to a file in FITS format. - _write_as_asdf() - Write the cutouts to a file in ASDF format. - _write_as_img() - Write the cutouts to a file in an image format. - _write_cutouts() - Write the cutouts according to the specified location and output format. cutout() Generate the cutouts. - normalize_img() + _parse_output_format(output_format) + Parse the output format string and return it in a standardized format. + _save_img_to_file(im, file_path) + Save a `~PIL.Image` object to a file. + write_as_img(stretch, minmax_percent, minmax_value, invert, colorize, output_format, output_dir, cutout_prefix) + Write the cutouts to a file in an image format. + normalize_img(stretch, minmax_percent, minmax_value, invert) Apply given stretch and scaling to an image array. """ def __init__(self, input_files: List[Union[str, Path, S3Path]], coordinates: Union[SkyCoord, str], cutout_size: Union[int, np.ndarray, Quantity, List[int], Tuple[int]] = 25, - fill_value: Union[int, float] = np.nan, memory_only: bool = False, - output_dir: Union[str, Path] = '.', limit_rounding_method: str = 'round', - stretch: Optional[str] = None, minmax_percent: Optional[List[int]] = None, - minmax_value: Optional[List[int]] = None, invert: Optional[bool] = None, - colorize: Optional[bool] = None, output_format: str = 'jpg', - cutout_prefix: str = 'cutout', verbose: bool = False): - super().__init__(input_files, coordinates, cutout_size, fill_value, memory_only, output_dir, - limit_rounding_method, verbose) - # Output format should be lowercase and begin with a dot - out_lower = output_format.lower() - self._output_format = f'.{out_lower}' if not output_format.startswith('.') else out_lower - - # Warn if image processing parameters are provided for FITS output - if (self._output_format == '.fits') and (stretch or minmax_percent or - minmax_value or invert or colorize): - warnings.warn('Stretch, minmax_percent, minmax_value, invert, and colorize are not supported ' - 'for FITS output and will be ignored.', InputWarning) - - # Assign attributes with defaults if not provided - stretch = stretch or 'asinh' + fill_value: Union[int, float] = np.nan, limit_rounding_method: str = 'round', + verbose: bool = False): + super().__init__(input_files, coordinates, cutout_size, fill_value, limit_rounding_method, verbose) + + # Initialize cutout dictionary and counters + self.cutouts_by_file = {} + self._image_cutouts = None + + @property + def image_cutouts(self) -> List[Image.Image]: + """ + Return the cutouts as a list of `PIL.Image` objects. + + If the image objects have not been generated yet, they will be generated with default + normalization parameters. + """ + if not self._image_cutouts: + self._image_cutouts = self.get_image_cutouts() + return self._image_cutouts + + def get_image_cutouts(self, stretch: Optional[str] = 'asinh', minmax_percent: Optional[List[int]] = None, + minmax_value: Optional[List[int]] = None, invert: Optional[bool] = False, + colorize: Optional[bool] = False) -> List[Image.Image]: + """ + Get the cutouts as `~PIL.Image` objects given certain normalization parameters. This method also sets + the `image_cutouts` attribute. + + Parameters + ---------- + stretch : str + Optional, default 'asinh'. The stretch to apply to the image array. + Valid values are: asinh, sinh, sqrt, log, linear + minmax_percent : array + Optional. Interval based on a keeping a specified fraction of pixels (can be asymmetric) + when scaling the image. The format is [lower percentile, upper percentile], where pixel + values below the lower percentile and above the upper percentile are clipped. + Only one of minmax_percent and minmax_value should be specified. + minmax_value : array + Optional. Interval based on user-specified pixel values when scaling the image. + The format is [min value, max value], where pixel values below the min value and above + the max value are clipped. + Only one of minmax_percent and minmax_value should be specified. + invert : bool + Optional, default False. If True the image is inverted (light pixels become dark and vice versa). + colorize : bool + Optional, default False. If True, the first three cutouts will be combined into a single RGB image. + + Returns + ------- + image_cutouts : list + List of `~PIL.Image` objects representing the cutouts. + """ + # Validate the stretch parameter valid_stretches = ['asinh', 'sinh', 'sqrt', 'log', 'linear'] if not isinstance(stretch, str) or stretch.lower() not in valid_stretches: raise InvalidInputError(f'Stretch {stretch} is not recognized. Valid options are {valid_stretches}.') - self._stretch = stretch.lower() - self._invert = invert or False - self._colorize = colorize or False - self._minmax_percent = minmax_percent - self._minmax_value = minmax_value - self._cutout_prefix = cutout_prefix - - # Initialize cutout dictionary and counters - self._cutout_dict = {} - self._num_empty = 0 - self._num_cutouts = 0 + stretch = stretch.lower() # Apply default scaling for image outputs - if (self._minmax_percent is None) and (self._minmax_value is None): - self._minmax_percent = [0.5, 99.5] + if (minmax_percent is None) and (minmax_value is None): + minmax_percent = [0.5, 99.5] + + if colorize: # color cutout + all_cutouts = [x for fle in self._input_files for x in self.cutouts_by_file.get(fle, [])] + + # Check for the correct number of cutouts + if len(all_cutouts) < 3: + raise InvalidInputError(('Color cutouts require 3 input images (RGB).' + 'If you supplied 3 images one of the cutouts may have been empty.')) + if len(all_cutouts) > 3: + warnings.warn('Too many inputs for a color cutout, only the first three will be used.', InputWarning) + all_cutouts = all_cutouts[:3] + + img_arrs = [] + for cutout in all_cutouts: + # Image output, applying the appropriate normalization parameters + img_arrs.append(self.normalize_img(cutout.data, stretch, minmax_percent, minmax_value, invert)) + + # Combine the three cutouts into a single RGB image + self._image_cutouts = [Image.fromarray(np.dstack([img_arrs[0], img_arrs[1], img_arrs[2]]).astype(np.uint8))] + else: # one image per cutout + image_cutouts = [] + for file, cutout_list in self.cutouts_by_file.items(): + for i, cutout in enumerate(cutout_list): + # Apply the appropriate normalization parameters + img_arr = self.normalize_img(cutout.data, stretch, minmax_percent, minmax_value, invert) + image_cutouts.append(Image.fromarray(img_arr)) + + self._image_cutouts = image_cutouts + return self._image_cutouts + @abstractmethod def _cutout_file(self, file: Union[str, Path, S3Path]): """ @@ -136,24 +165,38 @@ def _cutout_file(self, file: Union[str, Path, S3Path]): """ pass - @abstractmethod - def _write_as_fits(self): + def cutout(self): """ - Write the cutouts to a file in FITS format. + Generate the cutout(s). - This method is abstract and should be defined in the subclass. + This method is abstract and should be defined in subclasses. """ pass - - @abstractmethod - def _write_as_asdf(self): + + def _parse_output_format(self, output_format: str) -> str: """ - Write the cutouts to a file in ASDF format. + Parse the output format string and return it in a standardized format. + + Parameters + ---------- + output_format : str + The output format string. - This method is abstract and should be defined in the subclass. + Returns + ------- + out_format : str + The output format string in a standardized format. """ - pass + # Put format in standard format + out_lower = output_format.lower() + output_format = f'.{out_lower}' if not output_format.startswith('.') else out_lower + + # Error if the output format is not supported + if output_format not in Image.registered_extensions().keys(): + raise InvalidInputError(f'Output format {output_format} is not supported.') + + return output_format def _save_img_to_file(self, im: Image, file_path: str) -> bool: """ @@ -175,22 +218,53 @@ def _save_img_to_file(self, im: Image, file_path: str) -> bool: im.save(file_path) return True except ValueError as e: - warnings.warn(f'Cutout could not be saved in {self._output_format} format: {e}. ' + output_format = Path(file_path).suffix + warnings.warn(f'Cutout could not be saved in {output_format} format: {e}. ' 'Please try a different output format.', DataWarning) return False except KeyError as e: - warnings.warn(f'Cutout could not be saved in {self._output_format} format due to a KeyError: {e}. ' + output_format = Path(file_path).suffix + warnings.warn(f'Cutout could not be saved in {output_format} format due to a KeyError: {e}. ' 'Please try a different output format.', DataWarning) return False except OSError as e: warnings.warn(f'Cutout could not be saved: {e}', DataWarning) return False - def _write_as_img(self) -> Union[str, List[str]]: + def write_as_img(self, stretch: Optional[str] = 'asinh', minmax_percent: Optional[List[int]] = None, + minmax_value: Optional[List[int]] = None, invert: Optional[bool] = False, + colorize: Optional[bool] = False, output_format: str = '.jpg', + output_dir: Union[str, Path] = '.', cutout_prefix: str = 'cutout') -> Union[str, List[str]]: """ Write the cutout to memory or to a file in an image format. If colorize is set, the first 3 cutouts will be combined into a single RGB image. Otherwise, each cutout will be written to a separate file. + Parameters + ---------- + stretch : str + Optional, default 'asinh'. The stretch to apply to the image array. + Valid values are: asinh, sinh, sqrt, log, linear + minmax_percent : array + Optional. Interval based on a keeping a specified fraction of pixels (can be asymmetric) + when scaling the image. The format is [lower percentile, upper percentile], where pixel + values below the lower percentile and above the upper percentile are clipped. + Only one of minmax_percent and minmax_value shoul be specified. + minmax_value : array + Optional. Interval based on user-specified pixel values when scaling the image. + The format is [min value, max value], where pixel values below the min value and above + the max value are clipped. + Only one of minmax_percent and minmax_value should be specified. + invert : bool + Optional, default False. If True the image is inverted (light pixels become dark and vice versa). + colorize : bool + Optional, default False. If True, the first three cutouts will be combined into a single RGB image. + output_format : str + Optional, default '.jpg'. The output format for the cutout image(s). + output_dir : str | `~pathlib.Path` + Optional, default '.'. The directory to write the cutout image(s) to. + cutout_prefix : str + Optional, default 'cutout'. The prefix to add to the cutout image file name. + Returns ------- cutout_path : List[Path] @@ -201,129 +275,58 @@ def _write_as_img(self) -> Union[str, List[str]]: InvalidInputError If less than three inputs were provided for a colorized cutout. """ - # Set up output files and write them - if self._colorize: # Combine first three cutouts into a single RGB image - cutouts = [x for fle in self._input_files for x in self._cutout_dict.get(fle, [])] - - # Check for the correct number of cutouts - if len(cutouts) < 3: - raise InvalidInputError(('Color cutouts require 3 input images (RGB).' - 'If you supplied 3 images one of the cutouts may have been empty.')) - if len(cutouts) > 3: - warnings.warn('Too many inputs for a color cutout, only the first three will be used.', InputWarning) - cutouts = cutouts[:3] + # Parse the output format + output_format = self._parse_output_format(output_format) - im = Image.fromarray(np.dstack([cutouts[0], cutouts[1], cutouts[2]]).astype(np.uint8)) + # Get the image cutouts with the given normalization parameters + image_cutouts = self.get_image_cutouts(stretch, minmax_percent, minmax_value, invert, colorize) - if self._memory_only: - return [im] + # Create the output directory if it does not exist + Path(output_dir).mkdir(parents=True, exist_ok=True) + # Set up output files and write them + if colorize: # Combine first three cutouts into a single RGB image # Write the colorized cutout to disk - cutout_path = '{}_{:.7f}_{:.7f}_{}-x-{}_astrocut{}'.format( - self._cutout_prefix, + filename = '{}_{:.7f}_{:.7f}_{}-x-{}_astrocut{}'.format( + cutout_prefix, self._coordinates.ra.value, self._coordinates.dec.value, str(self._cutout_size[0]).replace(' ', ''), str(self._cutout_size[1]).replace(' ', ''), - self._output_format + output_format ) - cutout_path = Path(self._output_dir, cutout_path).as_posix() - success = self._save_img_to_file(im, cutout_path) + + # Attempt to write image to file + cutout_paths = Path(output_dir, filename).as_posix() + success = self._save_img_to_file(image_cutouts[0], cutout_paths) if not success: - return + cutout_paths = None else: # Write each cutout to a separate image file - cutout_path = [] # Store the paths of the written cutout files - for file, cutout_list in self._cutout_dict.items(): - if not cutout_list: - warnings.warn(f'Cutout of {file} contains no data and will not be written.', DataWarning) - continue - for i, cutout in enumerate(cutout_list): - - im = Image.fromarray(cutout) - if self._memory_only: - cutout_path.append(im) - continue - - # Write individual cutouts to disk - file_path = '{}_{:.7f}_{:.7f}_{}-x-{}_astrocut_{}{}'.format( - Path(file).stem, - self._coordinates.ra.value, - self._coordinates.dec.value, - str(self._cutout_size[0]).replace(' ', ''), - str(self._cutout_size[1]).replace(' ', ''), - i, - self._output_format) - file_path = Path(self._output_dir, file_path).as_posix() - success = self._save_img_to_file(im, file_path) - if success: - cutout_path.append(file_path) - - return cutout_path - - def _write_cutouts(self) -> Union[str, List]: - """ - Write the cutout to a file according to the specified output format. - - Returns - ------- - cutout_path : Path | list - Cutouts as memory objects or path(s) to the written cutout files. - - Raises - ------ - InvalidInputError - If the output format is not supported. - """ - if self._memory_only: - # Write only to memory if specified - log.info('Writing cutouts to memory only. No output files will be created.') - else: - # If writing to disk, ensure that output directory exists - Path(self._output_dir).mkdir(parents=True, exist_ok=True) - - if self._output_format == '.fits': - return self._write_as_fits() - elif self._output_format == '.asdf': - return self._write_as_asdf() - elif self._output_format in Image.registered_extensions().keys(): - return self._write_as_img() - else: - raise InvalidInputError(f'Output format {self._output_format} is not supported.') - - def cutout(self) -> Union[str, List[str], List[fits.HDUList]]: - """ - Generate cutouts from a list of input images. - - Returns - ------- - cutout_path : Path | list - Cutouts as memory objects or path(s) to the written cutout files. - - Raises - ------ - InvalidQueryError - If no cutouts contain data. - """ - # Track start time - start_time = monotonic() - - # Cutout each input file - for file in self._input_files: - self._cutout_file(file) - - # If no cutouts contain data, raise exception - if self._num_cutouts == self._num_empty: - raise InvalidQueryError('Cutout contains no data! (Check image footprint.)') - - # Write cutout(s) - cutout_path = self._write_cutouts() - - # Log cutout path and total time elapsed - log.debug('Cutout fits file(s): %s', cutout_path) - log.debug('Total time: %.2f sec', monotonic() - start_time) - - return cutout_path + cutout_paths = [] # Store the paths of the written cutout files + for i, file in enumerate(self.cutouts_by_file): + # Write individual cutouts to disk + filename = '{}_{:.7f}_{:.7f}_{}-x-{}_astrocut_{}{}'.format( + Path(file).stem, + self._coordinates.ra.value, + self._coordinates.dec.value, + str(self._cutout_size[0]).replace(' ', ''), + str(self._cutout_size[1]).replace(' ', ''), + i, + output_format) + + # Attempt to write image to file + cutout_path = Path(output_dir, filename).as_posix() + success = self._save_img_to_file(image_cutouts[i], cutout_path) + + # Append the path to the written file or the memory object + # If the image could not be written, append None + if not success: + cutout_path = None + cutout_paths.append(cutout_path) + + log.debug('Cutout filepaths: {}'.format(cutout_paths)) + return cutout_paths @staticmethod def normalize_img(img_arr: np.ndarray, stretch: str = 'asinh', minmax_percent: Optional[List[int]] = None, diff --git a/astrocut/asdf_cutouts.py b/astrocut/asdf_cutouts.py index cda9df39..2cbbb435 100644 --- a/astrocut/asdf_cutouts.py +++ b/astrocut/asdf_cutouts.py @@ -1,272 +1,63 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst """This module implements cutout functionality similar to fitscut, but for the ASDF file format.""" -import copy -import pathlib -from typing import Union, Tuple -import requests +from pathlib import Path +from typing import List, Union -import asdf import astropy import gwcs import numpy as np -import s3fs +from astropy.utils.decorators import deprecated_renamed_argument from s3path import S3Path -from astropy.coordinates import SkyCoord -from astropy.modeling import models - -from . import log -from .utils.utils import _handle_verbose - - -def _get_cloud_http(s3_uri: Union[str, S3Path], key: str = None, secret: str = None, - token: str = None, verbose: bool = False) -> str: - """ - Get the HTTP URI of a cloud resource from an S3 URI. - - Parameters - ---------- - s3_uri : string | S3Path - the S3 URI of the cloud resource - key : string - Default None. Access key ID for S3 file system. - secret : string - Default None. Secret access key for S3 file system. - token : string - Default None. Security token for S3 file system. - verbose : bool - Default False. If true intermediate information is printed. - """ - - # check if public or private by sending an HTTP request - s3_path = S3Path.from_uri(s3_uri) if isinstance(s3_uri, str) else s3_uri - url = f'https://{s3_path.bucket}.s3.amazonaws.com/{s3_path.key}' - resp = requests.head(url, timeout=10) - is_anon = False if resp.status_code == 403 else True - if not is_anon: - log.debug('Attempting to access private S3 bucket: %s', s3_path.bucket) - - # create file system and get URL of file - fs = s3fs.S3FileSystem(anon=is_anon, key=key, secret=secret, token=token) - with fs.open(s3_uri, 'rb') as f: - return f.url() +from .ASDFCutout import ASDFCutout +from .exceptions import InvalidInputError def get_center_pixel(gwcsobj: gwcs.wcs.WCS, ra: float, dec: float) -> tuple: """ - Get the center pixel from a Roman 2D science image. - - For an input RA, Dec sky coordinate, get the closest pixel location - on the input Roman image. + Get the closest pixel location on an input image for a given set of coordinates. Parameters ---------- gwcsobj : gwcs.wcs.WCS - The Roman GWCS object. + The GWCS object. ra : float - The input right ascension. + The right ascension of the input coordinates. dec : float - The input declination. - - Returns - ------- - tuple - The pixel position, FITS wcs object - """ - - # Convert the gwcs object to an astropy FITS WCS header - header = gwcsobj.to_fits_sip() - - # Update WCS header with some keywords that it's missing. - # Otherwise, it won't work with astropy.wcs tools (TODO: Figure out why. What are these keywords for?) - for k in ['cpdis1', 'cpdis2', 'det2im1', 'det2im2', 'sip']: - if k not in header: - header[k] = 'na' - - # New WCS object with updated header - wcs_updated = astropy.wcs.WCS(header) - - # Turn input RA, Dec into a SkyCoord object - coordinates = SkyCoord(ra, dec, unit='deg') - - # Map the coordinates to a pixel's location on the Roman 2d array (row, col) - row, col = gwcsobj.invert(coordinates) - - return (row, col), wcs_updated - - -def _get_cutout(data: asdf.tags.core.ndarray.NDArrayType, coords: Union[tuple, SkyCoord], - wcs: astropy.wcs.wcs.WCS = None, size: int = 20, outfile: str = "example_roman_cutout.fits", - write_file: bool = True, fill_value: Union[int, float] = np.nan, - gwcsobj: gwcs.wcs.WCS = None) -> astropy.nddata.Cutout2D: - """ - Get a Roman image cutout. - - Cut out a square section from the input image data array. The ``coords`` can either be a tuple of x, y - pixel coordinates or an astropy SkyCoord object, in which case, a wcs is required. Writes out a - new output file containing the image cutout of the specified ``size``. Default is 20 pixels. - - Parameters - ---------- - data : asdf.tags.core.ndarray.NDArrayType - the input Roman image data array - coords : Union[tuple, SkyCoord] - the input pixel or sky coordinates - wcs : astropy.wcs.wcs.WCS, Optional - the astropy FITS wcs object - size : int, optional - the image cutout pizel size, by default 20 - outfile : str, optional - the name of the output cutout file, by default "example_roman_cutout.fits" - write_file : bool, by default True - Flag to write the cutout to a file or not - fill_value: int | float, by default np.nan - The fill value for pixels outside the original image. - gwcsobj : gwcs.wcs.WCS, Optional - the original gwcs object for the full image, needed only when writing cutout as asdf file - - Returns - ------- - astropy.nddata.Cutout2D: - an image cutout object - - Raises - ------ - ValueError: - when a wcs is not present when coords is a SkyCoord object - RuntimeError: - when the requested cutout does not overlap with the original image - ValueError: - when no gwcs object is provided when writing to an asdf file - """ - - # check for correct inputs - if isinstance(coords, SkyCoord) and not wcs: - raise ValueError('wcs must be input if coords is a SkyCoord.') - - # create the cutout - try: - cutout = astropy.nddata.Cutout2D(data, position=coords, wcs=wcs, size=(size, size), mode='partial', - fill_value=fill_value) - except astropy.nddata.utils.NoOverlapError as e: - raise RuntimeError('Could not create 2d cutout. The requested cutout does not overlap with the ' - 'original image.') from e - - # check if the data is a quantity and get the array data - if isinstance(cutout.data, astropy.units.Quantity): - data = cutout.data.value - else: - data = cutout.data - - # write the cutout to the output file - if write_file: - # check the output file type - out = pathlib.Path(outfile) - write_as = out.suffix or '.fits' - outfile = outfile if out.suffix else str(out) + write_as - - # write out the file - if write_as == '.fits': - _write_fits(cutout, outfile) - elif write_as == '.asdf': - if not gwcsobj: - raise ValueError('The original gwcs object is needed when writing to asdf file.') - _write_asdf(cutout, gwcsobj, outfile) - - return cutout - - -def _write_fits(cutout: astropy.nddata.Cutout2D, outfile: str = "example_roman_cutout.fits"): - """ - Write cutout as FITS file. - - Parameters - ---------- - cutout : astropy.nddata.Cutout2D - the 2d cutout - outfile : str, optional - the name of the output cutout file, by default "example_roman_cutout.fits" - """ - # check if the data is a quantity and get the array data - if isinstance(cutout.data, astropy.units.Quantity): - data = cutout.data.value - else: - data = cutout.data - - astropy.io.fits.writeto(outfile, data=data, header=cutout.wcs.to_header(relax=True), overwrite=True) - - -def _slice_gwcs(gwcsobj: gwcs.wcs.WCS, slices: Tuple[slice, slice]) -> gwcs.wcs.WCS: - """ - Slice the original gwcs object. - - "Slices" the original gwcs object down to the cutout shape. This is a hack - until proper gwcs slicing is in place a la fits WCS slicing. The ``slices`` - keyword input is a tuple with the x, y cutout boundaries in the original image - array, e.g. ``cutout.slices_original``. Astropy Cutout2D slices are in the form - ((ymin, ymax, None), (xmin, xmax, None)) - - Parameters - ---------- - gwcsobj : gwcs.wcs.WCS - the original gwcs from the input image - slices : Tuple[slice, slice] - the cutout x, y slices as ((ymin, ymax), (xmin, xmax)) + The declination of the input coordinates. Returns ------- - gwcs.wcs.WCS - The sliced gwcs object + pixel_position + The pixel position of the input coordinates. + wcs_updated : `~astropy.wcs.WCS` + The approximated FITS WCS object. """ - tmp = copy.deepcopy(gwcsobj) - - # get the cutout array bounds and create a new shift transform to the cutout - # add the new transform to the gwcs - xmin, xmax = slices[1].start, slices[1].stop - ymin, ymax = slices[0].start, slices[0].stop - shape = (ymax - ymin, xmax - xmin) - offsets = models.Shift(xmin, name='cutout_offset1') & models.Shift(ymin, name='cutout_offset2') - tmp.insert_transform('detector', offsets, after=True) - - # modify the gwcs bounding box to the cutout shape - tmp.bounding_box = ((0, shape[0] - 1), (0, shape[1] - 1)) - tmp.pixel_shape = shape[::-1] - tmp.array_shape = shape - return tmp - - -def _write_asdf(cutout: astropy.nddata.Cutout2D, gwcsobj: gwcs.wcs.WCS, outfile: str = "example_roman_cutout.asdf"): - """ - Write cutout as ASDF file. - - Parameters - ---------- - cutout : astropy.nddata.Cutout2D - the 2d cutout - gwcsobj : gwcs.wcs.WCS - the original gwcs object for the full image - outfile : str, optional - the name of the output cutout file, by default "example_roman_cutout.asdf" + return ASDFCutout.get_center_pixel(gwcsobj, ra, dec) + + +@deprecated_renamed_argument('output_file', None, '1.0.0', warning_type=DeprecationWarning, + message='`output_file` is non-operational and will be removed in a future version.') +def asdf_cut(input_files: List[Union[str, Path, S3Path]], + ra: float, + dec: float, + cutout_size: int = 25, + output_file: Union[str, Path] = "example_roman_cutout.fits", + write_file: bool = True, + fill_value: Union[int, float] = np.nan, + output_dir: Union[str, Path] = '.', + output_format: str = '.asdf', + key: str = None, + secret: str = None, + token: str = None, + verbose: bool = False) -> astropy.nddata.Cutout2D: """ - # slice the origial gwcs to the cutout - sliced_gwcs = _slice_gwcs(gwcsobj, cutout.slices_original) - - # create the asdf tree - tree = {'roman': {'meta': {'wcs': sliced_gwcs}, 'data': cutout.data}} - af = asdf.AsdfFile(tree) - - # Write the data to a new file - af.write_to(outfile) + Takes one of more ASDF input files (`input_files`) and generates a cutout of designated size `cutout_size` + around the given coordinates (`coordinates`). The cutout is written to a file or returned as an object. - -def asdf_cut(input_file: Union[str, pathlib.Path, S3Path], ra: float, dec: float, cutout_size: int = 25, - output_file: Union[str, pathlib.Path] = "example_roman_cutout.fits", - write_file: bool = True, fill_value: Union[int, float] = np.nan, key: str = None, - secret: str = None, token: str = None, verbose: bool = False) -> astropy.nddata.Cutout2D: - """ - Takes a single ASDF input file (`input_file`) and generates a cutout of designated size `cutout_size` - around the given coordinates (`coordinates`). + This function is maintained for backwards compatibility. For maximum flexibility, we recommend using the + ``ASDFCutout``class directly. Parameters ---------- @@ -283,10 +74,17 @@ def asdf_cut(input_file: Union[str, pathlib.Path, S3Path], ra: float, dec: float pixel grid. output_file : str | Path Optional, default "example_roman_cutout.fits". The name of the output cutout file. + This parameter is deprecated and will be removed in a future version. write_file : bool Optional, default True. Flag to write the cutout to a file or not. fill_value: int | float Optional, default `np.nan`. The fill value for pixels outside the original image. + output_dir : str | Path + Optional, default ".". The directory to write the cutout file(s) to. + output_format : str + Optional, default ".asdf". The format of the output cutout file. If `write_file` is False, + then cutouts will be returned as `asdf.AsdfFile` objects if `output_format` is ".asdf" or + as `astropy.io.fits.HDUList` objects if `output_format` is ".fits". key : string Default None. Access key ID for S3 file system. Only applicable if `input_file` is a cloud resource. @@ -301,25 +99,26 @@ def asdf_cut(input_file: Union[str, pathlib.Path, S3Path], ra: float, dec: float Returns ------- - astropy.nddata.Cutout2D: - An image cutout object. + response : str | list + A list of cutout file paths if `write_file` is True, otherwise a list of cutout objects. """ - # Log messages based on verbosity - _handle_verbose(verbose) - - # if file comes from AWS cloud bucket, get HTTP URL to open with asdf - file = input_file - if (isinstance(input_file, str) and input_file.startswith('s3://')) or isinstance(input_file, S3Path): - file = _get_cloud_http(input_file, key, secret, token, verbose) - - # get the 2d image data - with asdf.open(file) as f: - data = f['roman']['data'] - gwcsobj = f['roman']['meta']['wcs'] - - # get the center pixel - pixel_coordinates, wcs = get_center_pixel(gwcsobj, ra, dec) + asdf_cutout = ASDFCutout(input_files, f'{ra} {dec}', cutout_size, fill_value, key=key, + secret=secret, token=token, verbose=verbose) + + if not write_file: # Returns as Cutout2D objects + return asdf_cutout.cutouts + + # Get output format in standard form + output_format = f'.{output_format}' if not output_format.startswith('.') else output_format + output_format = output_format.lower() + + if output_format == '.asdf': + return asdf_cutout.write_as_asdf(output_dir) + elif output_format == '.fits': + return asdf_cutout.write_as_fits(output_dir) + else: + # Error if output format not recognized + raise InvalidInputError(f'Output format {output_format} is not recognized. ' + 'Valid options are ".asdf" and ".fits".') + - # create the 2d image cutout - return _get_cutout(data, pixel_coordinates, wcs, size=cutout_size, outfile=output_file, - write_file=write_file, fill_value=fill_value, gwcsobj=gwcsobj) diff --git a/astrocut/cutouts.py b/astrocut/cutouts.py index e31846a5..a1aed941 100644 --- a/astrocut/cutouts.py +++ b/astrocut/cutouts.py @@ -21,7 +21,7 @@ def fits_cut(input_files: List[Union[str, Path, S3Path]], coordinates: Union[Sky cutout_size: Union[int, np.ndarray, Quantity, List[int], Tuple[int]] = 25, correct_wcs: bool = False, extension: Optional[Union[int, List[int], Literal['all']]] = None, single_outfile: bool = True, cutout_prefix: str = 'cutout', output_dir: Union[str, Path] = '.', - memory_only: bool = False, limit_rounding_method: str = 'round', + memory_only: bool = False, fill_value: Union[int, float] = np.nan, limit_rounding_method: str = 'round', verbose=False) -> Union[str, List[str], List[HDUList]]: """ Takes one or more FITS files with the same WCS/pointing, makes the same cutout in each file, @@ -66,6 +66,8 @@ def fits_cut(input_files: List[Union[str, Path, S3Path]], coordinates: Union[Sky the cutout(s) are returned as a list of `~astropy.io.fit.HDUList` objects. If set to True cutout_prefix and output_dir are ignored, however single_outfile can still be used to set the number of returned `~astropy.io.fits.HDUList` objects. + fill_value : int | float + Value to fill the cutout with if the cutout is outside the image. limit_rounding_method : str Method to use for rounding the cutout limits. Options are 'round', 'ceil', and 'floor'. verbose : bool @@ -79,17 +81,14 @@ def fits_cut(input_files: List[Union[str, Path, S3Path]], coordinates: Union[Sky If memory_only is True, a list of `~astropy.io.fit.HDUList` objects is returned instead of file name(s). """ - return FITSCutout(input_files, - coordinates=coordinates, - cutout_size=cutout_size, - memory_only=memory_only, - output_dir=output_dir, - limit_rounding_method=limit_rounding_method, - output_format='.fits', - extension=extension, - single_outfile=single_outfile, - cutout_prefix=cutout_prefix, - verbose=verbose).cutout() + fits_cutout = FITSCutout(input_files, coordinates, cutout_size, fill_value, limit_rounding_method, + extension, single_outfile, verbose) + + if memory_only: + return fits_cutout.fits_cutouts + + cutout_paths = fits_cutout.write_as_fits(output_dir, cutout_prefix) + return cutout_paths[0] if len(cutout_paths) == 1 else cutout_paths def normalize_img(img_arr: np.ndarray, stretch: str = 'asinh', minmax_percent: Optional[List[int]] = None, @@ -159,8 +158,6 @@ def img_cut(input_files: List[Union[str, Path, S3Path]], coordinates: Union[SkyC If ``cutout_size`` has two elements, they should be in ``(ny, nx)`` order. Scalar numbers in ``cutout_size`` are assumed to be in units of pixels. `~astropy.units.Quantity` objects must be in pixel or angular units. - fill_value : int | float - Value to fill the cutout with if the cutout is outside the image. stretch : str Optional, default 'asinh'. The stretch to apply to the image array. Valid values are: asinh, sinh, sqrt, log, linear @@ -186,6 +183,8 @@ def img_cut(input_files: List[Union[str, Path, S3Path]], coordinates: Union[SkyC cutout filename. output_dir : str Defaul value '.'. The directory to save the cutout file(s) to. + fill_value : int | float + Value to fill the cutout with if the cutout is outside the image. limit_rounding_method : str Method to use for rounding the cutout limits. Options are 'round', 'ceil', and 'floor'. extension : int, list of ints, None, or 'all' @@ -201,18 +200,9 @@ def img_cut(input_files: List[Union[str, Path, S3Path]], coordinates: Union[SkyC the output filepaths. """ - return FITSCutout(input_files=input_files, - coordinates=coordinates, - cutout_size=cutout_size, - fill_value=fill_value, - output_dir=output_dir, - limit_rounding_method=limit_rounding_method, - stretch=stretch, - minmax_percent=minmax_percent, - minmax_value=minmax_value, - invert=invert, - colorize=colorize, - output_format=img_format, - cutout_prefix=cutout_prefix, - extension=extension, - verbose=verbose).cutout() + fits_cutout = FITSCutout(input_files, coordinates, cutout_size, fill_value, limit_rounding_method, + extension, verbose=verbose) + + cutout_paths = fits_cutout.write_as_img(stretch, minmax_percent, minmax_value, invert, colorize, img_format, + output_dir, cutout_prefix) + return cutout_paths diff --git a/astrocut/tests/test_ASDFCutout.py b/astrocut/tests/test_ASDFCutout.py new file mode 100644 index 00000000..feac7c2a --- /dev/null +++ b/astrocut/tests/test_ASDFCutout.py @@ -0,0 +1,339 @@ +from pathlib import Path +import numpy as np +import pytest + +import asdf +from astropy import coordinates as coord +from astropy import units as u +from astropy.coordinates import SkyCoord +from astropy.modeling import models +from astropy.nddata import Cutout2D +from astropy.io import fits +from gwcs import wcs, coordinate_frames +from PIL import Image + +from astrocut.ASDFCutout import ASDFCutout +from astrocut.asdf_cutouts import asdf_cut +from astrocut.exceptions import DataWarning, InvalidInputError + + +def make_wcs(xsize, ysize, ra=30., dec=45.): + """ Create a fake gwcs object """ + # todo - refine this to better reflect roman wcs + + # create transformations + # - shift coords so array center is at 0, 0 ; reference pixel + # - scale pixels to correct angular scale + # - project coords onto sky with TAN projection + # - transform center pixel to the input celestial coordinate + pixelshift = models.Shift(-xsize) & models.Shift(-ysize) + pixelscale = models.Scale(0.1 / 3600.) & models.Scale(0.1 / 3600.) # 0.1 arcsec/pixel + tangent_projection = models.Pix2Sky_TAN() + celestial_rotation = models.RotateNative2Celestial(ra, dec, 180.) + + # net transforms pixels to sky + det2sky = pixelshift | pixelscale | tangent_projection | celestial_rotation + + # define the wcs object + detector_frame = coordinate_frames.Frame2D(name='detector', axes_names=('x', 'y'), unit=(u.pix, u.pix)) + sky_frame = coordinate_frames.CelestialFrame(reference_frame=coord.ICRS(), name='world', unit=(u.deg, u.deg)) + return wcs.WCS([(detector_frame, det2sky), (sky_frame, None)]) + + +@pytest.fixture() +def makefake(): + """ Fixture factory to make a fake gwcs and dataset """ + + def _make_fake(nx, ny, ra, dec, zero=False, asint=False): + # create the wcs + wcsobj = make_wcs(nx/2, ny/2, ra=ra, dec=dec) + wcsobj.bounding_box = ((0, nx), (0, ny)) + + # create the data + if zero: + data = np.zeros([nx, ny]) + else: + size = nx * ny + data = np.arange(size).reshape(nx, ny) + + # make a quantity + data *= (u.electron / u.second) + + # make integer array + if asint: + data = data.astype(int) + + return data, wcsobj + + yield _make_fake + + +@pytest.fixture() +def fakedata(makefake): + """ Fixture to create fake data and wcs """ + # set up initial parameters + nx = 1000 + ny = 1000 + ra = 30. + dec = 45. + + yield makefake(nx, ny, ra, dec) + + +@pytest.fixture() +def test_images(tmp_path, fakedata): + """ Fixture to create a fake dataset of 3 images """ + # get the fake data + data, wcsobj = fakedata + + # create meta + meta = {'wcs': wcsobj} + + # create and write the asdf file + tree = {'roman': {'data': data, 'meta': meta}} + af = asdf.AsdfFile(tree) + + path = tmp_path / 'roman' + path.mkdir(exist_ok=True) + + files = [] + for i in range(3): + filename = path / f'test_roman_{i}.asdf' + af.write_to(filename) + files.append(filename) + + return files + + +@pytest.fixture +def center_coord(): + """ Fixture to return a center coordinate """ + return SkyCoord('29.99901792 44.99930555', unit='deg') + + +@pytest.fixture +def cutout_size(): + """ Fixture to return a cutout size """ + return 10 + + +def test_asdf_cutout(test_images, center_coord, cutout_size): + cutout = ASDFCutout(test_images, center_coord, cutout_size) + cutouts = cutout.cutouts + # Should output a list of strings for multiple input files + assert isinstance(cutouts, list) + assert isinstance(cutouts[0], Cutout2D) + assert len(cutouts) == 3 + assert isinstance(cutout.asdf_cutouts, list) + assert isinstance(cutout.asdf_cutouts[0], asdf.AsdfFile) + assert isinstance(cutout.fits_cutouts, list) + assert isinstance(cutout.fits_cutouts[0], fits.HDUList) + + # Open output files + for i, cutout in enumerate(cutouts): + # Check shape of data + cutout_data = cutout.data + cutout_wcs = cutout.wcs + assert cutout_data.shape == (10, 10) + + # Check that data is equal between cutout and original image + with asdf.open(test_images[i]) as input_af: + assert np.all(cutout_data == input_af['roman']['data'].value[470:480, 471:481]) + + # Check WCS and that center coordinate matches input + s_coord = cutout_wcs.pixel_to_world(cutout_size / 2, cutout_size / 2) + assert cutout_wcs.pixel_shape == (10, 10) + assert np.isclose(s_coord.ra.deg, center_coord.ra.deg) + assert np.isclose(s_coord.dec.deg, center_coord.dec.deg) + + +def test_asdf_cutout_write_to_file(test_images, center_coord, cutout_size, tmpdir): + # Write cutouts to ASDF files on disk + cutout = ASDFCutout(test_images, center_coord, cutout_size) + asdf_files = cutout.write_as_asdf(output_dir=tmpdir) + assert len(asdf_files) == 3 + for i, asdf_file in enumerate(asdf_files): + with asdf.open(asdf_file) as af: + assert 'roman' in af.tree + assert 'data' in af.tree['roman'] + assert 'meta' in af.tree['roman'] + assert np.all(af.tree['roman']['data'] == cutout.cutouts[i].data) + assert af.tree['roman']['meta']['wcs'].pixel_shape == (10, 10) + + # Write cutouts to FITS files on disk + cutout = ASDFCutout(test_images, center_coord, cutout_size) + fits_files = cutout.write_as_fits(output_dir=tmpdir) + assert len(fits_files) == 3 + for i, fits_file in enumerate(fits_files): + with fits.open(fits_file) as hdul: + assert np.all(hdul[0].data == cutout.cutouts[i].data) + assert hdul[0].header['NAXIS1'] == 10 + assert hdul[0].header['NAXIS2'] == 10 + + +def test_asdf_cutout_partial(test_images, center_coord, cutout_size): + # Off the top + center_coord = SkyCoord('29.99901792 44.9861', unit='deg') + cutout = ASDFCutout(test_images[0], center_coord, cutout_size).cutouts[0] + assert cutout.data.shape == (10, 10) + assert np.isnan(cutout.data[:cutout_size//2, :]).all() + + # Off the bottom + center_coord = SkyCoord('29.99901792 45.01387', unit='deg') + cutout = ASDFCutout(test_images[0], center_coord, cutout_size).cutouts[0] + assert np.isnan(cutout.data[cutout_size//2:, :]).all() + + # Off the left, with integer fill value + center_coord = SkyCoord('29.98035835 44.99930555', unit='deg') + cutout = ASDFCutout(test_images[0], center_coord, cutout_size, fill_value=1).cutouts[0] + assert np.all(cutout.data[:, :cutout_size//2] == 1) + + # Off the right, with float fill value + center_coord = SkyCoord('30.01961 44.99930555', unit='deg') + cutout = ASDFCutout(test_images[0], center_coord, cutout_size, fill_value=1.5).cutouts[0] + assert np.all(cutout.data[:, cutout_size//2:] == 1.5) + + # Error if unexpected fill value + with pytest.raises(InvalidInputError, match='Fill value must be an integer or a float.'): + ASDFCutout(test_images[0], center_coord, cutout_size, fill_value='invalid') + + +def test_asdf_cutout_poles(cutout_size, makefake, tmp_path): + """ Test we can make cutouts around poles """ + # Make fake zero data around the pole + ra, dec = 315.0, 89.995 + data, gwcs = makefake(1000, 1000, ra, dec, zero=True) + + # Add some values (5x5 array) + data.value[245:250, 245:250] = 1 + + # Check central pixel is correct + ss = gwcs(500, 500) + assert ss == (ra, dec) + + # Set input cutout coord + center_coord = SkyCoord(284.702, 89.986, unit='deg') + + # create and write the asdf file + meta = {'wcs': gwcs} + tree = {'roman': {'data': data, 'meta': meta}} + af = asdf.AsdfFile(tree) + path = tmp_path / 'roman' + path.mkdir(exist_ok=True) + filename = path / 'test_roman_poles.asdf' + af.write_to(filename) + + # Get cutout + cutout = ASDFCutout(filename, center_coord, cutout_size).cutouts[0] + + # Check cutout contains all data + assert len(np.where(cutout.data == 1)[0]) == 25 + + +def test_asdf_cutout_not_in_footprint(test_images, center_coord, cutout_size): + # Throw error if cutout location is not in image footprint + with pytest.warns(DataWarning, match='Cutout footprint does not overlap'): + with pytest.raises(InvalidInputError, match='Cutout contains no data!'): + ASDFCutout(test_images[0], SkyCoord('0 0', unit='deg'), cutout_size) + + # Alter one of the test images to only contain zeros in cutout footprint + with asdf.open(test_images[0], mode='rw') as af: + af['roman']['data'][470:480, 471:481] = 0 + af.update() + + # Should warn about first image containing no data, but not fail + with pytest.warns(DataWarning, match='contains no data, skipping...'): + cutouts = ASDFCutout(test_images, center_coord, cutout_size).cutouts + assert len(cutouts) == 2 + + +def test_asdf_cutout_no_gwcs(test_images, center_coord, cutout_size): + # Remove WCS from test image + with asdf.open(test_images[0], mode='rw') as af: + del af['roman']['meta']['wcs'] + af.update() + + # Should warn about missing WCS for first image, but not fail + with pytest.warns(DataWarning, match='does not contain a GWCS object'): + cutouts = ASDFCutout(test_images, center_coord, cutout_size).cutouts + assert len(cutouts) == 2 + + +def test_asdf_cutout_invalid_params(test_images, center_coord, cutout_size, tmpdir): + # Invalid units for cutout size + cutout_size = 1 * u.m # meters are not valid + with pytest.raises(InvalidInputError, match='Cutout size unit meter is not supported.'): + ASDFCutout(test_images, center_coord, cutout_size) + + +def test_asdf_cutout_img_output(test_images, center_coord, cutout_size, tmpdir): + # Basic JPG image + jpg_files = ASDFCutout(test_images, center_coord, cutout_size).write_as_img(output_dir=tmpdir, + output_format='jpg') + assert len(jpg_files) == len(test_images) + with open(jpg_files[0], 'rb') as IMGFLE: + assert IMGFLE.read(3) == b'\xFF\xD8\xFF' # JPG + + # PNG (single input file, not as list) + png_files = ASDFCutout(test_images[0], center_coord, cutout_size).write_as_img(output_dir=tmpdir, + output_format='png') + with open(png_files[0], 'rb') as IMGFLE: + assert IMGFLE.read(8) == b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A' # PNG + assert len(png_files) == 1 + + # Save to memory only + img_cutouts = ASDFCutout(test_images[0], center_coord, cutout_size).get_image_cutouts() + assert len(img_cutouts) == 1 + assert isinstance(img_cutouts[0], Image.Image) + assert np.array(img_cutouts[0]).shape == (10, 10) + + # Color image + color_jpg = ASDFCutout(test_images, center_coord, cutout_size).write_as_img(output_dir=tmpdir, colorize=True) + img = Image.open(color_jpg) + assert img.mode == 'RGB' + + +def test_get_center_pixel(fakedata): + """ Test get_center_pixel function """ + # get the fake data + __, gwcs = fakedata + + pixel_coordinates, wcs = ASDFCutout.get_center_pixel(gwcs, 30, 45) + assert np.allclose(pixel_coordinates, (np.array(500), np.array(500))) + assert np.allclose(wcs.celestial.wcs.crval, np.array([30, 45])) + + +def test_asdf_cut(test_images, center_coord, cutout_size, tmpdir): + """ Test convenience function to create ASDF cutouts """ + def check_paths(cutout_paths, ext): + assert isinstance(cutout_paths, list) + assert isinstance(cutout_paths[0], str) + assert len(cutout_paths) == 3 + for i, path in enumerate(cutout_paths): + assert isinstance(path, str) + assert path.endswith(ext) + assert Path(path).exists() + assert str(tmpdir) in path + assert Path(test_images[i]).stem in path + assert center_coord.ra.to_string(unit='deg', decimal=True) in path + assert center_coord.dec.to_string(unit='deg', decimal=True) in path + assert '10-x-10' in path + + # Write files to disk as ASDF files + asdf_paths = asdf_cut(test_images, center_coord.ra.deg, center_coord.dec.deg, cutout_size, output_dir=tmpdir) + check_paths(asdf_paths, '.asdf') + + # Write files to disk as FITS files + fits_paths = asdf_cut(test_images, center_coord.ra.deg, center_coord.dec.deg, cutout_size, output_dir=tmpdir, + output_format='fits') + check_paths = (fits_paths, '.fits') + + # Write cutouts to memory as Cutout2D objects + cutouts = asdf_cut(test_images, center_coord.ra.deg, center_coord.dec.deg, cutout_size, write_file=False) + assert isinstance(cutouts, list) + assert isinstance(cutouts[0], Cutout2D) + assert len(cutouts) == 3 + + # Error if output format is not supported + with pytest.raises(InvalidInputError, match='Output format .invalid is not recognized.'): + asdf_cut(test_images, center_coord.ra.deg, center_coord.dec.deg, cutout_size, output_format='invalid') diff --git a/astrocut/tests/test_FITSCutout.py b/astrocut/tests/test_FITSCutout.py index 145b44f3..ff25dbc1 100644 --- a/astrocut/tests/test_FITSCutout.py +++ b/astrocut/tests/test_FITSCutout.py @@ -1,3 +1,4 @@ +from pathlib import Path import pytest import numpy as np @@ -45,52 +46,49 @@ def cutout_size(): def test_fits_cutout_single_outfile(test_images, center_coord, cutout_size, tmpdir): # Create cutout with single output file - cutout = FITSCutout(test_images, center_coord, cutout_size, single_outfile=True, output_dir=tmpdir).cutout() + cutouts = FITSCutout(test_images, center_coord, cutout_size, single_outfile=True).fits_cutouts - # Should output a single string - assert isinstance(cutout, str) - - # Open output file - cutout_hdulist = fits.open(cutout) - assert len(cutout_hdulist) == len(test_images) + 1 # num imgs + primary header + # Should output a list of objects + assert isinstance(cutouts, list) + assert isinstance(cutouts[0], fits.HDUList) + assert len(cutouts) == 1 # Check shape of data - cut1 = cutout_hdulist[1].data - assert cut1.shape == (cutout_size, cutout_size) - assert cutout_hdulist[1].data.shape == cutout_hdulist[2].data.shape - assert cutout_hdulist[2].data.shape == cutout_hdulist[3].data.shape - assert cutout_hdulist[3].data.shape == cutout_hdulist[4].data.shape - assert cutout_hdulist[4].data.shape == cutout_hdulist[5].data.shape - assert cutout_hdulist[5].data.shape == cutout_hdulist[6].data.shape + cutout = cutouts[0] + assert len(cutout) == len(test_images) + 1 # num imgs + primary header + assert cutout[1].data.shape == (cutout_size, cutout_size) + assert cutout[1].data.shape == cutout[2].data.shape + assert cutout[2].data.shape == cutout[3].data.shape + assert cutout[3].data.shape == cutout[4].data.shape + assert cutout[4].data.shape == cutout[5].data.shape + assert cutout[5].data.shape == cutout[6].data.shape # Check that data is equal between cutout and original image for i, img in enumerate(test_images): with fits.open(img) as test_hdu: - assert np.all(cutout_hdulist[i + 1].data == test_hdu[0].data[19:29, 19:29]) + assert np.all(cutout[i + 1].data == test_hdu[0].data[19:29, 19:29]) # Check WCS and position of center - cut_wcs = wcs.WCS(cutout_hdulist[1].header) + cut_wcs = wcs.WCS(cutout[1].header) sra, sdec = cut_wcs.all_pix2world(cutout_size/2, cutout_size/2, 0) assert round(float(sra), 4) == round(center_coord.ra.deg, 4) assert round(float(sdec), 4) == round(center_coord.dec.deg, 4) - cutout_hdulist.close() - def test_fits_cutout_multiple_files(tmpdir, test_images, center_coord, cutout_size): # Output is multiple files - cutout_files = FITSCutout(test_images, center_coord, cutout_size, single_outfile=False, output_dir=tmpdir).cutout() + cutouts = FITSCutout(test_images, center_coord, cutout_size, single_outfile=False).fits_cutouts - # Output is a list with paths to each file - assert isinstance(cutout_files, list) - assert len(cutout_files) == len(test_images) + # Should output a list of objects + assert isinstance(cutouts, list) + assert isinstance(cutouts[0], fits.HDUList) + assert len(cutouts) == len(test_images) - for i, cutout in enumerate(cutout_files): - cutout_hdulist = fits.open(cutout) - cut1 = cutout_hdulist[1].data + for i, cutout in enumerate(cutouts): + cut1 = cutout[1].data # Check shape of data - assert len(cutout_hdulist) == 2 # primary header + 1 image + assert len(cutout) == 2 # primary header + 1 image assert cut1.shape == (cutout_size, cutout_size) # Check that data is equal between cutout and original image @@ -98,116 +96,122 @@ def test_fits_cutout_multiple_files(tmpdir, test_images, center_coord, cutout_si assert np.all(cut1 == test_hdu[0].data[19:29, 19:29]) # Check WCS and position of center - cut_wcs = wcs.WCS(cutout_hdulist[1].header) + cut_wcs = wcs.WCS(cutout[1].header) sra, sdec = cut_wcs.all_pix2world(cutout_size/2, cutout_size/2, 0) assert round(float(sra), 4) == round(center_coord.ra.deg, 4) assert round(float(sdec), 4) == round(center_coord.dec.deg, 4) - cutout_hdulist.close() - # Test case where output directory does not exist new_dir = path.join(tmpdir, 'cutout_files') # non-existing directory to write files to - cutout_files = FITSCutout(test_images, center_coord, cutout_size, - output_dir=new_dir, single_outfile=False).cutout() + cutouts = FITSCutout(test_images[0], center_coord, cutout_size, single_outfile=False) + paths = cutouts.write_as_fits(output_dir=new_dir) - assert isinstance(cutout_files, list) - assert len(cutout_files) == len(test_images) - assert new_dir in cutout_files[0] + assert isinstance(paths, list) + assert isinstance(paths[0], str) + assert new_dir in paths[0] assert path.exists(new_dir) # new directory should now exist def test_fits_cutout_memory_only(test_images, center_coord, cutout_size): # Memory only, single file nonexisting_dir = 'nonexisting' # non-existing directory to check that no files are written - cutout_list = FITSCutout(test_images, center_coord, cutout_size, output_dir=nonexisting_dir, - single_outfile=True, memory_only=True).cutout() - cutout_hdu = cutout_list[0] + cutout_list = FITSCutout(test_images, center_coord, cutout_size, single_outfile=True).fits_cutouts assert isinstance(cutout_list, list) assert len(cutout_list) == 1 - assert isinstance(cutout_hdu, fits.HDUList) + assert isinstance(cutout_list[0], fits.HDUList) assert not path.exists(nonexisting_dir) # no files should be written # Memory only, multiple files - cutout_list = FITSCutout(test_images, center_coord, cutout_size, output_dir=nonexisting_dir, - single_outfile=False, memory_only=True).cutout() + cutout_list = FITSCutout(test_images, center_coord, cutout_size, single_outfile=False).fits_cutouts assert isinstance(cutout_list, list) assert len(cutout_list) == len(test_images) assert isinstance(cutout_list[0], fits.HDUList) assert not path.exists(nonexisting_dir) # no files should be written -def test_fits_cutout_off_edge(tmpdir, test_images, cutout_size): +def test_fits_cutout_return_paths(test_images, center_coord, cutout_size, tmpdir): + # Return filepath for single output file + cutout_file = FITSCutout(test_images, center_coord, cutout_size).write_as_fits(output_dir=tmpdir, + cutout_prefix='prefix')[0] + assert isinstance(cutout_file, str) + assert path.exists(cutout_file) + assert str(tmpdir) in cutout_file + assert 'prefix' in cutout_file + assert center_coord.ra.to_string(unit='deg', decimal=True) in cutout_file + assert center_coord.dec.to_string(unit='deg', decimal=True) in cutout_file + assert '10-x-10' in cutout_file + + # Return list of filepaths for multiple output files + cutout_files = FITSCutout(test_images, center_coord, cutout_size, + single_outfile=False).write_as_fits(output_dir=tmpdir) + assert isinstance(cutout_files, list) + assert len(cutout_files) == len(test_images) + for i, cutout_file in enumerate(cutout_files): + assert path.exists(cutout_file) + assert str(tmpdir) in cutout_file + assert Path(test_images[i]).stem in cutout_file + + +def test_fits_cutout_off_edge(test_images, cutout_size): # Off the top center_coord = SkyCoord("150.1163213 2.2005731", unit='deg') - cutout_file = FITSCutout(test_images, center_coord, cutout_size, single_outfile=True, output_dir=tmpdir).cutout() - assert isinstance(cutout_file, str) + cutout = FITSCutout(test_images, center_coord, cutout_size, single_outfile=True).fits_cutouts[0] + assert isinstance(cutout, fits.HDUList) - cutout_hdulist = fits.open(cutout_file) - assert len(cutout_hdulist) == len(test_images) + 1 # num imgs + primary header + assert len(cutout) == len(test_images) + 1 # num imgs + primary header - cut1 = cutout_hdulist[1].data + cut1 = cutout[1].data assert cut1.shape == (cutout_size, cutout_size) assert np.isnan(cut1[:cutout_size//2, :]).all() - cutout_hdulist.close() - # Off the bottom center_coord = SkyCoord("150.1163213 2.2014", unit='deg') - cutout = FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True).cutout() - assert np.isnan(cutout[0][1].data[cutout_size//2:, :]).all() + cutout = FITSCutout(test_images[0], center_coord, cutout_size).fits_cutouts[0] + assert np.isnan(cutout[1].data[cutout_size//2:, :]).all() # Off the left, with integer fill value center_coord = SkyCoord('150.11672 2.200973097', unit='deg') - cutout = FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True, fill_value=1).cutout() - assert np.all(cutout[0][1].data[:, :cutout_size//2] == 1) + cutout = FITSCutout(test_images[0], center_coord, cutout_size, fill_value=1).fits_cutouts[0] + assert np.all(cutout[1].data[:, :cutout_size//2] == 1) # Off the right, with float fill value center_coord = SkyCoord('150.11588 2.200973097', unit='deg') - cutout = FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True, fill_value=1.5).cutout() - assert np.all(cutout[0][1].data[:, cutout_size//2:] == 1.5) + cutout = FITSCutout(test_images[0], center_coord, cutout_size, fill_value=1.5).fits_cutouts[0] + assert np.all(cutout[1].data[:, cutout_size//2:] == 1.5) # Error if unexpected fill value with pytest.raises(InvalidInputError, match='Fill value must be an integer or a float.'): - FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True, fill_value='invalid').cutout() + FITSCutout(test_images[0], center_coord, cutout_size, fill_value='invalid') -def test_fits_cutout_cloud(tmpdir): +def test_fits_cutout_cloud(): # Test single cloud image test_s3_uri = "s3://stpubdata/hst/public/j8pu/j8pu0y010/j8pu0y010_drc.fits" center_coord = SkyCoord("150.4275416667 2.42155", unit='deg') cutout_size = [10, 15] - cutout_file = FITSCutout(test_s3_uri, center_coord, cutout_size, output_dir=tmpdir).cutout() - assert isinstance(cutout_file, str) - assert "10-x-15" in cutout_file - - with fits.open(cutout_file) as cutout_hdulist: - assert cutout_hdulist[1].data.shape == (15, 10) + cutout = FITSCutout(test_s3_uri, center_coord, cutout_size).fits_cutouts[0] + assert cutout[1].data.shape == (15, 10) def test_fits_cutout_rounding(test_images, cutout_size): # Rounding normally center_coord = SkyCoord("150.1163117 2.200973097", unit='deg') - cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True).cutout() - test_hdu = fits.open(test_images[0]) - assert np.all(cutout_list[0][1].data == test_hdu[0].data[19:29, 20:30]) + cutout = FITSCutout(test_images[0], center_coord, cutout_size).fits_cutouts[0] + with fits.open(test_images[0]) as test_hdu: + assert np.all(cutout[1].data == test_hdu[0].data[19:29, 20:30]) - # Rounding to ceiling - cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True, - limit_rounding_method='ceil').cutout() - assert np.all(cutout_list[0][1].data == test_hdu[0].data[20:30, 20:30]) + # Rounding to ceiling + cutout = FITSCutout(test_images[0], center_coord, cutout_size, limit_rounding_method='ceil').fits_cutouts[0] + assert np.all(cutout[1].data == test_hdu[0].data[20:30, 20:30]) - # Rounding to floor - cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True, - limit_rounding_method='floor').cutout() - assert np.all(cutout_list[0][1].data == test_hdu[0].data[19:29, 19:29]) + # Rounding to floor + cutout = FITSCutout(test_images[0], center_coord, cutout_size, limit_rounding_method='floor').fits_cutouts[0] + assert np.all(cutout[1].data == test_hdu[0].data[19:29, 19:29]) - # Case that the cutout rounds to zero - cutout_size = 0.57557495 - cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True, - limit_rounding_method='round').cutout() - assert np.all(cutout_list[0][1].data == test_hdu[0].data[24:25, 24:25]) - - test_hdu.close() + # Case that the cutout rounds to zero + cutout_size = 0.57557495 + cutout = FITSCutout(test_images[0], center_coord, cutout_size, limit_rounding_method='round').fits_cutouts[0] + assert np.all(cutout[1].data == test_hdu[0].data[24:25, 24:25]) def test_fits_cutout_extension(test_images, center_coord, cutout_size): @@ -221,20 +225,20 @@ def test_fits_cutout_extension(test_images, center_coord, cutout_size): hdul.flush() # save changes # Cutout all extensions - cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True, extension='all').cutout() + cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, extension='all').fits_cutouts assert len(cutout_list[0]) == 4 # primary header + 3 images # Specify a single extension - cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True, extension=2).cutout() + cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, extension=2).fits_cutouts assert len(cutout_list[0]) == 2 # primary header + 1 image # Specify a list of extensions - cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True, extension=[0, 1]).cutout() + cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, extension=[0, 1]).fits_cutouts assert len(cutout_list[0]) == 3 # primary header + 2 images # Warning if a non-existing extension is specified with pytest.warns(DataWarning, match=r'extension\(s\) 3 will be skipped.'): - cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, memory_only=True, extension=[1, 3]).cutout() + cutout_list = FITSCutout(test_images[0], center_coord, cutout_size, extension=[1, 3]).fits_cutouts assert len(cutout_list[0]) == 2 # primary header + 1 image # Remove image data from one of the input files @@ -246,85 +250,84 @@ def test_fits_cutout_extension(test_images, center_coord, cutout_size): hdul.append(table_hdu) hdul.flush() - with pytest.raises(InvalidQueryError, match='Cutout contains no data!'): - FITSCutout(test_images[1], center_coord, cutout_size, memory_only=True).cutout() + with pytest.warns(DataWarning, match='No image extensions with data found.'): + with pytest.raises(InvalidInputError, match='Cutout contains no data!'): + FITSCutout(test_images[1], center_coord, cutout_size) def test_fits_cutout_not_in_footprint(test_images, cutout_size): # Test when the requested cutout is not on the image center_coord = SkyCoord("140.1163213 2.2005731", unit='deg') with pytest.raises(InvalidQueryError, match='Cutout location is not in image footprint!'): - FITSCutout(test_images, center_coord, cutout_size, single_outfile=True).cutout() + FITSCutout(test_images, center_coord, cutout_size, single_outfile=True) center_coord = SkyCoord("15.1163213 2.2005731", unit='deg') with pytest.raises(InvalidQueryError, match='Cutout location is not in image footprint!'): - FITSCutout(test_images, center_coord, cutout_size, single_outfile=True).cutout() + FITSCutout(test_images, center_coord, cutout_size, single_outfile=True) def test_fits_cutout_no_data(tmpdir, test_images, cutout_size): # Test behavior when some input files contain zeros in cutout footprint # Putting zeros into 2 images for img in test_images[:2]: - hdu = fits.open(img, mode="update") - hdu[0].data[:20, :] = 0 - hdu.flush() - hdu.close() + with fits.open(img, mode="update") as hdu: + hdu[0].data[:20, :] = 0 + hdu.flush() # Single outfile should include empty files as extensions center_coord = SkyCoord("150.1163213 2.2007", unit='deg') - cutout_file = FITSCutout(test_images, center_coord, cutout_size, single_outfile=True, output_dir=tmpdir).cutout() - cutout_hdulist = fits.open(cutout_file) - assert len(cutout_hdulist) == len(test_images) + 1 # num imgs + primary header - assert (cutout_hdulist[1].data == 0).all() - assert (cutout_hdulist[2].data == 0).all() - assert ~(cutout_hdulist[3].data == 0).any() - assert ~(cutout_hdulist[4].data == 0).any() - assert ~(cutout_hdulist[5].data == 0).any() - assert ~(cutout_hdulist[6].data == 0).any() + with pytest.warns(DataWarning, match='contains no data, skipping...'): + cutout = FITSCutout(test_images, center_coord, cutout_size, single_outfile=True).fits_cutouts[0] + assert len(cutout) == len(test_images) - 1 # 6 images - 2 empty + 1 primary header + assert ~(cutout[1].data == 0).any() + assert ~(cutout[2].data == 0).any() + assert ~(cutout[3].data == 0).any() + assert ~(cutout[4].data == 0).any() # Empty files should not be written to their own file - cutout_files = FITSCutout(test_images, center_coord, cutout_size, single_outfile=False, output_dir=tmpdir).cutout() + with pytest.warns(DataWarning, match='contains no data, skipping...'): + cutout_files = FITSCutout(test_images, center_coord, cutout_size, + single_outfile=False).write_as_fits(output_dir=tmpdir) assert isinstance(cutout_files, list) assert len(cutout_files) == len(test_images) - 2 # Test when all input files contain only zeros in cutout footprint # Putting zeros into the rest of the images for img in test_images[2:]: - hdu = fits.open(img, mode="update") - hdu[0].data[:20, :] = 0 - hdu.flush() - hdu.close() + with fits.open(img, mode="update") as hdu: + hdu[0].data[:20, :] = 0 + hdu.flush() - with pytest.raises(InvalidQueryError) as e: - cutout_file = FITSCutout(test_images, center_coord, cutout_size, single_outfile=True, - output_dir=tmpdir).cutout() - assert 'Cutout contains no data! (Check image footprint.)' in e + with pytest.warns(DataWarning, match='contains no data, skipping...'): + with pytest.raises(InvalidInputError, match='Cutout contains no data!'): + FITSCutout(test_images, center_coord, cutout_size, single_outfile=True) def test_fits_cutout_bad_sip(tmpdir, caplog, test_image_bad_sip): # Test single image and also conflicting sip keywords center_coord = SkyCoord("150.1163213 2.2007", unit='deg') cutout_size = [10, 15] - cutout_file = FITSCutout(test_image_bad_sip, center_coord, cutout_size, output_dir=tmpdir).cutout() + cutout_file = FITSCutout(test_image_bad_sip, center_coord, cutout_size).write_as_fits(output_dir=tmpdir)[0] assert isinstance(cutout_file, str) assert "10-x-15" in cutout_file - cutout_hdulist = fits.open(cutout_file) - assert cutout_hdulist[1].data.shape == (15, 10) + with fits.open(cutout_file) as cutout_hdulist: + assert cutout_hdulist[1].data.shape == (15, 10) center_coord = SkyCoord("150.1159 2.2006", unit='deg') cutout_size = [10, 15]*u.pixel - cutout_file = FITSCutout(test_image_bad_sip, center_coord, cutout_size, output_dir=tmpdir).cutout() + cutout_file = FITSCutout(test_image_bad_sip, center_coord, cutout_size).write_as_fits(output_dir=tmpdir)[0] assert isinstance(cutout_file, str) assert "10.0pix-x-15.0pix" in cutout_file - cutout_hdulist = fits.open(cutout_file) - assert cutout_hdulist[1].data.shape == (15, 10) + with fits.open(cutout_file) as cutout_hdulist: + assert cutout_hdulist[1].data.shape == (15, 10) cutout_size = [1, 2]*u.arcsec - cutout_file = FITSCutout(test_image_bad_sip, center_coord, cutout_size, output_dir=tmpdir, verbose=True).cutout() + cutout_file = FITSCutout(test_image_bad_sip, center_coord, cutout_size, + verbose=True).write_as_fits(output_dir=tmpdir)[0] assert isinstance(cutout_file, str) assert "1.0arcsec-x-2.0arcsec" in cutout_file - cutout_hdulist = fits.open(cutout_file) - assert cutout_hdulist[1].data.shape == (33, 17) + with fits.open(cutout_file) as cutout_hdulist: + assert cutout_hdulist[1].data.shape == (33, 17) captured = caplog.text assert "Original image shape: (50, 50)" in captured assert "Image cutout shape: (33, 17)" in captured @@ -333,69 +336,63 @@ def test_fits_cutout_bad_sip(tmpdir, caplog, test_image_bad_sip): center_coord = "150.1159 2.2006" cutout_size = [10, 15, 20] with pytest.warns(InputWarning, match='Too many dimensions in cutout size, only the first two will be used.'): - cutout_file = FITSCutout(test_image_bad_sip, center_coord, cutout_size, output_dir=tmpdir).cutout() + cutout_file = FITSCutout(test_image_bad_sip, center_coord, cutout_size).write_as_fits(output_dir=tmpdir)[0] assert isinstance(cutout_file, str) assert "10-x-15" in cutout_file assert "x-20" not in cutout_file -def test_fits_cutout_invalid_params(test_images, center_coord, cutout_size): - # Warning when image options are given - with pytest.warns(InputWarning, match='are not supported for FITS output and will be ignored.'): - FITSCutout(test_images, center_coord, cutout_size, stretch='asinh').cutout() - +def test_fits_cutout_invalid_params(tmpdir, test_images, center_coord, cutout_size): # Invalid limit rounding method with pytest.raises(InvalidInputError, match='Limit rounding method invalid is not recognized.'): - FITSCutout(test_images, center_coord, cutout_size, limit_rounding_method='invalid').cutout() + FITSCutout(test_images, center_coord, cutout_size, limit_rounding_method='invalid') # Invalid units for cutout size cutout_size = 1 * u.m # meters are not valid with pytest.raises(InvalidInputError, match='Cutout size unit meter is not supported.'): - FITSCutout(test_images, center_coord, cutout_size).cutout() + FITSCutout(test_images, center_coord, cutout_size) def test_fits_cutout_img_output(tmpdir, test_images, caplog, center_coord, cutout_size): # Basic jpg image - jpg_files = FITSCutout(test_images, center_coord, cutout_size, output_dir=tmpdir, output_format='jpg').cutout() - + jpg_files = FITSCutout(test_images, center_coord, cutout_size).write_as_img(output_format='jpg', output_dir=tmpdir) assert len(jpg_files) == len(test_images) with open(jpg_files[0], 'rb') as IMGFLE: assert IMGFLE.read(3) == b'\xFF\xD8\xFF' # JPG # Png (single input file, not as list) - img_files = FITSCutout(test_images[0], center_coord, cutout_size, output_format='png', output_dir=tmpdir).cutout() + img_files = FITSCutout(test_images[0], center_coord, cutout_size).write_as_img(output_format='png', + output_dir=tmpdir) with open(img_files[0], 'rb') as IMGFLE: assert IMGFLE.read(8) == b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A' # PNG assert len(img_files) == 1 - # string coordinates and verbose + # String coordinates and verbose center_coord = "150.1163213 2.200973097" - jpg_files = FITSCutout(test_images, center_coord, cutout_size, output_format='jpg', - output_dir=path.join(tmpdir, 'image_path'), verbose=True).cutout() + jpg_files = FITSCutout(test_images, center_coord, cutout_size, verbose=True).write_as_img(output_format='jpg', + output_dir=tmpdir) captured = caplog.text assert len(findall('Original image shape', captured)) == 6 - assert 'Cutout fits file(s)' in captured assert 'Total time' in captured def test_fits_cutout_img_color(tmpdir, test_images, center_coord, cutout_size): # Color image - color_jpg = FITSCutout(test_images[:3], center_coord, cutout_size, output_format='jpg', colorize=True, - output_dir=tmpdir).cutout() + color_jpg = FITSCutout(test_images[:3], center_coord, cutout_size).write_as_img(output_format='jpg', colorize=True, + output_dir=tmpdir) img = Image.open(color_jpg) assert img.mode == 'RGB' def test_fits_cutout_img_memory_only(test_images, center_coord, cutout_size): # Save black and white image to memory - imgs = FITSCutout(test_images[0], center_coord, cutout_size, output_format='png', memory_only=True).cutout() + imgs = FITSCutout(test_images[0], center_coord, cutout_size).image_cutouts assert isinstance(imgs, list) assert len(imgs) == 1 assert isinstance(imgs[0], Image.Image) # Save color image to memory - color_imgs = FITSCutout(test_images[:3], center_coord, cutout_size, output_format='jpg', colorize=True, - memory_only=True).cutout() + color_imgs = FITSCutout(test_images[:3], center_coord, cutout_size).get_image_cutouts(colorize=True) assert isinstance(color_imgs, list) assert len(color_imgs) == 1 assert isinstance(color_imgs[0], Image.Image) @@ -405,50 +402,43 @@ def test_fits_cutout_img_memory_only(test_images, center_coord, cutout_size): def test_fits_cutout_img_errors(tmpdir, test_images, center_coord, cutout_size): # Error when too few input images with pytest.raises(InvalidInputError): - FITSCutout(test_images[0], center_coord, cutout_size, output_format='jpg', colorize=True, - output_dir=tmpdir).cutout() + FITSCutout(test_images[0], center_coord, cutout_size).get_image_cutouts(colorize=True) # Warning when too many input images with pytest.warns(InputWarning, match='Too many inputs for a color cutout, only the first three will be used.'): - color_jpg = FITSCutout(test_images, center_coord, cutout_size, output_format='jpg', colorize=True, - output_dir=tmpdir).cutout() + color_jpg = FITSCutout(test_images, center_coord, cutout_size).write_as_img(colorize=True, output_dir=tmpdir) img = Image.open(color_jpg) assert img.mode == 'RGB' # Warning when saving image to unsupported image formats with pytest.warns(DataWarning, match='Cutout could not be saved in .blp format'): - FITSCutout(test_images[0], center_coord, cutout_size, output_format='blp', output_dir=tmpdir).cutout() + FITSCutout(test_images[0], center_coord, cutout_size).write_as_img(output_format='blp', output_dir=tmpdir) with pytest.warns(DataWarning, match='Cutout could not be saved in .mpg format'): - FITSCutout(test_images, center_coord, cutout_size, output_format='mpg', colorize=True, - output_dir=tmpdir).cutout() + FITSCutout(test_images[:3], center_coord, cutout_size).write_as_img(output_format='mpg', output_dir=tmpdir, + colorize=True) # Invalid stretch error with pytest.raises(InvalidInputError, match='Stretch invalid is not recognized.'): - FITSCutout(test_images[0], center_coord, cutout_size, stretch='invalid', output_format='png', - output_dir=tmpdir).cutout() + FITSCutout(test_images[0], center_coord, cutout_size).write_as_img(stretch='invalid', output_format='png', + output_dir=tmpdir) # Invalid output format with pytest.raises(InvalidInputError, match='Output format .invalid is not supported'): - FITSCutout(test_images[0], center_coord, cutout_size, output_format='invalid', output_dir=tmpdir).cutout() + FITSCutout(test_images[0], center_coord, cutout_size).write_as_img(output_format='invalid', output_dir=tmpdir) # Change first input file to be all zeros - hdu = fits.open(test_images[0], mode='update') - hdu[0].data[:, :] = 0 - hdu.flush() - hdu.close() + with fits.open(test_images[0], mode='update') as hdu: + hdu[0].data[:, :] = 0 + hdu.flush() # Warning when outputting non-color images - with pytest.warns(DataWarning, match='contains no data and will not be written.'): - FITSCutout(test_images[0], center_coord, cutout_size, output_format='png', output_dir=tmpdir).cutout() + with pytest.warns(DataWarning, match='contains no data, skipping...'): + with pytest.raises(InvalidInputError, match='Cutout contains no data'): + FITSCutout(test_images[0], center_coord, cutout_size).write_as_img(output_format='png', output_dir=tmpdir) # Error when outputting color image - with pytest.raises(InvalidInputError): - FITSCutout(test_images[:3], center_coord, cutout_size, - colorize=True, output_format='png', output_dir=tmpdir).cutout() - - -def test_fits_cutout_asdf_output(test_images, center_coord, cutout_size): - # Should warn if output format is ASDF (not yet implemented) - with pytest.warns(InputWarning, match='ASDF output is not yet implemented for FITS files.'): - FITSCutout(test_images[0], center_coord, cutout_size, output_format='asdf').cutout() + with pytest.warns(DataWarning, match='contains no data, skipping...'): + with pytest.raises(InvalidInputError): + FITSCutout(test_images[:3], center_coord, cutout_size).write_as_img(colorize=True, output_format='png', + output_dir=tmpdir) diff --git a/astrocut/tests/test_asdf_cut.py b/astrocut/tests/test_asdf_cut.py deleted file mode 100644 index 5a4b70d6..00000000 --- a/astrocut/tests/test_asdf_cut.py +++ /dev/null @@ -1,351 +0,0 @@ - -import pathlib -from unittest.mock import MagicMock, patch -import numpy as np -import pytest - -import asdf -from astropy.modeling import models -from astropy import coordinates as coord -from astropy import units as u -from astropy.io import fits -from astropy.wcs import WCS -from astropy.wcs.utils import pixel_to_skycoord -from gwcs import wcs -from gwcs import coordinate_frames as cf -from s3path import S3Path -from astrocut.asdf_cutouts import get_center_pixel, asdf_cut, _get_cutout, _slice_gwcs, _get_cloud_http - - -def make_wcs(xsize, ysize, ra=30., dec=45.): - """ create a fake gwcs object """ - # todo - refine this to better reflect roman wcs - - # create transformations - # - shift coords so array center is at 0, 0 ; reference pixel - # - scale pixels to correct angular scale - # - project coords onto sky with TAN projection - # - transform center pixel to the input celestial coordinate - pixelshift = models.Shift(-xsize) & models.Shift(-ysize) - pixelscale = models.Scale(0.1 / 3600.) & models.Scale(0.1 / 3600.) # 0.1 arcsec/pixel - tangent_projection = models.Pix2Sky_TAN() - celestial_rotation = models.RotateNative2Celestial(ra, dec, 180.) - - # net transforms pixels to sky - det2sky = pixelshift | pixelscale | tangent_projection | celestial_rotation - - # define the wcs object - detector_frame = cf.Frame2D(name="detector", axes_names=("x", "y"), unit=(u.pix, u.pix)) - sky_frame = cf.CelestialFrame(reference_frame=coord.ICRS(), name='world', unit=(u.deg, u.deg)) - return wcs.WCS([(detector_frame, det2sky), (sky_frame, None)]) - - -@pytest.fixture() -def makefake(): - """ fixture factory to make a fake gwcs and dataset """ - - def _make_fake(nx, ny, ra, dec, zero=False, asint=False): - # create the wcs - wcsobj = make_wcs(nx/2, ny/2, ra=ra, dec=dec) - wcsobj.bounding_box = ((0, nx), (0, ny)) - - # create the data - if zero: - data = np.zeros([nx, ny]) - else: - size = nx * ny - data = np.arange(size).reshape(nx, ny) - - # make a quantity - data *= (u.electron / u.second) - - # make integer array - if asint: - data = data.astype(int) - - return data, wcsobj - - yield _make_fake - - -@pytest.fixture() -def fakedata(makefake): - """ fixture to create fake data and wcs """ - # set up initial parameters - nx = 1000 - ny = 1000 - ra = 30. - dec = 45. - - yield makefake(nx, ny, ra, dec) - - -@pytest.fixture() -def make_file(tmp_path, fakedata): - """ fixture to create a fake dataset """ - # get the fake data - data, wcsobj = fakedata - - # create meta - meta = {'wcs': wcsobj} - - # create and write the asdf file - tree = {'roman': {'data': data, 'meta': meta}} - af = asdf.AsdfFile(tree) - - path = tmp_path / "roman" - path.mkdir(exist_ok=True) - filename = path / "test_roman.asdf" - af.write_to(filename) - - yield filename - - -@pytest.fixture() -def output(tmp_path): - """ fixture to create the output path """ - def _output_file(ext='fits'): - # create output fits path - out = tmp_path / "roman" - out.mkdir(exist_ok=True, parents=True) - output_file = out / f"test_output_cutout.{ext}" if ext else out / "test_output_cutout" - return output_file - yield _output_file - - -def test_get_center_pixel(fakedata): - """ test we can get the correct center pixel """ - # get the fake data - __, gwcs = fakedata - - pixel_coordinates, wcs = get_center_pixel(gwcs, 30., 45.) - assert np.allclose(pixel_coordinates, (np.array(500.), np.array(500.))) - assert np.allclose(wcs.celestial.wcs.crval, np.array([30., 45.])) - - -@pytest.mark.parametrize('quantity', [True, False], ids=['quantity', 'array']) -def test_get_cutout(output, fakedata, quantity): - """ test we can create a cutout """ - output_file = output('fits') - - # get the input wcs - data, gwcs = fakedata - skycoord = gwcs(25, 25, with_units=True) - wcs = WCS(gwcs.to_fits_sip()) - - # convert quanity data back to array - if not quantity: - data = data.value - - # create cutout - cutout = _get_cutout(data, skycoord, wcs, size=10, outfile=output_file) - - assert_same_coord(5, 10, cutout, wcs) - - # test output - with fits.open(output_file) as hdulist: - data = hdulist[0].data - assert data.shape == (10, 10) - assert data[5, 5] == 25025 - - -def test_asdf_cutout(make_file, output): - """ test we can make a cutout """ - output_file = output('fits') - # make cutout - ra, dec = (29.99901792, 44.99930555) - asdf_cut(make_file, ra, dec, cutout_size=10, output_file=output_file) - - # test output - with fits.open(output_file) as hdulist: - data = hdulist[0].data - assert data.shape == (10, 10) - assert data[5, 5] == 475476 - - -@pytest.mark.parametrize('suffix', ['fits', 'asdf', None]) -def test_write_file(make_file, suffix, output): - """ test we can write an different file types """ - output_file = output(suffix) - - # make cutout - ra, dec = (29.99901792, 44.99930555) - asdf_cut(make_file, ra, dec, cutout_size=10, output_file=output_file) - - # if no suffix provided, check that the default output is fits - if not suffix: - output_file = output_file.with_suffix('.fits') - - assert pathlib.Path(output_file).exists() - - -def test_fail_write_asdf(fakedata, output): - """ test we fail to write an asdf if no gwcs given """ - with pytest.raises(ValueError, match='The original gwcs object is needed when writing to asdf file.'): - output_file = output('asdf') - data, gwcs = fakedata - skycoord = gwcs(25, 25, with_units=True) - wcs = WCS(gwcs.to_fits_sip()) - _get_cutout(data, skycoord, wcs, size=10, outfile=output_file) - - -def test_cutout_nofile(make_file, output): - """ test we can make a cutout with no file output """ - output_file = output() - # make cutout - ra, dec = (29.99901792, 44.99930555) - cutout = asdf_cut(make_file, ra, dec, cutout_size=10, output_file=output_file, write_file=False) - - assert not pathlib.Path(output_file).exists() - assert cutout.shape == (10, 10) - - -def test_cutout_poles(makefake): - """ test we can make cutouts around poles """ - # make fake zero data around the pole - ra, dec = 315.0, 89.995 - data, gwcs = makefake(1000, 1000, ra, dec, zero=True) - - # add some values (5x5 array) - data.value[245:250, 245:250] = 1 - - # check central pixel is correct - ss = gwcs(500, 500) - assert ss == (ra, dec) - - # set input cutout coord - cc = coord.SkyCoord(284.702, 89.986, unit=u.degree) - wcs = WCS(gwcs.to_fits_sip()) - - # get cutout - cutout = _get_cutout(data, cc, wcs, size=50, write_file=False) - assert_same_coord(5, 10, cutout, wcs) - - # check cutout contains all data - assert len(np.where(cutout.data.value == 1)[0]) == 25 - - -def test_fail_cutout_outside(fakedata): - """ test we fail when cutout completely outside range """ - data, gwcs = fakedata - wcs = WCS(gwcs.to_fits_sip()) - cc = coord.SkyCoord(200.0, 50.0, unit=u.degree) - - with pytest.raises(RuntimeError, match='Could not create 2d cutout. The requested ' - 'cutout does not overlap with the original image'): - _get_cutout(data, cc, wcs, size=50, write_file=False) - - -def assert_same_coord(x, y, cutout, wcs): - """ assert we get the same sky coordinate from cutout and original wcs """ - cutout_coord = pixel_to_skycoord(x, y, cutout.wcs) - ox, oy = cutout.to_original_position((x, y)) - orig_coord = pixel_to_skycoord(ox, oy, wcs) - assert cutout_coord == orig_coord - - -@pytest.mark.parametrize('asint, fill', [(False, None), (True, -9999)], ids=['fillfloat', 'fillint']) -def test_partial_cutout(makefake, asint, fill): - """ test we get a partial cutout with nans or fill value """ - ra, dec = 30.0, 45.0 - data, gwcs = makefake(100, 100, ra, dec, asint=asint) - - wcs = WCS(gwcs.to_fits_sip()) - cc = coord.SkyCoord(29.999, 44.998, unit=u.degree) - cutout = _get_cutout(data, cc, wcs, size=50, write_file=False, fill_value=fill) - assert cutout.shape == (50, 50) - if asint: - assert -9999 in cutout.data - else: - assert np.isnan(cutout.data).any() - - -def test_bad_fill(makefake): - """ test error is raised on bad fill value """ - ra, dec = 30.0, 45.0 - data, gwcs = makefake(100, 100, ra, dec, asint=True) - wcs = WCS(gwcs.to_fits_sip()) - cc = coord.SkyCoord(29.999, 44.998, unit=u.degree) - with pytest.raises(ValueError, match='fill_value is inconsistent with the data type of the input array'): - _get_cutout(data, cc, wcs, size=50, write_file=False) - - -def test_cutout_raedge(makefake): - """ test we can make cutouts around ra=0 """ - # make fake zero data around the ra edge - ra, dec = 0.0, 10.0 - data, gg = makefake(2000, 2000, ra, dec, zero=True) - - # check central pixel is correct - ss = gg(1001, 1001) - assert pytest.approx(ss, abs=1e-3) == (ra, dec) - - # set input cutout coord - cc = coord.SkyCoord(0.001, 9.999, unit=u.degree) - wcs = WCS(gg.to_fits_sip()) - - # get cutout - cutout = _get_cutout(data, cc, wcs, size=100, write_file=False) - assert_same_coord(5, 10, cutout, wcs) - - # assert the RA cutout bounds are > 359 and < 0 - bounds = gg(*cutout.bbox_original, with_units=True) - assert bounds[0].ra.value > 359 - assert bounds[1].ra.value < 0.1 - - -def test_slice_gwcs(fakedata): - """ test we can slice a gwcs object """ - data, gwcsobj = fakedata - skycoord = gwcsobj(250, 250) - wcs = WCS(gwcsobj.to_fits_sip()) - - cutout = _get_cutout(data, skycoord, wcs, size=50, write_file=False) - - sliced = _slice_gwcs(gwcsobj, cutout.slices_original) - - # check coords between slice and original gwcs - assert cutout.center_cutout == (24.5, 24.5) - assert sliced.array_shape == (50, 50) - assert sliced(*cutout.input_position_cutout) == gwcsobj(*cutout.input_position_original) - assert gwcsobj(*cutout.center_original) == sliced(*cutout.center_cutout) - - # assert same sky footprint between slice and original - # gwcs footprint/bounding_box expects ((x0, x1), (y0, y1)) but cutout.bbox is in ((y0, y1), (x0, x1)) - assert (gwcsobj.footprint(bounding_box=tuple(reversed(cutout.bbox_original))) == sliced.footprint()).all() - - -@patch('requests.head') -@patch('s3fs.S3FileSystem') -def test_get_cloud_http(mock_s3fs, mock_requests): - """ test we can get HTTP URI of cloud resource """ - # mock HTTP response - mock_resp = MagicMock() - mock_resp.status_code = 200 # public bucket - mock_requests.return_value = mock_resp - - # mock s3 file system operations - HTTP_URI = "http_test" - mock_fs = mock_s3fs.return_value - mock_file = MagicMock() - mock_file.url.return_value = HTTP_URI - mock_fs.open.return_value.__enter__.return_value = mock_file - - # test function with string input - s3_uri = "s3://test_bucket/test_file.asdf" - http_uri = _get_cloud_http(s3_uri) - assert http_uri == HTTP_URI - mock_s3fs.assert_called_with(anon=True, key=None, secret=None, token=None) - mock_fs.open.assert_called_once_with(s3_uri, 'rb') - mock_file.url.assert_called_once() - - # test function with S3Path input - s3_uri_path = S3Path("/test_bucket/test_file_2.asdf") - http_uri_path = _get_cloud_http(s3_uri_path) - assert http_uri_path == HTTP_URI - mock_fs.open.assert_called_with(s3_uri_path, 'rb') - - # test function with private bucket - mock_resp.status_code = 403 - http_uri = _get_cloud_http(s3_uri, key="access") - mock_s3fs.assert_called_with(anon=False, key="access", secret=None, token=None) diff --git a/astrocut/tests/test_cutouts.py b/astrocut/tests/test_cutouts.py index 077e1ca0..f571f525 100644 --- a/astrocut/tests/test_cutouts.py +++ b/astrocut/tests/test_cutouts.py @@ -13,7 +13,7 @@ from .utils_for_test import create_test_imgs from .. import cutouts -from ..exceptions import InputWarning, InvalidInputError, InvalidQueryError +from ..exceptions import DataWarning, InputWarning, InvalidInputError, InvalidQueryError @pytest.mark.parametrize('ffi_type', ['SPOC', 'TICA']) @@ -118,7 +118,6 @@ def test_fits_cut(tmpdir, caplog, ffi_type): # Test when cutout is in some images not others - # Putting zeros into 2 images for img in test_images[:2]: hdu = fits.open(img, mode="update") @@ -126,21 +125,19 @@ def test_fits_cut(tmpdir, caplog, ffi_type): hdu.flush() hdu.close() - center_coord = SkyCoord("150.1163213 2.2007", unit='deg') - cutout_file = cutouts.fits_cut(test_images, center_coord, cutout_size, single_outfile=True, output_dir=tmpdir) + with pytest.warns(DataWarning, match='contains no data, skipping...'): + cutout_file = cutouts.fits_cut(test_images, center_coord, cutout_size, single_outfile=True, output_dir=tmpdir) cutout_hdulist = fits.open(cutout_file) - assert len(cutout_hdulist) == len(test_images) + 1 # num imgs + primary header - assert (cutout_hdulist[1].data == 0).all() - assert (cutout_hdulist[2].data == 0).all() + assert len(cutout_hdulist) == len(test_images) - 1 # 6 images - 2 empty + 1 primary header + assert ~(cutout_hdulist[1].data == 0).any() + assert ~(cutout_hdulist[2].data == 0).any() assert ~(cutout_hdulist[3].data == 0).any() assert ~(cutout_hdulist[4].data == 0).any() - assert ~(cutout_hdulist[5].data == 0).any() - assert ~(cutout_hdulist[6].data == 0).any() - - cutout_files = cutouts.fits_cut(test_images, center_coord, cutout_size, single_outfile=False, output_dir=tmpdir) + with pytest.warns(DataWarning, match='contains no data'): + cutout_files = cutouts.fits_cut(test_images, center_coord, cutout_size, single_outfile=False, output_dir=tmpdir) assert isinstance(cutout_files, list) assert len(cutout_files) == len(test_images) - 2 @@ -151,10 +148,10 @@ def test_fits_cut(tmpdir, caplog, ffi_type): hdu.flush() hdu.close() - with pytest.raises(Exception) as e: - cutout_file = cutouts.fits_cut(test_images, center_coord, cutout_size, single_outfile=True, output_dir=tmpdir) - assert e.type is InvalidQueryError - assert "Cutout contains no data! (Check image footprint.)" in str(e.value) + with pytest.warns(DataWarning, match='contains no data, skipping...'): + with pytest.raises(InvalidInputError, match='Cutout contains no data!'): + cutout_file = cutouts.fits_cut(test_images, center_coord, cutout_size, single_outfile=True, + output_dir=tmpdir) # test single image and also conflicting sip keywords test_image = create_test_imgs(ffi_type, 50, 1, dir_name=tmpdir, @@ -315,7 +312,7 @@ def test_img_cut(tmpdir, caplog, ffi_type): output_dir=path.join(tmpdir, "image_path"), verbose=True) captured = caplog.text assert len(findall("Original image shape", captured)) == 6 - assert "Cutout fits file(s)" in captured + assert "Cutout filepaths:" in captured assert "Total time" in captured # test color image where one of the images is all zeros @@ -324,6 +321,7 @@ def test_img_cut(tmpdir, caplog, ffi_type): hdu.flush() hdu.close() - with pytest.raises(InvalidInputError): - cutouts.img_cut(test_images[:3], center_coord, cutout_size, - colorize=True, img_format='png', output_dir=tmpdir) + with pytest.warns(DataWarning, match='contains no data, skipping...'): + with pytest.raises(InvalidInputError): + cutouts.img_cut(test_images[:3], center_coord, cutout_size, + colorize=True, img_format='png', output_dir=tmpdir) diff --git a/astrocut/utils/utils.py b/astrocut/utils/utils.py index ec9a506e..5bb0d015 100644 --- a/astrocut/utils/utils.py +++ b/astrocut/utils/utils.py @@ -53,7 +53,7 @@ def parse_size_input(cutout_size): cutout_size = cutout_size[:2] ny, nx = cutout_size - if ny == 0 or nx == 0: + if ny <= 0 or nx <= 0: raise InvalidQueryError('Cutout size dimensions must be greater than zero. ' f'Provided size: ({cutout_size[0]}, {cutout_size[1]})') diff --git a/setup.cfg b/setup.cfg index f1131b7d..b24b06d3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,7 @@ install_requires = roman_datamodels>=0.17.0 # for roman file support requests>=2.32.3 # for making HTTP requests spherical_geometry>=1.3.0 + gwcs>=0.21.0 scipy Pillow