Skip to content

Commit 8fa8ab2

Browse files
authored
Merge pull request #40 from GNiendorf/struct_of_arrays
Move to SoA Format for TracePy
2 parents ed86a90 + 3f28c5f commit 8fa8ab2

13 files changed

+542
-512
lines changed

tests/test_hyperbolic.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
def test_rms_hyperbolic():
2525
geo = [back_lens, lens, stop]
26-
ray_group = tp.ray_plane(geo, [0., 0., 0.], 1.1, d=[0.,0.,1.], nrays=100)
26+
ray_group = tp.ray_plane(geo, [0., 0., 0.], 1.1, d=[0., 0., 1.], nrays=100)
2727
rms = tp.spot_rms(geo, ray_group)
28-
assert rms == 0.
29-
28+
assert rms == 0.

tests/test_optimizer.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
'name': 'lens',
66
'action': 'refraction',
77
'P': 2.,
8-
'kappa': -.004,
9-
'c': -.5,
8+
'kappa': -0.004,
9+
'c': -0.5,
1010
'Diam': 2.2
1111
}
1212

1313
back_lens = {
1414
'name': 'back_lens',
1515
'action': 'refraction',
1616
'P': 1.,
17-
'c': .5,
18-
'kappa': -.004,
17+
'c': 0.5,
18+
'kappa': -0.004,
1919
'N': 1.5,
2020
'Diam': 2.2
2121
}
@@ -60,6 +60,4 @@ def test_optimizer():
6060
ray_group_3 = tp.ray_plane(geo_opt, [0., 0., 0.], 1.1, d=[0., 0., 1.], nrays=1000)
6161
tp.rayaberration(geo_1, ray_group_3)
6262
except Exception as e:
63-
assert False, f"rayaberration raised an exception: {e}"
64-
65-
63+
assert False, f"rayaberration raised an exception: {e}"

tests/test_parabolic.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
'action': 'reflection',
55
'P': 1.5,
66
'kappa': 0.,
7-
'c': -.5,
7+
'c': -0.5,
88
'Diam': 2.2,
99
}
1010

1111
stop = {
1212
'action': 'stop',
13-
'P': .5,
14-
'Diam': .2
13+
'P': 0.5,
14+
'Diam': 0.2
1515
}
1616

1717
def test_rms_parabolic():
1818
geo = [mirror, stop]
19-
ray_group = tp.ray_plane(geo, [0., 0., -1.5], 1.1, d=[0.,0.,1.], nrays=100)
19+
ray_group = tp.ray_plane(geo, [0., 0., -1.5], 1.1, d=[0., 0., 1.], nrays=100)
2020
rms = tp.spot_rms(geo, ray_group)
21-
assert rms == 0.
21+
assert rms == 0.

tests/test_plotting.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@
2828
'action': 'stop',
2929
'P': np.array([0., -4.6, 2.]),
3030
'D': np.array([0., np.pi/2., 0.]),
31-
'Diam': .8
31+
'Diam': 0.8
3232
}
3333

3434
lensy = -4.
35-
thickness = .1
35+
thickness = 0.1
3636

3737
lens1 = {
3838
'action': 'refraction',
3939
'P': np.array([0., lensy, 2.]),
4040
'D': np.array([0., 3*np.pi/2., 0.]),
4141
'Diam': 0.8,
4242
'kappa': 0.,
43-
'glass': 'ZF4 cdgm', # Modified with glass
43+
'glass': 'ZF4 cdgm',
4444
'c': 0.3
4545
}
4646

@@ -55,7 +55,7 @@
5555

5656
def test_plotting_functions():
5757
geo = [selector, mirror, small_mirror, lens1, lens2, stop]
58-
ray_group = tp.ray_plane(geo, [0., 0., 0.], 1.8, d=[0., 0., 1.], nrays=50, wvl=.55)
58+
ray_group = tp.ray_plane(geo, [0., 0., 0.], 1.8, d=[0., 0., 1.], nrays=50, wvl=0.55)
5959

6060
# Test plotyz
6161
try:
@@ -79,5 +79,4 @@ def test_plotting_functions():
7979
try:
8080
tp.plotobject(geo, ray_group)
8181
except Exception as e:
82-
assert False, f"plotobject raised an exception: {e}"
83-
82+
assert False, f"plotobject raised an exception: {e}"

tracepy/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
optics regime. It features lens optimization from Scipy.
88
TracePy is currently in active development and any collaborators
99
would be welcome.
10-
1110
"""
1211

13-
from .ray import ray
12+
from .ray import RayGroup
1413
from .geometry import geometry
1514
from .optimize import optimize
1615
from .geoplot import plotxz, plotyz, plot2d

tracepy/geometry.py

+76-32
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ class geometry:
3333
If kappa is None then the surface is planar.
3434
kappa < 0 -> hyperboloid
3535
kappa = 0 -> paraboloid
36-
0 < kappa < 1 -> hemelipsoid of revolution about major axis
36+
0 < kappa < 1 -> helioid of revolution about major axis
3737
kappa = 1 -> hemisphere
38-
kappa > 1 -> hemelipsoid of revolution about minor axis
38+
kappa > 1 -> helioid of revolution about minor axis
3939
c (optional): float/int
4040
Vertex curvature of the surface.
4141
If c is 0 then the surface is planar.
@@ -82,27 +82,42 @@ def check_params(self) -> None:
8282
Note that this does not affect the calculation since c is 0.
8383
8484
"""
85-
8685
if self.c != 0:
8786
if self.kappa is None:
8887
raise Exception("Specify a kappa for this conic.")
89-
elif self.kappa>0:
88+
elif self.kappa > 0:
9089
print("Warning: Specified c value is not used when kappa>0")
91-
self.c = np.sqrt(1/(self.kappa*pow(self.Diam/2.,2)))
90+
self.c = np.sqrt(1 / (self.kappa * pow(self.Diam / 2., 2)))
9291
elif self.c == 0 and self.kappa is None:
93-
#Used for planes, does not affect calculations.
92+
# Used for planes, does not affect calculations.
9493
self.kappa = 1.
9594

9695
def get_surface(self, point: np.ndarray) -> Tuple[float, List[float]]:
97-
""" Returns the function and derivitive of a surface for a point. """
96+
""" Returns the function and derivative of a surface for a point. """
9897
return self.conics(point)
9998

10099
def get_surface_plot(self, points: np.ndarray) -> np.ndarray:
101100
""" Returns the function value for an array of points. """
102101
return self.conics_plot(points)
103102

103+
def get_surface_vector(self, points: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
104+
"""Returns function values and derivative vectors for an array of points.
105+
106+
Parameters
107+
----------
108+
points : np.ndarray of shape (N,3)
109+
Array of points.
110+
Returns
111+
-------
112+
func : np.ndarray of shape (N,)
113+
Function values.
114+
deriv : np.ndarray of shape (N,3)
115+
Derivative vectors.
116+
"""
117+
return self.conics_vector(points)
118+
104119
def conics(self, point: np.ndarray) -> Tuple[float, List[float]]:
105-
"""Returns function value and derivitive list for conics and sphere surfaces.
120+
"""Returns function value and derivative list for conics and sphere surfaces.
106121
107122
Note
108123
----
@@ -122,25 +137,58 @@ def conics(self, point: np.ndarray) -> Tuple[float, List[float]]:
122137
function : float
123138
Function value of the surface. A value of 0 corresponds to the ray intersecting
124139
the surface.
125-
derivitive : list of length 3
126-
Function derivitive of the surface at the point given. Used to determine which
127-
direction the ray needs to travel in and the step size to intersect the surface.
128-
140+
derivative : list of length 3
141+
Derivative of the surface at the point given.
129142
"""
130-
131-
X,Y,Z = point
132-
rho = np.sqrt(pow(X,2) + pow(Y, 2))
133-
if rho > self.Diam/2. or rho < self.diam/2.:
143+
X, Y, Z = point
144+
rho = np.sqrt(pow(X, 2) + pow(Y, 2))
145+
if rho > self.Diam / 2. or rho < self.diam / 2.:
134146
raise NotOnSurfaceError()
135-
# Ensure kappa is not None before using it in calculations
136147
if self.kappa is None:
137148
raise ValueError("kappa must not be None for conic calculations")
138-
#Conic equation.
139-
function = Z - self.c*pow(rho, 2)/(1 + pow((1-self.kappa*pow(self.c, 2)*pow(rho,2)), 0.5))
140-
#See Spencer, Murty section on rotational surfaces for definition of E.
141-
E = self.c / pow((1-self.kappa*pow(self.c, 2)*pow(rho,2)), 0.5)
142-
derivitive = [-X*E, -Y*E, 1.]
143-
return function, derivitive
149+
function = Z - self.c * pow(rho, 2) / (1 + pow((1 - self.kappa * pow(self.c, 2) * pow(rho, 2)), 0.5))
150+
E = self.c / pow((1 - self.kappa * pow(self.c, 2) * pow(rho, 2)), 0.5)
151+
derivative = [-X * E, -Y * E, 1.]
152+
return function, derivative
153+
154+
def conics_vector(self, points: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
155+
"""Returns function values and derivative vectors for conics and sphere surfaces.
156+
157+
Vectorized version for an array of points.
158+
159+
Parameters
160+
----------
161+
points : np.ndarray of shape (N,3)
162+
Array of points (X, Y, Z).
163+
164+
Returns
165+
-------
166+
func : np.ndarray of shape (N,)
167+
Function values for each point.
168+
deriv : np.ndarray of shape (N,3)
169+
Derivative vectors for each point.
170+
"""
171+
X = points[:, 0]
172+
Y = points[:, 1]
173+
Z = points[:, 2]
174+
rho = np.sqrt(X**2 + Y**2)
175+
# Initialize outputs
176+
func = np.full(points.shape[0], np.nan)
177+
deriv = np.full((points.shape[0], 3), np.nan)
178+
# Determine valid indices based on aperture
179+
valid = (rho <= self.Diam / 2.) & (rho >= self.diam / 2.)
180+
if self.kappa is None:
181+
raise ValueError("kappa must not be None for conic calculations")
182+
# Calculate for valid points
183+
rho_valid = rho[valid]
184+
sqrt_term = np.sqrt(1 - self.kappa * self.c**2 * rho_valid**2)
185+
denom = 1 + sqrt_term
186+
func[valid] = Z[valid] - self.c * rho_valid**2 / denom
187+
E = self.c / sqrt_term
188+
deriv[valid, 0] = -X[valid] * E
189+
deriv[valid, 1] = -Y[valid] * E
190+
deriv[valid, 2] = 1.
191+
return func, deriv
144192

145193
def conics_plot(self, point: np.ndarray) -> np.ndarray:
146194
"""Returns Z values for an array of points for plotting conics.
@@ -154,18 +202,14 @@ def conics_plot(self, point: np.ndarray) -> np.ndarray:
154202
-------
155203
function : 1d np.array
156204
np.array of Z values for each X,Y pair.
157-
158205
"""
159-
160-
X, Y = point[:,0], point[:,1]
161-
rho = np.sqrt(pow(X,2) + pow(Y,2))
162-
#Initialize Z value array
206+
X, Y = point[:, 0], point[:, 1]
207+
rho = np.sqrt(pow(X, 2) + pow(Y, 2))
163208
function = np.zeros(len(point))
164-
nan_idx = (rho > self.Diam/2.) + (rho < self.diam/2.)
165-
rho = np.sqrt(pow(X[~nan_idx],2) + pow(Y[~nan_idx], 2))
209+
nan_idx = (rho > self.Diam / 2.) + (rho < self.diam / 2.)
210+
rho = np.sqrt(pow(X[~nan_idx], 2) + pow(Y[~nan_idx], 2))
166211
function[nan_idx] = np.nan
167-
# Ensure kappa is not None before using it in calculations
168212
if self.kappa is None:
169213
raise ValueError("kappa must not be None for conic plot calculations")
170-
function[~nan_idx] = self.c*pow(rho, 2)/(1 + pow((1-self.kappa*pow(self.c, 2)*pow(rho,2)), 0.5))
214+
function[~nan_idx] = self.c * pow(rho, 2) / (1 + pow((1 - self.kappa * pow(self.c, 2) * pow(rho, 2)), 0.5))
171215
return function

0 commit comments

Comments
 (0)