Skip to content

Commit d4d0e16

Browse files
committed
intersect: special case for zero rays
1 parent 145f55e commit d4d0e16

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

src/torchlensmaker/core/intersect.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def intersect(
5050
assert P.shape[1] == V.shape[1]
5151
assert P.shape[1] in {2, 3}
5252

53+
# Special case for zero rays
54+
if P.shape[0] == 0:
55+
return torch.zeros_like(P), torch.zeros_like(V), torch.full((P.shape[0],), False)
56+
5357
# Convert rays to surface local frame
5458
Ps = transform.inverse_points(P)
5559
Vs = transform.inverse_vectors(V)

src/torchlensmaker/core/sag_functions.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,13 @@ def __init__(self, coefficients: Tensor, normalize: bool = False):
451451
self.p = torch.arange(self.coefficients.shape[0])
452452
self.q = torch.arange(self.coefficients.shape[1])
453453

454+
def parameters(self) -> dict[str, nn.Parameter]:
455+
return (
456+
{"coefficients": self.coefficients}
457+
if isinstance(self.coefficients, nn.Parameter)
458+
else {}
459+
)
460+
454461
def unnorm(self, tau: Tensor) -> Tensor:
455462
if self.normalize:
456463
taup = torch.pow(tau, self.p)
@@ -459,13 +466,10 @@ def unnorm(self, tau: Tensor) -> Tensor:
459466
return self.coefficients * tau / denom
460467
else:
461468
return self.coefficients
462-
463-
def parameters(self) -> dict[str, nn.Parameter]:
464-
return (
465-
{"coefficients": self.coefficients}
466-
if isinstance(self.coefficients, nn.Parameter)
467-
else {}
468-
)
469+
470+
def bounds(self, tau: Tensor) -> Tensor:
471+
# TODO
472+
return torch.tensor([-1., 1., 1.], dtype=self.coefficients.dtype)
469473

470474
def G(self, y: Tensor, z: Tensor, tau: Tensor) -> Tensor:
471475
C = bbroad(self.unnorm(tau), y.dim())

src/torchlensmaker/core/surfaces.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,8 @@ def __init__(
571571

572572

573573
# TODO
574-
class Conic(SagSurface):
575-
...
574+
class Conic(SagSurface): ...
575+
576576

577577
class Asphere(SagSurface):
578578
"""

0 commit comments

Comments
 (0)