|
5 | 5 | from typing import TYPE_CHECKING |
6 | 6 |
|
7 | 7 | import h5py |
| 8 | +import matplotlib as mpl |
8 | 9 | import numpy as np |
9 | 10 | import scipy |
10 | 11 | from astropy.io import fits |
11 | 12 |
|
12 | 13 | from regularizepsf.exceptions import InvalidCoordinateError |
13 | 14 | from regularizepsf.util import IndexedCube |
| 15 | +from regularizepsf.visualize import KERNEL_IMSHOW_ARGS_DEFAULT, visualize_grid |
14 | 16 |
|
15 | 17 | if TYPE_CHECKING: |
16 | 18 | import pathlib |
@@ -138,14 +140,25 @@ def slice_padded_image(coordinate: tuple[int, int]) -> tuple[slice, slice]: |
138 | 140 | 2 * self.psf_shape[1] : image.shape[1] + 2 * self.psf_shape[1], |
139 | 141 | ] |
140 | 142 |
|
141 | | - def visualize(self) -> None: |
142 | | - """Visualize the PSFTransform. |
143 | | -
|
144 | | - Returns |
145 | | - ------- |
146 | | - None |
147 | | -
|
148 | | - """ |
| 143 | + def visualize(self, |
| 144 | + fig: mpl.figure.Figure | None = None, |
| 145 | + fig_scale: int = 1, |
| 146 | + all_patches: bool = False, imshow_args: dict | None = None) -> None: # noqa: ANN002, ANN003 |
| 147 | + """Visualize the transfer kernels.""" |
| 148 | + imshow_args = KERNEL_IMSHOW_ARGS_DEFAULT if imshow_args is None else imshow_args |
| 149 | + |
| 150 | + arr = np.abs(np.fft.fftshift(np.fft.ifft2(self._transfer_kernel.values))) |
| 151 | + extent = np.max(np.abs(arr)) |
| 152 | + if 'vmin' not in imshow_args: |
| 153 | + imshow_args['vmin'] = -extent |
| 154 | + if 'vmax' not in imshow_args: |
| 155 | + imshow_args['vmax'] = extent |
| 156 | + |
| 157 | + return visualize_grid( |
| 158 | + IndexedCube(self._transfer_kernel.coordinates, arr), |
| 159 | + all_patches=all_patches, fig=fig, |
| 160 | + fig_scale=fig_scale, colorbar_label="Transfer kernel amplitude", |
| 161 | + imshow_args=imshow_args) |
149 | 162 |
|
150 | 163 | def save(self, path: pathlib.Path) -> None: |
151 | 164 | """Save a PSFTransform to a file. Supports h5 and FITS. |
@@ -173,6 +186,7 @@ def save(self, path: pathlib.Path) -> None: |
173 | 186 | name="transfer_imag", quantize_level=32)]).writeto(path) |
174 | 187 | else: |
175 | 188 | raise NotImplementedError(f"Unsupported file type {path.suffix}. Change to .h5 or .fits.") |
| 189 | + |
176 | 190 | @classmethod |
177 | 191 | def load(cls, path: pathlib.Path) -> ArrayPSFTransform: |
178 | 192 | """Load a PSFTransform object. Supports h5 and FITS. |
|
0 commit comments