Skip to content
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

- Adjust and fix tests. Replace ``logging`` with `warnings``. [#659]

- Update the legacy API. [#660]

0.26.1 (2025-11-19)
-------------------
- Fix an indexing bug in ``spectroscopy.SellmeierZemax`` where the output ``n`` for array-type wavelength
Expand Down
14 changes: 6 additions & 8 deletions gwcs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,13 @@ def world_axis_units(self):
return tuple(unit.to_string(format="vounit") for unit in self.output_frame.unit)

def _remove_quantity_output(self, result, frame):
if self.forward_transform.uses_quantity:
if frame.naxes == 1:
result = [result]
if frame.naxes == 1:
result = [result]

result = tuple(
r.to_value(unit) if isinstance(r, u.Quantity) else r
for r, unit in zip(result, frame.unit, strict=False)
)
result = tuple(
r.to_value(unit) if isinstance(r, u.Quantity) else r
for r, unit in zip(result, frame.unit, strict=False)
)

# If we only have one output axes, we shouldn't return a tuple.
if self.output_frame.naxes == 1 and isinstance(result, tuple):
Expand All @@ -94,7 +93,6 @@ def pixel_to_world_values(self, *pixel_arrays):
is the vertical coordinate.
"""
result = self(*pixel_arrays)

return self._remove_quantity_output(result, self.output_frame)

def array_index_to_world_values(self, *index_arrays):
Expand Down
4 changes: 3 additions & 1 deletion gwcs/coordinate_frames/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def add_units(self, arrays: u.Quantity | np.ndarray | float) -> tuple[u.Quantity
"""
Add units to the arrays
"""
if self.naxes == 1 and np.isscalar(arrays):
return u.Quantity(arrays, self.unit[0])
return tuple(
u.Quantity(array, unit=unit)
for array, unit in zip(arrays, self.unit, strict=True)
Expand All @@ -139,7 +141,7 @@ def remove_units(
"""
Remove units from the input arrays
"""
if self.naxes == 1:
if self.naxes == 1 and (np.isscalar(arrays) or isinstance(arrays, u.Quantity)):
arrays = (arrays,)

return tuple(
Expand Down
26 changes: 25 additions & 1 deletion gwcs/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def gwcs_spec_cel_time_4d():
wcslin = models.Mapping((1, 0)) | (offx & offy) | aff
tan = models.Pix2Sky_TAN(name="tangent_projection")
n2c = models.RotateNative2Celestial(*crval, 180, name="sky_rotation")
cel_model = wcslin | tan | n2c
cel_model = wcslin | tan | n2c | models.Mapping((1, 0))
icrs = cf.CelestialFrame(
reference_frame=coord.ICRS(), name="sky", axes_order=(2, 1)
)
Expand Down Expand Up @@ -734,3 +734,27 @@ def fits_wcs_imaging_simple(params):
w.wcs.lonpole = 180
w.wcs.set()
return gw, w


def gwcs_2d_spatial_shift_reverse():
"""
A simple one step spatial WCS with forward from sky to detector.
"""
pipe = [(ICRC_SKY_FRAME, MODEL_2D_SHIFT), (DETECTOR_2D_FRAME, None)]
return wcs.WCS(pipe)


def gwcs_multi_stage():
"""
A 3-step pipeline where the intermediate step is 1D and the final is 2D.
"""
tr1 = models.Shift(10)
tr2 = models.Mapping((0, 0)) | models.Scale(-2) & models.Scale(-1)
det = cf.CoordinateFrame(
name="detector", naxes=1, unit=("pix",), axes_type="SPATIAL", axes_order=(0,)
)
interm = cf.CoordinateFrame(
name="interm", naxes=1, unit=("m",), axes_type="SPATIAL", axes_order=(0,)
)
cel = cf.CelestialFrame(name="sky", axes_names=("ra", "dec"))
return wcs.WCS([(det, tr1), (interm, tr2), (cel, None)])
10 changes: 10 additions & 0 deletions gwcs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,13 @@ def gwcs_romanisim():
def fits_wcs_imaging_simple(request):
params = request.param
return examples.fits_wcs_imaging_simple(params)


@pytest.fixture
def gwcs_2d_spatial_shift_reverse():
return examples.gwcs_2d_spatial_shift_reverse()


@pytest.fixture
def gwcs_multi_stage():
return examples.gwcs_multi_stage()
22 changes: 8 additions & 14 deletions gwcs/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def test_high_level_wrapper(wcsobj, request):
wc1 = hlvl.pixel_to_world(*pixel_input)
wc2 = wcsobj(*pixel_input)
results = wcsobj._remove_units_input(wc2, wcsobj.output_frame)

wc2 = values_to_high_level_objects(*results, low_level_wcs=wcsobj)
if len(wc2) == 1:
wc2 = wc2[0]
Expand All @@ -325,19 +326,10 @@ def test_high_level_wrapper(wcsobj, request):
wc1 = (wc1,)

pix_out1 = hlvl.world_to_pixel(*wc1)
pix_out2 = wcsobj.invert(*wc1)

if not isinstance(pix_out2, list | tuple):
pix_out2 = (pix_out2,)

if wcsobj.forward_transform.uses_quantity:
pix_out2 = tuple(
p.to_value(unit)
for p, unit in zip(pix_out2, wcsobj.input_frame.unit, strict=False)
)

np.testing.assert_allclose(pix_out1, pixel_input)
np.testing.assert_allclose(pix_out2, pixel_input)
with pytest.raises(TypeError) as e:
_ = wcsobj.invert(*wc1)
assert "High Level objects are not supported with the native" in str(e)


def test_stokes_wrapper(gwcs_stokes_lookup):
Expand Down Expand Up @@ -407,6 +399,8 @@ def test_pixel_bounds(wcsobj):

wcsobj.bounding_box = ((-0.5, 2039.5), (-0.5, 1019.5))
assert_array_equal(wcsobj.pixel_bounds, wcsobj.bounding_box)
# Reset the bounding box or this will affect other tests
wcsobj.bounding_box = None


@wcs_objs
Expand Down Expand Up @@ -598,8 +592,8 @@ def test_coordinate_frame_api():
pixel = wcs.world_to_pixel(world)
assert isinstance(pixel, float)

pixel2 = wcs.invert(world)
assert u.allclose(pixel2, 0 * u.pix)
with pytest.raises(TypeError):
_ = wcs.invert(world)


def test_world_axis_object_components_units(gwcs_3d_identity_units):
Expand Down
Loading
Loading