Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions changes/696.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Clean up internal code flow so that the entire GWCS Native API follows common code
paths. This enables the Native API to properly handle various forms of inputs in
a consistent manner so that evaluation and inverse have the same input handling
behavior. This means HLO and quantities should function properly as they are passed
into these portions of the Native API.

Also allows for quantities to be passed into the "low-level APE 14 API" without
raising errors.
2 changes: 0 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@
# Enable nitpicky mode - which ensures that all references in the docs resolve.
nitpicky = True
nitpick_ignore = [
("py:class", "gwcs.api.GWCSAPIMixin"),
("py:class", "gwcs.wcs._pipeline.Pipeline"),
("py:obj", "astropy.modeling.projections.projcodes"),
("py:attr", "gwcs.WCS.bounding_box"),
("py:meth", "gwcs.WCS.footprint"),
Expand Down
4 changes: 3 additions & 1 deletion docs/gwcs/points_to_wcs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ easily work in both pixel and sky space, and transform between frames.
The GWCS object, which by default when called executes for forward transformation,
can be used to convert coordinates from pixel to world.

.. doctest-requires:: numpy>=2.0

>>> gwcs_obj(36.235,642.215) # doctest: +FLOAT_CMP
(246.72158004206716, 43.46075091731673)
(np.float64(246.72158004206716), np.float64(43.46075091731673))

Or using the common WCS API

Expand Down
58 changes: 54 additions & 4 deletions gwcs/coordinate_frames/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,57 @@ def remove_units(
for array in self.add_units(arrays)
)

def to_high_level_coordinates(self, *values):
def is_high_level(self, *args) -> bool:
"""
Return `True` if the input coordinates are already high level objects
described by this frame.

This is used by the low level WCS API in Astropy to determine whether
to call ``to_high_level_coordinates`` or not.
"""

if (world_axis_object_classes := self.world_axis_object_classes) is None or len(
args
) != len(world_axis_object_classes):
return False

type_match = []
for arg, world_axis_object_class in zip(
args, world_axis_object_classes.values(), strict=True
):
if isinstance(class_object := world_axis_object_class.class_object, str):
type_match.append(
type(arg).__name__ == class_object
and class_object != u.Quantity.__name__
)
else:
type_match.append(
isinstance(arg, class_object) and class_object is not u.Quantity
)

if all(type_match):
return True

if any(type_match):
types = [
(
type(arg).__name__,
c.class_object
if isinstance(c.class_object, str)
else c.class_object.__name__,
)
for arg, c in zip(args, world_axis_object_classes.values(), strict=True)
]
msg = (
"Invalid types were passed, got "
f"({', '.join(t[0] for t in types)}), but expected "
f"({', '.join(t[1] for t in types)})."
)
raise TypeError(msg)

return False

def to_high_level_coordinates(self, *values, correct_1d=True):
"""
Convert "values" to high level coordinate objects described by this frame.

Expand All @@ -301,11 +351,11 @@ def to_high_level_coordinates(self, *values):
raise TypeError(msg)

high_level = values_to_high_level_objects(*values, low_level_wcs=self)
if len(high_level) == 1:
if correct_1d and len(high_level) == 1:
high_level = high_level[0]
return high_level

def from_high_level_coordinates(self, *high_level_coords):
def from_high_level_coordinates(self, *high_level_coords, correct_1d=True):
"""
Convert high level coordinate objects to "values" as described by this frame.

Expand All @@ -324,7 +374,7 @@ def from_high_level_coordinates(self, *high_level_coords):
``naxis`` number of coordinates as scalars or arrays.
"""
values = high_level_objects_to_values(*high_level_coords, low_level_wcs=self)
if len(values) == 1:
if correct_1d and self.naxes == 1:
values = values[0]
return values

Expand Down
15 changes: 10 additions & 5 deletions gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,15 @@ def test_high_level_wrapper(wcsobj, request):
wc1 = (wc1,)

pix_out1 = hlvl.world_to_pixel(*wc1)
pix_out2 = wcsobj.invert(*wc1)

pix_out2 = wcsobj.input_frame.remove_units(pix_out2)

if not isinstance(pix_out2, list | tuple):
pix_out2 = (pix_out2,)

np.testing.assert_allclose(pix_out1, pixel_input)
with pytest.raises(TypeError) as e:
_ = wcsobj.invert(*wc1)
assert "High Level objects are not supported with the native" in str(e)
np.testing.assert_allclose(pix_out2, pixel_input)


def test_stokes_wrapper(gwcs_stokes_lookup):
Expand Down Expand Up @@ -592,8 +597,8 @@ def test_coordinate_frame_api():
pixel = wcs.world_to_pixel(world)
assert isinstance(pixel, float)

with pytest.raises(TypeError):
_ = wcs.invert(world)
pixel2 = wcs.invert(world)
assert u.allclose(pixel2, 0 * u.pix)


def test_world_axis_object_components_units(gwcs_3d_identity_units):
Expand Down
10 changes: 2 additions & 8 deletions gwcs/tests/test_api_consistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,7 @@ def test_no_units_nd(wcsobj):
sky = wcsobj.pixel_to_world(*inp)
if not np.iterable(sky):
sky = (sky,)
with pytest.raises(
TypeError, match=r"High Level objects are not supported with the native"
):
wcsobj.invert(*sky)
assert u.allclose(inpq, wcsobj.invert(*sky))


@wcs_with_unit_1d
Expand Down Expand Up @@ -191,10 +188,7 @@ def test_transform_with_units(wcsobj):
sky = wcsobj.pixel_to_world(*xxq)
if not np.iterable(sky):
sky = (sky,)
with pytest.raises(
TypeError, match=r"High Level objects are not supported with the native"
):
wcsobj.invert(*sky)
assert u.allclose(xxq, wcsobj.invert(*sky))


@wcs_no_unit_1d
Expand Down
3 changes: 3 additions & 0 deletions gwcs/tests/test_coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,9 @@ def add_units(self, arrays):
def remove_units(self, arrays):
return arrays

def is_high_level(self, *inputs):
return False

def to_high_level_coordinates(self, *inputs):
return inputs

Expand Down
15 changes: 6 additions & 9 deletions gwcs/tests/test_wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,10 +1788,10 @@ def test_quantities_in_pipeline_backward(gwcs_with_pipeline_celestial):
20 * u.arcsec + 1 * u.deg,
15 * u.deg + 2 * u.deg,
]
with pytest.raises(
TypeError, match=r"High Level objects are not supported with the native"
):
iwcs.invert(*input_world)
pixel = iwcs.invert(*input_world)

assert all(isinstance(p, u.Quantity) for p in pixel)
assert u.allclose(pixel, [1, 1] * u.pix)

intermediate_world = iwcs.transform(
"output",
Expand Down Expand Up @@ -1939,12 +1939,9 @@ def test_parameterless_transform():
assert gwcs(1, 1) == (1, 1)
assert gwcs(1 * u.pix, 1 * u.pix) == (1 * u.pix, 1 * u.pix)

# No units introduced by the inverse transform
assert gwcs.invert(1, 1) == (1, 1)
# Strictly speaking it's correct that this fails Because
# for this setup the HLO are Quantities
with pytest.raises(TypeError) as e:
_ = gwcs.invert(1 * u.pix, 1 * u.pix)
assert "High Level objects are not supported with the native" in str(e)
assert gwcs.invert(1 * u.pix, 1 * u.pix) == (1 * u.pix, 1 * u.pix)


def test_fitswcs_imaging(fits_wcs_imaging_simple):
Expand Down
28 changes: 2 additions & 26 deletions gwcs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,34 +479,10 @@ def create_projection_transform(projcode):
return projklass(**projparams)


# ToDo: Should this be deprecated?
def is_high_level(*args, low_level_wcs):
"""
Determine if args matches the high level classes as defined by
``low_level_wcs``.
"""
if low_level_wcs.world_axis_object_classes is None or len(args) != len(
low_level_wcs.world_axis_object_classes
):
return False

type_match = [
(type(arg), waoc[0])
for arg, waoc in zip(
args, low_level_wcs.world_axis_object_classes.values(), strict=False
)
]

types_are_high_level = [argt is t for argt, t in type_match]

if all(types_are_high_level):
return True

if any(types_are_high_level):
msg = (
"Invalid types were passed, got "
f"({', '.join(tm[0].__name__ for tm in type_match)}) expected "
f"({', '.join(tm[1].__name__ for tm in type_match)})."
)
raise TypeError(msg)

return False
return low_level_wcs.output_frame.is_high_level(*args)
3 changes: 3 additions & 0 deletions gwcs/wcs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from ._exception import GwcsBoundingBoxWarning, NoConvergence
from ._pipeline import DirectionalPipeline, Pipeline
from ._step import Step
from ._wcs import WCS

__all__ = [
"WCS",
"DirectionalPipeline",
"GwcsBoundingBoxWarning",
"NoConvergence",
"Pipeline",
"Step",
]
Loading
Loading