Skip to content

Commit f8af45a

Browse files
committed
decouple spot diagram plot and optimizer rms calc
1 parent 0f46129 commit f8af45a

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

tracepy/optplot.py

+29-19
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,15 @@
88

99
from typing import List, Dict, Tuple
1010

11-
# TODO: Decouple plotting the spot diagram from the RMS calculation used by optimizer.
12-
# This involves breaking up the three functions below to decouple the two tasks.
1311
def _gen_object_points(surface: geometry,
1412
surface_idx: int,
15-
rays: List[ray]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
13+
rays: List[ray]) -> np.ndarray:
1614
"""Transform intersection points into a surfaces' reference frame.
1715
1816
Parameters
1917
----------
2018
surface : geometry object
21-
Surface whos reference frame the points will be transformed into.
19+
Surface whose reference frame the points will be transformed into.
2220
surface_idx : int
2321
Integer corresponding to where the surface is in the propagation
2422
order. For example, 0 means the surface is the first surface rays
@@ -28,8 +26,6 @@ def _gen_object_points(surface: geometry,
2826
2927
Returns
3028
-------
31-
X, Y : np.array of len(rays)
32-
X, Y pair in the surface's reference frame.
3329
points_obj: 2d np.array
3430
X, Y pair points in 2d array for easy rms calculation.
3531
@@ -39,16 +35,29 @@ def _gen_object_points(surface: geometry,
3935
if points.size == 0:
4036
#No rays survived
4137
raise TraceError()
42-
#Get X,Y points in obj. reference frame.
38+
39+
# Get X,Y points in obj. reference frame.
4340
points_obj = transform_points(surface.R, surface, points)
44-
#Round arrays to upper bound on accuracy.
45-
points_obj = np.around(points_obj, 14)
46-
if points_obj.ndim == 2:
47-
X, Y = points_obj[:,0], points_obj[:,1]
48-
elif points_obj.ndim == 1:
49-
X, Y = points_obj[0], points_obj[1]
50-
points_obj = np.array([points_obj])
51-
return X, Y, points_obj
41+
42+
# Round arrays to upper bound on accuracy.
43+
return np.around(points_obj, 14)
44+
45+
def calculate_rms(points: np.ndarray) -> float:
46+
"""Calculates the RMS of the given points.
47+
48+
Parameters
49+
----------
50+
points : np.ndarray
51+
Array of points to calculate RMS for.
52+
53+
Returns
54+
-------
55+
float
56+
The calculated RMS value.
57+
58+
"""
59+
60+
return np.std(points[:, [0, 1]] - points[:, [0, 1]].mean(axis=0))
5261

5362
def spot_rms(geo_params: List[Dict], rays: List[ray]) -> float:
5463
"""Calculates the RMS of the spot diagram points.
@@ -68,8 +77,8 @@ def spot_rms(geo_params: List[Dict], rays: List[ray]) -> float:
6877
"""
6978

7079
stop = geometry(geo_params[-1])
71-
_, _, points_obj = _gen_object_points(stop, -1, rays)
72-
return np.std(points_obj[:, [0, 1]] - points_obj[:, [0, 1]].mean(axis=0))
80+
points_obj = _gen_object_points(stop, -1, rays)
81+
return calculate_rms(points_obj)
7382

7483
def spotdiagram(geo_params: List[Dict],
7584
rays: List[ray],
@@ -88,8 +97,9 @@ def spotdiagram(geo_params: List[Dict],
8897
"""
8998

9099
stop = geometry(geo_params[-1])
91-
X, Y, points_obj = _gen_object_points(stop, -1, rays)
92-
rms = np.std(points_obj[:, [0, 1]] - points_obj[:, [0, 1]].mean(axis=0))
100+
points_obj = _gen_object_points(stop, -1, rays)
101+
X, Y = points_obj[:, 0], points_obj[:, 1]
102+
rms = calculate_rms(points_obj)
93103

94104
plt.subplot(1, 1, 1, aspect='equal')
95105
plt.locator_params(axis='x', nbins=8)

0 commit comments

Comments
 (0)