Skip to content

Commit 0cdfc9a

Browse files
committed
optical data: replace material by index
1 parent d4b7c98 commit 0cdfc9a

File tree

3 files changed

+30
-35
lines changed

3 files changed

+30
-35
lines changed

src/torchlensmaker/elements/light_sources.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,25 +92,25 @@ def forward(self, inputs: OpticalData) -> OpticalData:
9292
rays_object=cat_optional(inputs.rays_object, rays_object),
9393
var_base=var_base,
9494
var_object=var_object,
95-
material=self.material,
9695
)
9796

9897
# now sample wavelength
99-
100-
if non_chromatic.rays_wavelength is not None:
101-
raise RuntimeError("Rays already have wavelength data")
102-
10398
if "wavelength" not in non_chromatic.sampling:
10499
raise RuntimeError("Missing 'wavelength' key in sampling configuration")
105100

106101
chromatic_space = non_chromatic.sampling["wavelength"].sample1d(
107102
self.wavelength_lower, self.wavelength_upper, non_chromatic.dtype
108103
)
109104

110-
return cartesian_wavelength(non_chromatic, chromatic_space).replace(
105+
chromatic = cartesian_wavelength(non_chromatic, chromatic_space).replace(
111106
var_wavelength=chromatic_space
112107
)
113108

109+
# index of refraction
110+
rays_index = self.material.refractive_index(chromatic.rays_wavelength)
111+
112+
return chromatic.replace(rays_index=rays_index)
113+
114114

115115
class RaySource(LightSourceBase):
116116
"""

src/torchlensmaker/elements/optical_surfaces.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -237,30 +237,18 @@ def forward(self, data: OpticalData) -> tuple[OpticalData, Tensor]:
237237
# Zero rays special case
238238
# (needs to happen after self.collision_surface is called to enable rendering of it)
239239
if data.P.numel() == 0:
240-
return data.replace(
241-
material=self.material, dfk=new_dfk, ifk=new_ifk
242-
), torch.full((data.P.shape[0],), True)
243-
244-
# Compute indices of refraction (scalars for non dispersive materials,
245-
# tensors for dispersive materials)
246-
if (
247-
data.rays_wavelength is None
248-
): # TODO this should be empty but not None when number of rays is zero
249-
if not isinstance(data.material, NonDispersiveMaterial) or not isinstance(
250-
self.material, NonDispersiveMaterial
251-
):
252-
raise RuntimeError(
253-
f"Cannot compute refraction with dispersive material "
254-
f"because optical data has no wavelength variable "
255-
f"(got materials {data.material} and {self.material})"
256-
)
257-
258-
n1 = torch.as_tensor(data.material.n)
259-
n2 = torch.as_tensor(self.material.n)
240+
return data.replace(dfk=new_dfk, ifk=new_ifk), torch.full(
241+
(data.P.shape[0],), True
242+
)
260243

261-
else:
262-
n1 = data.material.refractive_index(data.rays_wavelength)
263-
n2 = self.material.refractive_index(data.rays_wavelength)
244+
# Compute indices of refraction
245+
n1 = data.rays_index
246+
n2 = self.material.refractive_index(data.rays_wavelength)
247+
assert n1.shape == n2.shape == (data.P.shape[0],), (
248+
n1.shape,
249+
n2.shape,
250+
data.P.shape,
251+
)
264252

265253
# Snell's law happens here
266254
# Compute refraction on the full frame rays (including non-colliding
@@ -282,6 +270,7 @@ def forward(self, data: OpticalData) -> tuple[OpticalData, Tensor]:
282270
new_rays_wavelength = filter_optional_tensor(
283271
data.rays_wavelength, both_valid
284272
)
273+
new_rays_index = filter_optional_tensor(n2, both_valid)
285274
else:
286275
# keep tir rays
287276
new_P = collision_points[valid_collision]
@@ -291,14 +280,15 @@ def forward(self, data: OpticalData) -> tuple[OpticalData, Tensor]:
291280
new_rays_wavelength = filter_optional_tensor(
292281
data.rays_wavelength, valid_collision
293282
)
283+
new_rays_index = filter_optional_tensor(n2, valid_collision)
294284

295285
return data.replace(
296286
P=new_P,
297287
V=new_V,
298288
rays_base=new_rays_base,
299289
rays_object=new_rays_object,
300290
rays_wavelength=new_rays_wavelength,
301-
material=self.material,
291+
rays_index=new_rays_index,
302292
dfk=new_dfk,
303293
ifk=new_ifk,
304294
), valid_refraction
@@ -328,6 +318,9 @@ def forward(self, data: OpticalData) -> OpticalData:
328318
rays_wavelength=filter_optional_tensor(
329319
data.rays_wavelength, valid_collision
330320
),
321+
rays_index=filter_optional_tensor(
322+
data.rays_index, valid_collision
323+
),
331324
dfk=new_dfk,
332325
ifk=new_ifk, # correct but useless cause Aperture is only circular plane currently
333326
)

src/torchlensmaker/optical_data.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ class OpticalData:
5050
# Light rays in parametric form: P + tV
5151
P: Float[torch.Tensor, "N D"]
5252
V: Float[torch.Tensor, "N D"]
53+
54+
# Light rays wavelength in nm
5355
rays_wavelength: Float[torch.Tensor, " N"]
56+
57+
# Light rays index of refraction
58+
rays_index: Float[torch.Tensor, " N"]
5459

5560
# Rays variables
5661
# Tensors of shape (N, 2|3) or None
@@ -65,9 +70,6 @@ class OpticalData:
6570
var_object: Optional[torch.Tensor]
6671
var_wavelength: Optional[torch.Tensor]
6772

68-
# Material model for this batch of rays
69-
material: MaterialModel
70-
7173
# Loss accumulator
7274
# Tensor of dim 0
7375
loss: torch.Tensor
@@ -124,13 +126,13 @@ def default_input(
124126
ifk=ifk,
125127
P=torch.empty((0, dim), dtype=dtype),
126128
V=torch.empty((0, dim), dtype=dtype),
129+
rays_wavelength=torch.empty((0,), dtype=dtype),
130+
rays_index=torch.empty((0,), dtype=dtype),
127131
rays_base=None,
128132
rays_object=None,
129133
rays_image=None,
130-
rays_wavelength=None,
131134
var_base=None,
132135
var_object=None,
133136
var_wavelength=None,
134-
material=get_material_model("vacuum"),
135137
loss=torch.tensor(0.0, dtype=dtype),
136138
)

0 commit comments

Comments
 (0)