Skip to content

add SurfaceViewer to vizualize surfaces #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions optiland/optic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from optiland.distribution import create_distribution
from optiland.geometries import Plane, StandardGeometry
from optiland.materials import IdealMaterial
from optiland.visualization import OpticViewer, OpticViewer3D, LensInfoViewer
from optiland.visualization import (
SurfaceViewer,
OpticViewer,
OpticViewer3D,
LensInfoViewer)
from optiland.pickup import PickupManager
from optiland.solves import SolveManager

Expand Down Expand Up @@ -293,6 +297,28 @@
if surface.aperture is not None:
surface.aperture.scale(scale_factor)

def draw_surface(self,
surface_index,
projection='2d',
num_points=256,
figsize=(7, 5.5),
title=None):
"""
Visualize a surface.

Args:
surface_index (int): Index of the surface to be visualized.
projection (str): The type of projection to use for visualization.
Can be '2d' or '3d'.
num_points (int): The number of points to sample along each axis
for the visualization.
figsize (tuple): The size of the figure in inches.
Defaults to (7, 5.5).
title (str): Title.
"""
viewer = SurfaceViewer(self)
viewer.view(surface_index, projection, num_points, figsize, title)

Check warning on line 320 in optiland/optic.py

View check run for this annotation

Codecov / codecov/patch

optiland/optic.py#L319-L320

Added lines #L319 - L320 were not covered by tests

def draw(self, fields='all', wavelengths='primary', num_rays=3,
distribution='line_y', figsize=(10, 4), xlim=None, ylim=None,
title=None, reference=None):
Expand Down Expand Up @@ -406,17 +432,18 @@
surface.set_semi_aperture(r_max=ya[k]+yb[k])
self.update_normalization(surface)

def update_normalization(self, surface)->None:
def update_normalization(self, surface) -> None:
"""
Update the normalization radius of non-spherical surfaces.
"""
if surface.surface_type in ['even_asphere', 'odd_asphere', 'polynomial', 'chebyshev']:
surface.geometry.norm_x = surface.semi_aperture
surface.geometry.norm_y = surface.semi_aperture
if surface.surface_type in ['even_asphere', 'odd_asphere',
'polynomial', 'chebyshev']:
surface.geometry.norm_x = surface.semi_aperture*1.1
surface.geometry.norm_y = surface.semi_aperture*1.1
if surface.surface_type == 'zernike':
surface.geometry.norm_radius = surface.semi_aperture*1.1

def update(self)->None:
def update(self) -> None:
"""
Update the surfaces based on the pickup operations.
"""
Expand Down
2 changes: 1 addition & 1 deletion optiland/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# flake8: noqa

from .visualization import OpticViewer, OpticViewer3D, LensInfoViewer
from .visualization import SurfaceViewer, OpticViewer, OpticViewer3D, LensInfoViewer
150 changes: 150 additions & 0 deletions optiland/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,156 @@
plt.rcParams.update({'font.size': 12, 'font.family': 'cambria'})


class SurfaceViewer:
"""
A class used to visualize surfaces.

Args:
optic: The optical system to be visualized.
surface_index: Index of the surface to be visualized.
"""

def __init__(self, optic):
self.optic = optic

def view(self,
surface_index: int,
projection: str = '2d',
num_points: int = 256,
figsize: tuple = (7, 5.5),
title: str = None):
"""
Visualize the surface.

Args:
surface_index (int): Index of the surface to be visualized.
projection (str): The type of projection to use for visualization.
Can be '2d' or '3d'.
num_points (int): The number of points to sample along each axis
for the visualization.
figsize (tuple): The size of the figure in inches.
Defaults to (7, 5.5).
title (str): Title.

Raises:
ValueError: If the projection is not '2d' or '3d'.
"""
# Update optics and compute surface sag
self.optic.update_paraxial()
surface = self.optic.surface_group.surfaces[surface_index]
semi_aperture = surface.semi_aperture
x, y = np.meshgrid(
np.linspace(-semi_aperture, semi_aperture, num_points),
np.linspace(-semi_aperture, semi_aperture, num_points),)
z = surface.geometry.sag(x, y)
z[np.sqrt(x**2+y**2) > semi_aperture] = np.nan

# Plot in 2D
if projection == '2d':
self._plot_2d(z, figsize=figsize, title=title,
surface_type=surface.surface_type,
surface_index=surface_index,
semi_aperture=semi_aperture)
# Plot in 3D
elif projection == '3d':
self._plot_3d(x, y, z, figsize=figsize, title=title,
surface_type=surface.surface_type,
surface_index=surface_index,
semi_aperture=semi_aperture)
else:
raise ValueError('Projection must be "2d" or "3d".')

Check warning on line 80 in optiland/visualization/visualization.py

View check run for this annotation

Codecov / codecov/patch

optiland/visualization/visualization.py#L80

Added line #L80 was not covered by tests

def _plot_2d(self,
z: np.ndarray,
figsize: tuple = (7, 5.5),
title: str = None,
**kwargs):
"""
Plot a 2D representation of the given data.

Args:
z (numpy.ndarray): The data to be plotted.
figsize (tuple, optional): The size of the figure
(default is (7, 5.5)).
title (str): Title.
"""
_, ax = plt.subplots(figsize=figsize)

if 'semi_aperture' in kwargs:
semi_aperture = kwargs['semi_aperture']
extent = [-semi_aperture, semi_aperture, -semi_aperture, semi_aperture]
ax.set_xlabel('X [mm]')
ax.set_ylabel('Y [mm]')
else:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like semi_aperture is always passed as an argument, so the "else" condition will never occur. If that's the case, then you can remove the if statement entirely.

extent = [-1, 1, -1, 1]
ax.set_xlabel('Normalized X')
ax.set_ylabel('Normalized Y')

Check warning on line 106 in optiland/visualization/visualization.py

View check run for this annotation

Codecov / codecov/patch

optiland/visualization/visualization.py#L104-L106

Added lines #L104 - L106 were not covered by tests
im = ax.imshow(np.flipud(z), extent=extent)

if title is not None:
ax.set_title(title)

Check warning on line 110 in optiland/visualization/visualization.py

View check run for this annotation

Codecov / codecov/patch

optiland/visualization/visualization.py#L110

Added line #L110 was not covered by tests
else:
ax.set_title(
f'Surface {kwargs.get("surface_index", None)} '
f'deviation to plane\n'
f'{kwargs.get("surface_type", None).capitalize()} surface'
)

cbar = plt.colorbar(im)
cbar.ax.get_yaxis().labelpad = 15
cbar.ax.set_ylabel("Deviation to plane [mm]", rotation=270)
plt.grid(alpha=0.25)
plt.show()

def _plot_3d(self,
x: np.ndarray,
y: np.ndarray,
z: np.ndarray,
figsize: tuple = (7, 5.5),
title: str = None,
**kwargs):
"""
Plot a 3D surface plot of the given data.

Args:
x (numpy.ndarray): Array of x-coordinates.
y (numpy.ndarray): Array of y-coordinates.
z (numpy.ndarray): Array of z-coordinates.
figsize (tuple, optional): Size of the figure (width, height).
Default is (7, 5.5).
title (str): Title.
"""
fig, ax = plt.subplots(subplot_kw={"projection": "3d"},
figsize=figsize)

surf = ax.plot_surface(x, y, z,
rstride=1, cstride=1,
cmap='viridis', linewidth=0,
antialiased=False)

if 'semi_aperture' in kwargs:
ax.set_xlabel('X [mm]')
ax.set_ylabel('Y [mm]')
else:
ax.set_xlabel('Normalized X')
ax.set_ylabel('Normalized Y')

Check warning on line 155 in optiland/visualization/visualization.py

View check run for this annotation

Codecov / codecov/patch

optiland/visualization/visualization.py#L154-L155

Added lines #L154 - L155 were not covered by tests
ax.set_zlabel("Deviation to plane [mm]")

if title is not None:
ax.set_title(title)

Check warning on line 159 in optiland/visualization/visualization.py

View check run for this annotation

Codecov / codecov/patch

optiland/visualization/visualization.py#L159

Added line #L159 was not covered by tests
else:
ax.set_title(
f'Surface {kwargs.get("surface_index", None)} '
f'deviation to plane\n'
f'{kwargs.get("surface_type", None).capitalize()} surface'
)
fig.colorbar(surf, ax=ax, shrink=0.5, aspect=10,
pad=0.15)

fig.tight_layout()
plt.show()


class OpticViewer:
"""
A class used to visualize optical systems.
Expand Down
54 changes: 53 additions & 1 deletion tests/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from unittest.mock import patch
import pytest
import numpy as np
from optiland.visualization import OpticViewer, OpticViewer3D, LensInfoViewer
from optiland.visualization import (
SurfaceViewer,
OpticViewer,
OpticViewer3D,
LensInfoViewer
)
from optiland.samples.objectives import (
TessarLens,
ReverseTelephoto
Expand Down Expand Up @@ -50,6 +55,53 @@ def k(self, wavelength):
return -42


class TestSurfaceViewer:
def test_init(self):
lens = TessarLens()
viewer = SurfaceViewer(lens)
assert viewer.optic == lens

@patch('matplotlib.pyplot.show')
def test_view(self, mock_show):
lens = ReverseTelephoto()
viewer = SurfaceViewer(lens)
viewer.view(surface_index=1)
mock_show.assert_called_once()
plt.close()

@patch('matplotlib.pyplot.show')
def test_view_2d(self, mock_show):
lens = ReverseTelephoto()
viewer = SurfaceViewer(lens)
viewer.view(surface_index=1, projection='2d')
mock_show.assert_called_once()
plt.close()

@patch('matplotlib.pyplot.show')
def test_view_3d(self, mock_show):
lens = ReverseTelephoto()
viewer = SurfaceViewer(lens)
viewer.view(surface_index=1, projection='3d')
mock_show.assert_called_once()
plt.close()

@patch('matplotlib.pyplot.show')
def test_view_bonded_lens(self, mock_show):
lens = TessarLens()
viewer = SurfaceViewer(lens)
viewer.view(surface_index=1)
mock_show.assert_called_once()
plt.close()

@patch('matplotlib.pyplot.show')
def test_view_reflective_lens(self, mock_show):
lens = HubbleTelescope()
viewer = SurfaceViewer(lens)
viewer.view(surface_index=1)
mock_show.assert_called_once()
plt.close()


class TestOpticViewer:
def test_init(self):
lens = TessarLens()
Expand Down