@@ -781,18 +781,42 @@ def _cube_spec_plotter(
781781 cube_wcs = WCS (cube_hdu )
782782
783783 swcs = cube_wcs .spectral if cube_wcs .has_spectral else cube_wcs .sub ([3 ])
784- with u .set_enabled_equivalencies (u .spectral ()):
785- wave = swcs .pixel_to_world (np .arange (swcs .pixel_shape [0 ]))
784+ px = np .arange (swcs .pixel_shape [0 ]) * u .pixel
785+
786+ def _px2wave (pixel ):
787+ with u .set_enabled_equivalencies (u .spectral ()):
788+ wave = swcs .pixel_to_world (pixel )
789+ return (wave << u .um ).value
790+
791+ def _wave2px (wave ):
792+ if not len (wave ):
793+ # Catch empty list which matplotib passes here for whatever reason.
794+ return wave
795+ with u .set_enabled_equivalencies (u .spectral ()):
796+ wave = (wave << u .um ).to_value (u .m )
797+ shape = wave .shape
798+ wave = np .atleast_1d (wave .squeeze ())
799+ pix = np .array (swcs .all_world2pix (wave , 0 )).reshape (shape )
800+ return pix
801+
802+ drawstyle = "default"
803+ y_label = r"$F_\lambda$"
804+ if (bunit := cube_hdu .header .get ("BUNIT" )) is not None :
786805 try :
787- wave <<= u .um
788- except u .UnitConversionError :
789- # TODO: perhaps deal with pixel and/or mm separately and let
790- # anything else fail on purpose??
791- pass # catch and ignore dimensionless or pixel coordinates
806+ flux_unit = u .Unit (bunit )
807+ if not unit_includes_per_physical_type (flux_unit , "length" ):
808+ # Binned flux
809+ drawstyle = "steps-mid"
810+ y_label = r"$F$"
811+ except ValueError :
812+ pass # Catch missing unit, default to default
792813
793- axes .plot (wave , cube_hdu .data .sum (axis = (1 , 2 )))
794- axes .set_xlabel (wave .unit )
795- axes .set_ylabel (_get_bunit_label (cube_hdu .header ))
814+ axes .plot (px , cube_hdu .data .sum (axis = (1 , 2 )), drawstyle = drawstyle )
815+ axes .set_xlabel (px .unit )
816+ wax = axes .secondary_xaxis (location = "top" , functions = (_px2wave , _wave2px ))
817+ wax .set_xlabel (fr"$\lambda$ [{ u .um } ]" )
818+ axes .set_ylabel (f"{ y_label } [{ _get_bunit_label (cube_hdu .header )} ]" )
819+ axes .grid ()
796820
797821
798822def _get_bunit_label (header : fits .Header ) -> u .Unit | str :
@@ -829,7 +853,8 @@ def cube_plotter(
829853 cube_hdu : fits .ImageHDU
830854) -> tuple [mpl .figure .Figure , tuple [mpl .axes .Axes , mpl .axes .Axes ]]:
831855 """Plot cube in separate plots for spatial and spectral parts."""
832- fig , (ax_img , ax_spec ) = figure_factory (2 , height_ratios = (2 , 1 ))
856+ fig , (ax_img , ax_spec ) = figure_factory (2 , height_ratios = (2 , 1 ),
857+ layout = "tight" )
833858 cube_wcs = WCS (cube_hdu )
834859
835860 _cube_image_plotter (fig , ax_img , cube_hdu , cube_wcs )
0 commit comments