Skip to content

Commit 47973e0

Browse files
committed
remove DiameterBandSurface
1 parent fa8330b commit 47973e0

File tree

5 files changed

+150
-115
lines changed

5 files changed

+150
-115
lines changed

src/torchlensmaker/core/collision_detection.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ def __call__(
208208
history: bool = False,
209209
) -> CollisionDetectionResult:
210210
"""
211-
tmin and tmax must be such that all points within P+tminV and P+tmaxV are within the valid domain of the implicit surface F function
211+
tmin and tmax must be such that all points within P+tminV and P+tmaxV
212+
are within the valid domain of the implicit surface F function
212213
213214
TODO check we are safe for any ordering of tmin, tmax, i.e. if tmax < tmin
214215
"""
@@ -223,9 +224,7 @@ def __call__(
223224
assert tmin.dim() == tmax.dim() == 1
224225

225226
# Tensor dimensions
226-
N = P.shape[0] # Number of rays
227-
D = P.shape[1] # Rays dimension (2 or 3)
228-
B = self.B
227+
(N, _), B = P.shape, self.B
229228

230229
# Initialize solutions t
231230
t_sample = torch.linspace(0., 1., B, dtype=P.dtype)
@@ -240,15 +239,19 @@ def __call__(
240239
history_coarse = torch.zeros((B, N, self.num_iterA), dtype=surface.dtype)
241240
history_fine = torch.zeros((N, self.num_iterB), dtype=surface.dtype)
242241

243-
br = surface.bounding_radius()
242+
max_delta = torch.abs(tmax - tmin) / B
244243

245244
with torch.no_grad():
246245
# Iteration tensor t
247246
t = init_t
248247

249248
# Coarse phase (multiple beams)
250249
for ia in range(self.num_iterA):
251-
t = t - self.algoA.delta(surface, P, V, t, max_delta=br / B)
250+
t = torch.clamp(
251+
t - self.algoA.delta(surface, P, V, t, max_delta),
252+
min=tmin,
253+
max=tmax,
254+
)
252255
if history:
253256
history_coarse[:, :, ia] = t.clone()
254257

@@ -281,14 +284,20 @@ def __call__(
281284

282285
# Fine phase (single beam)
283286
for ib in range(self.num_iterB):
284-
t = t - self.algoB.delta(
285-
surface, P, V, t, max_delta=br / (B * self.num_iterA)
287+
t = torch.clamp(
288+
t - self.algoB.delta(surface, P, V, t, max_delta),
289+
min=tmin,
290+
max=tmax,
286291
)
287292
if history:
288293
history_fine[:, ib] = t
289294

290295
# Differentiable phase: one iteration for backwards pass
291-
t = t - self.algoC.delta(surface, P, V, t, max_delta=br)
296+
t = torch.clamp(
297+
t - self.algoC.delta(surface, P, V, t, max_delta=max_delta),
298+
min=tmin,
299+
max=tmax,
300+
)
292301

293302
return CollisionDetectionResult(
294303
t[0, :],

src/torchlensmaker/core/surfaces.py

Lines changed: 34 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -228,19 +228,6 @@ def __init__(
228228
def bcyl(self) -> Tensor:
229229
raise NotImplementedError
230230

231-
def contains(self, points: Tensor, tol: Optional[float] = None) -> Tensor:
232-
if tol is None:
233-
tol = {torch.float32: 1e-4, torch.float64: 1e-7}[self.dtype]
234-
235-
# TODO if removing diameterbandsurface, check bounds/bcyl before calling F
236-
# because points might be out of g() domain
237-
238-
dim = points.shape[1]
239-
240-
F = self.F if dim == 3 else self.f
241-
242-
return torch.abs(F(points)) < tol
243-
244231
def rmse(self, points: Tensor) -> float:
245232
N = sum(points.shape[:-1])
246233
return torch.sqrt(torch.sum(self.Fd(points) ** 2) / N).item()
@@ -337,54 +324,6 @@ def F_grad(self, points: Tensor) -> Tensor:
337324
raise NotImplementedError
338325

339326

340-
class DiameterBandSurfaceSq(ImplicitSurface):
341-
"Distance to edge points"
342-
343-
def __init__(self, Ax: Tensor, Ar: Tensor, dtype: torch.dtype = torch.float64):
344-
super().__init__(dtype=dtype)
345-
self.Ax = Ax
346-
self.Ar = Ar
347-
348-
def f(self, points: Tensor) -> Tensor:
349-
assert points.shape[-1] == 2
350-
X, R = points.unbind(-1)
351-
Ax, Ar = self.Ax, self.Ar
352-
return torch.sqrt((X - Ax) ** 2 + (torch.abs(R) - Ar) ** 2)
353-
354-
def f_grad(self, points: Tensor) -> Tensor:
355-
assert points.shape[-1] == 2
356-
X, R = points.unbind(-1)
357-
Ax, Ar = self.Ax, self.Ar
358-
sq = self.f(points)
359-
return torch.stack(
360-
((X - Ax) / sq, torch.sign(R) * (torch.abs(R) - Ar) / sq), dim=-1
361-
)
362-
363-
def F(self, points: Tensor) -> Tensor:
364-
assert points.shape[-1] == 3
365-
X, Y, Z = points.unbind(-1)
366-
R2 = Y**2 + Z**2
367-
Ax, Ar = self.Ax, self.Ar
368-
return torch.sqrt((X - Ax) ** 2 + (torch.sqrt(R2) - Ar) ** 2)
369-
370-
def F_grad(self, points: Tensor) -> Tensor:
371-
assert points.shape[-1] == 3
372-
X, Y, Z = points.unbind(-1)
373-
R2 = Y**2 + Z**2
374-
Ax, Ar = self.Ax, self.Ar
375-
sq = self.F(points)
376-
sqr2 = torch.sqrt(R2)
377-
quot = (sqr2 - Ar) / (sqr2 * sq)
378-
return torch.stack(
379-
(
380-
(X - Ax) / sq,
381-
Y * quot,
382-
Z * quot,
383-
),
384-
dim=-1,
385-
)
386-
387-
388327
class SagSurface(ImplicitSurface):
389328
"""
390329
Axially symmetric implicit surface defined by a sag function.
@@ -416,16 +355,10 @@ def __init__(
416355
def mask_function(self, points: Tensor) -> Tensor:
417356
return within_radius(self.diameter / 2, points)
418357

419-
def fallback_surface(self) -> DiameterBandSurfaceSq:
420-
return DiameterBandSurfaceSq(
421-
Ax=self.extent_x(),
422-
Ar=torch.as_tensor(self.diameter / 2, dtype=self.dtype),
423-
dtype=self.dtype,
424-
)
425-
426358
def parameters(self) -> dict[str, nn.Parameter]:
427359
return self.sag_function.parameters()
428360

361+
# TODO remove?
429362
def bounding_radius(self) -> float:
430363
"""
431364
Any point on the surface has a distance to the center that is less
@@ -438,47 +371,28 @@ def tau(self) -> Tensor:
438371
return torch.as_tensor(self.diameter / 2, dtype=self.dtype)
439372

440373
def f(self, points: Tensor) -> Tensor:
374+
"points are assumed to be within the bcyl domain"
441375
assert points.shape[-1] == 2
442376
x, r = points.unbind(-1)
443-
sag_f = self.sag_function.g(r, self.tau()) - x
444-
mask = self.mask_function(points)
445-
fallback = self.fallback_surface()
446-
return torch.where(mask, sag_f, fallback.f(points))
377+
return self.sag_function.g(r, self.tau()) - x
447378

448379
def f_grad(self, points: Tensor) -> Tensor:
449380
assert points.shape[-1] == 2
450381
x, r = points.unbind(-1)
451-
sag_f_grad = torch.stack(
382+
return torch.stack(
452383
(-torch.ones_like(x), self.sag_function.g_grad(r, self.tau())), dim=-1
453384
)
454-
mask = self.mask_function(points)
455-
fallback = self.fallback_surface()
456-
return torch.where(
457-
mask.unsqueeze(-1).expand(*mask.size(), 2),
458-
sag_f_grad,
459-
fallback.f_grad(points),
460-
)
461385

462386
def F(self, points: Tensor) -> Tensor:
463387
assert points.shape[-1] == 3
464388
x, y, z = points.unbind(-1)
465-
sag_F = self.sag_function.G(y, z, self.tau()) - x
466-
mask = self.mask_function(points)
467-
fallback = self.fallback_surface()
468-
return torch.where(mask, sag_F, fallback.F(points))
389+
return self.sag_function.G(y, z, self.tau()) - x
469390

470391
def F_grad(self, points: Tensor) -> Tensor:
471392
assert points.shape[-1] == 3
472393
x, y, z = points.unbind(-1)
473394
grad_y, grad_z = self.sag_function.G_grad(y, z, self.tau())
474-
sag_F_grad = torch.stack((-torch.ones_like(x), grad_y, grad_z), dim=-1)
475-
mask = self.mask_function(points)
476-
fallback = self.fallback_surface()
477-
return torch.where(
478-
mask.unsqueeze(-1).expand(*mask.size(), 3),
479-
sag_F_grad,
480-
fallback.F_grad(points),
481-
)
395+
return torch.stack((-torch.ones_like(x), grad_y, grad_z), dim=-1)
482396

483397
def extent_x(self) -> Tensor:
484398
return torch.max(torch.abs(self.sag_function.bounds(self.tau())))
@@ -497,6 +411,34 @@ def bcyl(self) -> Tensor:
497411
dim=0,
498412
)
499413

414+
def contains(self, points: Tensor, tol: Optional[float] = None) -> Tensor:
415+
if tol is None:
416+
tol = {torch.float32: 1e-4, torch.float64: 1e-7}[self.dtype]
417+
418+
N, dim = points.shape
419+
420+
# Check points are within the diameter
421+
r2 = points[:, 1] if dim == 2 else points[:, 1] ** 2 + points[:, 2] ** 2
422+
within_diameter = r2 <= self.diameter**2
423+
424+
tau = self.tau()
425+
zeros1d = torch.zeros_like(points[:, 1])
426+
zeros2d = torch.zeros_like(r2)
427+
428+
# If within diameter, check the sag equation x = g(r)
429+
if dim == 2:
430+
safe_input = torch.where(within_diameter, torch.sqrt(r2), zeros2d)
431+
sagG = self.sag_function.g(safe_input, tau)
432+
G = torch.where(within_diameter, sagG, zeros2d)
433+
else:
434+
safe_input_y = torch.where(within_diameter, points[:, 1], zeros1d)
435+
safe_input_z = torch.where(within_diameter, points[:, 2], zeros1d)
436+
sagG = self.sag_function.G(safe_input_y, safe_input_z, tau)
437+
G = torch.where(within_diameter, sagG, zeros2d)
438+
439+
within_tol = torch.abs(G - points[:, 0]) < tol
440+
return torch.logical_and(within_diameter, within_tol)
441+
500442
def samples2D_full(self, N, epsilon):
501443
start = -(1 - epsilon) * self.diameter / 2
502444
end = (1 - epsilon) * self.diameter / 2

src/torchlensmaker/optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def optimize(
112112

113113
if i % show_every == 0 or i == num_iter - 1:
114114
iter_str = f"[{i + 1:>3}/{num_iter}]"
115-
L_str = f"L= {loss.item():>6.3f} | grad norm= {torch.linalg.norm(grad)}"
115+
L_str = f"L= {loss.item():>6.5f} | grad norm= {torch.linalg.norm(grad)}"
116116
print(f"{iter_str} {L_str}")
117117

118118
return OptimizationRecord(num_iter, parameters_record, loss_record, optics)

test_notebooks/XYPolynomial.ipynb

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
"\n",
8686
"optics = tlm.Sequential(\n",
8787
" tlm.PointSourceAtInfinity(6.0),\n",
88-
" tlm.Gap(0.5),\n",
88+
" tlm.Gap(5),\n",
8989
" *elements\n",
9090
")\n",
9191
"\n",
@@ -109,6 +109,7 @@
109109
"import matplotlib.pyplot as plt\n",
110110
"import json\n",
111111
"\n",
112+
"\n",
112113
"C = tlm.parameter(torch.tensor([[0.0, 0.0], [0.0, 0.0]], dtype=torch.float64))\n",
113114
"#surface = tlm.SagSurface(10, tlm.XYPolynomial(C, normalize=True))\n",
114115
"\n",
@@ -117,9 +118,69 @@
117118
"# nan in grad with XYPolynomial\n",
118119
"# nan in grad with Parabola\n",
119120
"\n",
120-
"# surface = tlm.Sphere(10, R=tlm.parameter(15))\n",
121+
"#surface = tlm.SphereR(10, R=tlm.parameter(15))\n",
121122
"surface = tlm.Parabola(10, tlm.parameter(0.02))\n",
122123
"\n",
124+
"x = tlm.parameter(50.)\n",
125+
"\n",
126+
"optics = tlm.Sequential(\n",
127+
" tlm.PointSourceAtInfinity(8.0),\n",
128+
" tlm.Gap(10),\n",
129+
" tlm.RefractiveSurface(surface, material=\"water-nd\"),\n",
130+
" tlm.Gap(x),\n",
131+
" tlm.FocalPoint()\n",
132+
")\n",
133+
"\n",
134+
"tlm.show3d(optics, sampling={\"base\": 100})\n",
135+
"\n",
136+
"### get gradient\n",
137+
"sampling = {\"base\": 100}\n",
138+
"default_input = tlm.default_input(sampling, dim=3, dtype=torch.float64)\n",
139+
"outputs = optics(default_input)\n",
140+
"loss = outputs.loss\n",
141+
"loss.backward()\n",
142+
"\n",
143+
"params = list(surface.parameters().values())\n",
144+
"\n",
145+
"print(surface.to_dict(3))\n",
146+
"for param in params:\n",
147+
" print(param, param.grad)\n",
148+
" print()\n",
149+
"\n",
150+
"print(x, x.grad)\n",
151+
"\n",
152+
"\n",
153+
"tlm.show3d(optics, sampling={\"base\": 100})\n"
154+
]
155+
},
156+
{
157+
"cell_type": "code",
158+
"execution_count": null,
159+
"id": "b49cfbb7-81f8-42c9-a53a-446d8725ae0b",
160+
"metadata": {},
161+
"outputs": [],
162+
"source": [
163+
"## XYPolynomial optimization\n",
164+
"\n",
165+
"import torchlensmaker as tlm\n",
166+
"import torch\n",
167+
"import torch.nn as nn\n",
168+
"import torch.optim as optim\n",
169+
"import numpy as np\n",
170+
"import matplotlib.pyplot as plt\n",
171+
"import json\n",
172+
"\n",
173+
"\n",
174+
"C = tlm.parameter(torch.zeros((13,13), dtype=torch.float64))\n",
175+
"surface = tlm.SagSurface(10, tlm.XYPolynomial(C, normalize=True))\n",
176+
"\n",
177+
"fixed_mask = torch.zeros_like(C, dtype=torch.bool)\n",
178+
"fixed_mask[0, 0] = True # Freeze position (0,0)\n",
179+
"\n",
180+
"C.register_hook(lambda grad: grad.masked_fill(fixed_mask, 0.))\n",
181+
"\n",
182+
"surface = tlm.Sphere(10, R=tlm.parameter(15))\n",
183+
"\n",
123184
"optics = tlm.Sequential(\n",
124185
" tlm.PointSourceAtInfinity(8.0),\n",
125186
" tlm.Gap(10),\n",
@@ -132,19 +193,30 @@
132193
"\n",
133194
"tlm.optimize(\n",
134195
" optics,\n",
135-
" optimizer = optim.Adam(optics.parameters(), lr=.1),\n",
136-
" sampling = {\"base\": 100},\n",
196+
" optimizer = tlm.optim.Adam(optics.parameters(), lr=2e-3),\n",
197+
" sampling = {\"base\": 10},\n",
137198
" dim = 3,\n",
138-
" num_iter = 100\n",
199+
" num_iter = 50\n",
139200
").plot()\n",
140201
"\n",
141-
"tlm.show3d(optics, sampling={\"base\": 100})\n"
202+
"tlm.optimize(\n",
203+
" optics,\n",
204+
" optimizer = tlm.optim.Adam(optics.parameters(), lr=1e-4),\n",
205+
" sampling = {\"base\": 10},\n",
206+
" dim = 3,\n",
207+
" num_iter = 500\n",
208+
").plot()\n",
209+
"\n",
210+
"print(surface.parameters())\n",
211+
"\n",
212+
"tlm.show3d(optics, sampling={\"base\": 100})\n",
213+
"\n"
142214
]
143215
},
144216
{
145217
"cell_type": "code",
146218
"execution_count": null,
147-
"id": "b49cfbb7-81f8-42c9-a53a-446d8725ae0b",
219+
"id": "cd2555b0-878c-445d-b58e-9bee05bb112c",
148220
"metadata": {},
149221
"outputs": [],
150222
"source": []

0 commit comments

Comments
 (0)