Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,13 @@ def world_axis_units(self):
return tuple(unit.to_string(format="vounit") for unit in self.output_frame.unit)

def _remove_quantity_output(self, result, frame):
if self.forward_transform.uses_quantity:
if frame.naxes == 1:
result = [result]
if frame.naxes == 1:
result = [result]

result = tuple(
r.to_value(unit) if isinstance(r, u.Quantity) else r
for r, unit in zip(result, frame.unit, strict=False)
)
result = tuple(
r.to_value(unit) if isinstance(r, u.Quantity) else r
for r, unit in zip(result, frame.unit, strict=False)
)

# If we only have one output axes, we shouldn't return a tuple.
if self.output_frame.naxes == 1 and isinstance(result, tuple):
Expand All @@ -94,7 +93,6 @@ def pixel_to_world_values(self, *pixel_arrays):
is the vertical coordinate.
"""
result = self(*pixel_arrays)

return self._remove_quantity_output(result, self.output_frame)

def array_index_to_world_values(self, *index_arrays):
Expand Down
4 changes: 3 additions & 1 deletion gwcs/coordinate_frames/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def add_units(self, arrays: u.Quantity | np.ndarray | float) -> tuple[u.Quantity
"""
Add units to the arrays
"""
if self.naxes == 1 and np.isscalar(arrays):
return u.Quantity(arrays, self.unit[0])
return tuple(
u.Quantity(array, unit=unit)
for array, unit in zip(arrays, self.unit, strict=True)
Expand All @@ -139,7 +141,7 @@ def remove_units(
"""
Remove units from the input arrays
"""
if self.naxes == 1:
if self.naxes == 1 and (np.isscalar(arrays) or isinstance(arrays, u.Quantity)):
arrays = (arrays,)

return tuple(
Expand Down
10 changes: 9 additions & 1 deletion gwcs/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def gwcs_spec_cel_time_4d():
wcslin = models.Mapping((1, 0)) | (offx & offy) | aff
tan = models.Pix2Sky_TAN(name="tangent_projection")
n2c = models.RotateNative2Celestial(*crval, 180, name="sky_rotation")
cel_model = wcslin | tan | n2c
cel_model = wcslin | tan | n2c | models.Mapping((1, 0))
icrs = cf.CelestialFrame(
reference_frame=coord.ICRS(), name="sky", axes_order=(2, 1)
)
Expand Down Expand Up @@ -734,3 +734,11 @@ def fits_wcs_imaging_simple(params):
w.wcs.lonpole = 180
w.wcs.set()
return gw, w


def gwcs_2d_spatial_shift_reverse():
"""
A simple one step spatial WCS with forward from sky to detector.
"""
pipe = [(ICRC_SKY_FRAME, MODEL_2D_SHIFT), (DETECTOR_2D_FRAME, None)]
return wcs.WCS(pipe)
5 changes: 5 additions & 0 deletions gwcs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,8 @@ def gwcs_romanisim():
def fits_wcs_imaging_simple(request):
params = request.param
return examples.fits_wcs_imaging_simple(params)


@pytest.fixture
def gwcs_2d_spatial_shift_reverse():
return examples.gwcs_2d_spatial_shift_reverse()
31 changes: 16 additions & 15 deletions gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import pytest
from astropy import coordinates as coord
from astropy import time
from astropy.tests.helper import assert_quantity_allclose
from astropy.wcs.wcsapi import HighLevelWCSWrapper
from astropy.wcs.wcsapi.high_level_api import values_to_high_level_objects
# from gwcs.utils import values_to_high_level_objects
from numpy.testing import assert_allclose, assert_array_equal

import gwcs
Expand Down Expand Up @@ -309,7 +311,13 @@ def test_high_level_wrapper(wcsobj, request):
wc1 = hlvl.pixel_to_world(*pixel_input)
wc2 = wcsobj(*pixel_input)
results = wcsobj._remove_units_input(wc2, wcsobj.output_frame)
wc2 = values_to_high_level_objects(*results, low_level_wcs=wcsobj)

wc2 = values_to_high_level_objects(
*results,
low_level_wcs=wcsobj,
object_classes=wcsobj.world_axis_object_classes,
object_components=wcsobj.world_axis_object_components
)
if len(wc2) == 1:
wc2 = wc2[0]
assert type(wc1) is type(wc2)
Expand All @@ -325,19 +333,11 @@ def test_high_level_wrapper(wcsobj, request):
wc1 = (wc1,)

pix_out1 = hlvl.world_to_pixel(*wc1)
pix_out2 = wcsobj.invert(*wc1)

if not isinstance(pix_out2, list | tuple):
pix_out2 = (pix_out2,)

if wcsobj.forward_transform.uses_quantity:
pix_out2 = tuple(
p.to_value(unit)
for p, unit in zip(pix_out2, wcsobj.input_frame.unit, strict=False)
)

np.testing.assert_allclose(pix_out1, pixel_input)
np.testing.assert_allclose(pix_out2, pixel_input)
result = wcsobj.invert(*wc1)
inpq = [pix * un for pix, un in zip(pixel_input, wcsobj.input_frame.unit)]
assert_quantity_allclose(result, inpq)


def test_stokes_wrapper(gwcs_stokes_lookup):
Expand Down Expand Up @@ -407,7 +407,8 @@ def test_pixel_bounds(wcsobj):

wcsobj.bounding_box = ((-0.5, 2039.5), (-0.5, 1019.5))
assert_array_equal(wcsobj.pixel_bounds, wcsobj.bounding_box)

# Reset the bounding box or this will affect other tests
wcsobj.bounding_box = None

@wcs_objs
def test_axis_correlation_matrix(wcsobj):
Expand Down Expand Up @@ -598,8 +599,8 @@ def test_coordinate_frame_api():
pixel = wcs.world_to_pixel(world)
assert isinstance(pixel, float)

pixel2 = wcs.invert(world)
assert u.allclose(pixel2, 0 * u.pix)
result = wcs.invert(world)
assert_quantity_allclose(result, 0*u.pix)


def test_world_axis_object_components_units(gwcs_3d_identity_units):
Expand Down
241 changes: 241 additions & 0 deletions gwcs/tests/test_api_consistent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Test the API is consistent with units and quantities and follows the rules below.

WCS functions which considered for this work are part of the legacy API:
wcs(x, y)
wcs.invert(ra, dec)
wcs.forward_transform(x,y), wcs.backward_transform() and wcs.get_transform(f1, f2)
wcs.numerical_inverse(ra, dec) - does not support units

Rules:


1. Neither transforms nor inputs support units -> the output is clearly numerical
for all functions above
2. Transforms support units but inputs do not -> return quantities assuming the
units of the coordinate frame
- This should work for the wcs methods (wcs(x,y) and wcs.invert
- The methods using transforms should follow modeling rules and will require units
on the input and raise an exception if not
3. Both transforms and inputs support units -> return quantities
- Wcs methods return quantities
- Transforms work and return quantities
4. Transforms do not support units but inputs are quantities -> raise an error

"""
import numbers
import numpy as np
from numpy.testing import assert_array_equal, assert_allclose

from astropy import units as u
from astropy.tests.helper import assert_quantity_allclose

import pytest
from .conftest import gwcs_with_pipeline_celestial, gwcs_2d_spatial_shift_reverse

x = 1
y = 2
xq = [1, 1] * u.pix
yq = 2 * u.pix


def is_numerical(args):
if isinstance(args, numbers.Number):
return True
#return all([isinstance(arg, numbers.Number) or type(arg) == np.ndarray for arg in args])
return all([isinstance(arg, numbers.Number) or arg is np.ndarray for arg in args])


def is_quantity(args):
return all([isinstance(arg, u.Quantity) for arg in args])


@pytest.fixture
def wcsobj(request):
return request.getfixturevalue(request.param)

wno_unit_1d = ["gwcs_1d_freq", "gwcs_1d_spectral",]

wno_unit_nd = ["gwcs_2d_shift_scale", "gwcs_3d_spatial_wave", "gwcs_2d_spatial_shift",
"gwcs_2d_spatial_reordered", "gwcs_3d_spatial_wave", "gwcs_simple_imaging",
"gwcs_3spectral_orders", "gwcs_3d_galactic_spectral", "gwcs_spec_cel_time_4d",
"gwcs_romanisim", ]

# "gwcs_7d_complex_mapping" errors in astropy - fix
# "gwcs_2d_quantity_shift" errors when inputs are quantities.
# Need to confirm if Qs are HLO

w_unit_1d = ["gwcs_stokes_lookup", "gwcs_1d_freq_quantity"]

w_unit_nd = ["gwcs_2d_shift_scale_quantity", "gwcs_3d_identity_units",
"gwcs_3d_identity_units", "gwcs_4d_identity_units", "gwcs_simple_imaging_units",
"gwcs_with_pipeline_celestial", ]

w_transform_test = ["gwcs_1d_freq_quantity", "gwcs_2d_quantity_shift"]

wcs_no_unit_1d = pytest.mark.parametrize(("wcsobj"), wno_unit_1d, indirect=True)
wcs_no_unit_nd = pytest.mark.parametrize(("wcsobj"), wno_unit_nd, indirect=True)
wcs_with_unit_1d = pytest.mark.parametrize(("wcsobj"), w_unit_1d, indirect=True)
wcs_with_unit_nd = pytest.mark.parametrize(("wcsobj"), w_unit_nd, indirect=True)


@wcs_no_unit_1d
def test_no_units_1d(wcsobj):
""" Transforms do not support units."""
assert not wcsobj.forward_transform.uses_quantity

# the case of a scalar input
x = 1
bbox = wcsobj.bounding_box
if bbox is not None:
x = np.mean(bbox.bounding_box())

result_num = wcsobj(x)
assert np.isscalar(result_num)

assert_allclose(wcsobj.invert(result_num), x)

xq = x * wcsobj.input_frame.unit[0]
result = wcsobj(xq)
assert_quantity_allclose(result, result_num * wcsobj.output_frame.unit[0])


@wcs_no_unit_nd
def test_no_units_nd(wcsobj):
assert not wcsobj.forward_transform.uses_quantity

n_inputs = wcsobj.input_frame.naxes

inp = [1] * n_inputs
bbox = wcsobj.bounding_box
if bbox is not None:
bb = bbox.bounding_box()
inp = [np.mean(interval) for interval in bb]
# Inputs are numbers
result = wcsobj(*inp)
assert is_numerical(result)
if np.isscalar(result):
result = [result]
inp_new = wcsobj.invert(*result)
_ = [assert_allclose(i, j) for i, j in zip(inp_new, inp, strict=True)]

# Inputs are quantities; return quantities (except for pixels?)
inpq = [coo * un for coo, un in zip(inp, wcsobj.input_frame.unit, strict=True)]
result = wcsobj(*inpq)
assert is_quantity(result)
inp_new = wcsobj.invert(*result)
_ = [assert_allclose(i, j) for i, j in zip(inp_new, inpq, strict=True)]

sky = wcsobj.pixel_to_world(*inp)
if not np.iterable(sky):
sky=(sky,)
inv_sky = wcsobj.invert(*sky)
assert_quantity_allclose(inv_sky, inpq)


@wcs_with_unit_1d
def test_with_units_1d(wcsobj):
""" Transform do not support units."""
assert wcsobj.forward_transform.uses_quantity

# the case of a scalar input
x = 1 * wcsobj.input_frame.unit[0]

result = wcsobj(x)
assert isinstance(result, u.Quantity)
assert_allclose(wcsobj.invert(result), x)

x = 1
result = wcsobj(x)
assert np.isscalar(result)
assert_allclose(wcsobj.invert(result), x)


@wcs_with_unit_nd
def test_transform_with_units(wcsobj):
""" Transforms support units."""
assert wcsobj.forward_transform.uses_quantity

n_inputs = wcsobj.input_frame.naxes
xx = [x] * n_inputs

# input is numerical; return numbers
result_num = wcsobj(*xx)
assert is_numerical(result_num)

inp = wcsobj.invert(*result_num)
assert is_numerical(inp)

# input is quantities; return quantities
xxq = [1 * u.pix] * n_inputs
result = wcsobj(*xxq)
assert all([type(res)==u.Quantity for res in result])
assert_allclose([r.value for r in result], result_num)

sky = wcsobj.pixel_to_world(*xxq)
if not np.iterable(sky):
sky=(sky,)
inv_sky = wcsobj.invert(*sky)
assert_quantity_allclose(inv_sky, xxq)


@wcs_no_unit_1d
def test_add_units(wcsobj):
if wcsobj.input_frame.naxes == 1:
assert wcsobj._add_units_input((1,), wcsobj.input_frame) == 1 * wcsobj.input_frame.unit[0]
assert_allclose(
wcsobj._add_units_input(([1, 1],), wcsobj.input_frame),
([1, 1] * wcsobj.input_frame.unit[0],))
elif wcsobj.input_frame.naxes == 2:
assert_quantity_allclose(
wcsobj._add_units_input((1, 1), wcsobj.input_frame),
(1*u.pix, 1*u.pix))
assert_quantity_allclose(
wcsobj._add_units_input(([1, 1], [1, 1]), wcsobj.input_frame),
([1, 1]*u.pix, [1, 1]*u.pix))


@wcs_with_unit_1d
def test_remove_units(wcsobj):
if wcsobj.input_frame.naxes == 1:
unit = wcsobj.input_frame.unit[0]
assert wcsobj._remove_units_input(1 * unit, wcsobj.input_frame) == (1,)
assert_allclose(
wcsobj._remove_units_input(([1, 1] * unit,), wcsobj.input_frame),
([1, 1],))
elif wcsobj.input_frame.naxes == 2:
assert_quantity_allclose(
wcsobj._remove_units_input((1*u.pix, 1*u.pix), wcsobj.input_frame),
(1, 1)
)
assert_quantity_allclose(
wcsobj._remove_units_input(([1, 1]*u.pix, [1, 1]*u.pix), wcsobj.input_frame),
([1, 1], [1, 1]))


def test_transform_multistage_wcs(gwcs_with_pipeline_celestial):
"""Tests that the input and output types match for intermediate frames/transforms."""
wcsobj = gwcs_with_pipeline_celestial
frames = wcsobj.available_frames
result = wcsobj.transform(frames[0], frames[-1], 1*u.pix, 1*u.pix)
assert is_quantity(result)
assert_allclose([r.value for r in result], wcsobj(1, 1))
final_result = wcsobj.transform(frames[0], frames[-1], 1*u.pix, 1*u.pix)
assert is_quantity(final_result)
assert_allclose([r.value for r in final_result], wcsobj(1, 1))
interm_result = wcsobj.transform(frames[0], frames[1], 1*u.pix, 1*u.pix)
assert is_quantity(interm_result)
tr = wcsobj.get_transform(frames[0], frames[1])
assert_quantity_allclose(interm_result, tr(1*u.pix, 1*u.pix))
ninterm_result = wcsobj.transform(frames[0], frames[1], 1, 1)
assert_allclose([r.value for r in interm_result], ninterm_result)


def test_reverse_wcs_direction(gwcs_2d_spatial_shift_reverse):
"""Test that input quantities are converted to the units of the input frame."""
wcsobj = gwcs_2d_spatial_shift_reverse
assert_quantity_allclose(
wcsobj(1*u.arcsec, 2*u.arcsec),
wcsobj(1*u.arcsec.to(u.deg)*u.deg, 2*u.arcsec.to(u.deg)*u.deg)
)
Loading
Loading