Skip to content

Commit 8572c19

Browse files
authored
Improve cube plotter (#736)
2 parents 1147ca9 + 4328d91 commit 8572c19

1 file changed

Lines changed: 36 additions & 11 deletions

File tree

scopesim/utils.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -781,18 +781,42 @@ def _cube_spec_plotter(
781781
cube_wcs = WCS(cube_hdu)
782782

783783
swcs = cube_wcs.spectral if cube_wcs.has_spectral else cube_wcs.sub([3])
784-
with u.set_enabled_equivalencies(u.spectral()):
785-
wave = swcs.pixel_to_world(np.arange(swcs.pixel_shape[0]))
784+
px = np.arange(swcs.pixel_shape[0]) * u.pixel
785+
786+
def _px2wave(pixel):
787+
with u.set_enabled_equivalencies(u.spectral()):
788+
wave = swcs.pixel_to_world(pixel)
789+
return (wave << u.um).value
790+
791+
def _wave2px(wave):
792+
if not len(wave):
793+
# Catch empty list which matplotib passes here for whatever reason.
794+
return wave
795+
with u.set_enabled_equivalencies(u.spectral()):
796+
wave = (wave << u.um).to_value(u.m)
797+
shape = wave.shape
798+
wave = np.atleast_1d(wave.squeeze())
799+
pix = np.array(swcs.all_world2pix(wave, 0)).reshape(shape)
800+
return pix
801+
802+
drawstyle = "default"
803+
y_label = r"$F_\lambda$"
804+
if (bunit := cube_hdu.header.get("BUNIT")) is not None:
786805
try:
787-
wave <<= u.um
788-
except u.UnitConversionError:
789-
# TODO: perhaps deal with pixel and/or mm separately and let
790-
# anything else fail on purpose??
791-
pass # catch and ignore dimensionless or pixel coordinates
806+
flux_unit = u.Unit(bunit)
807+
if not unit_includes_per_physical_type(flux_unit, "length"):
808+
# Binned flux
809+
drawstyle = "steps-mid"
810+
y_label = r"$F$"
811+
except ValueError:
812+
pass # Catch missing unit, default to default
792813

793-
axes.plot(wave, cube_hdu.data.sum(axis=(1, 2)))
794-
axes.set_xlabel(wave.unit)
795-
axes.set_ylabel(_get_bunit_label(cube_hdu.header))
814+
axes.plot(px, cube_hdu.data.sum(axis=(1, 2)), drawstyle=drawstyle)
815+
axes.set_xlabel(px.unit)
816+
wax = axes.secondary_xaxis(location="top", functions=(_px2wave, _wave2px))
817+
wax.set_xlabel(fr"$\lambda$ [{u.um}]")
818+
axes.set_ylabel(f"{y_label} [{_get_bunit_label(cube_hdu.header)}]")
819+
axes.grid()
796820

797821

798822
def _get_bunit_label(header: fits.Header) -> u.Unit | str:
@@ -829,7 +853,8 @@ def cube_plotter(
829853
cube_hdu: fits.ImageHDU
830854
) -> tuple[mpl.figure.Figure, tuple[mpl.axes.Axes, mpl.axes.Axes]]:
831855
"""Plot cube in separate plots for spatial and spectral parts."""
832-
fig, (ax_img, ax_spec) = figure_factory(2, height_ratios=(2, 1))
856+
fig, (ax_img, ax_spec) = figure_factory(2, height_ratios=(2, 1),
857+
layout="tight")
833858
cube_wcs = WCS(cube_hdu)
834859

835860
_cube_image_plotter(fig, ax_img, cube_hdu, cube_wcs)

0 commit comments

Comments
 (0)