@@ -237,30 +237,18 @@ def forward(self, data: OpticalData) -> tuple[OpticalData, Tensor]:
237237 # Zero rays special case
238238 # (needs to happen after self.collision_surface is called to enable rendering of it)
239239 if data .P .numel () == 0 :
240- return data .replace (
241- material = self .material , dfk = new_dfk , ifk = new_ifk
242- ), torch .full ((data .P .shape [0 ],), True )
243-
244- # Compute indices of refraction (scalars for non dispersive materials,
245- # tensors for dispersive materials)
246- if (
247- data .rays_wavelength is None
248- ): # TODO this should be empty but not None when number of rays is zero
249- if not isinstance (data .material , NonDispersiveMaterial ) or not isinstance (
250- self .material , NonDispersiveMaterial
251- ):
252- raise RuntimeError (
253- f"Cannot compute refraction with dispersive material "
254- f"because optical data has no wavelength variable "
255- f"(got materials { data .material } and { self .material } )"
256- )
257-
258- n1 = torch .as_tensor (data .material .n )
259- n2 = torch .as_tensor (self .material .n )
240+ return data .replace (dfk = new_dfk , ifk = new_ifk ), torch .full (
241+ (data .P .shape [0 ],), True
242+ )
260243
261- else :
262- n1 = data .material .refractive_index (data .rays_wavelength )
263- n2 = self .material .refractive_index (data .rays_wavelength )
244+ # Compute indices of refraction
245+ n1 = data .rays_index
246+ n2 = self .material .refractive_index (data .rays_wavelength )
247+ assert n1 .shape == n2 .shape == (data .P .shape [0 ],), (
248+ n1 .shape ,
249+ n2 .shape ,
250+ data .P .shape ,
251+ )
264252
265253 # Snell's law happens here
266254 # Compute refraction on the full frame rays (including non-colliding
@@ -282,6 +270,7 @@ def forward(self, data: OpticalData) -> tuple[OpticalData, Tensor]:
282270 new_rays_wavelength = filter_optional_tensor (
283271 data .rays_wavelength , both_valid
284272 )
273+ new_rays_index = filter_optional_tensor (n2 , both_valid )
285274 else :
286275 # keep tir rays
287276 new_P = collision_points [valid_collision ]
@@ -291,14 +280,15 @@ def forward(self, data: OpticalData) -> tuple[OpticalData, Tensor]:
291280 new_rays_wavelength = filter_optional_tensor (
292281 data .rays_wavelength , valid_collision
293282 )
283+ new_rays_index = filter_optional_tensor (n2 , valid_collision )
294284
295285 return data .replace (
296286 P = new_P ,
297287 V = new_V ,
298288 rays_base = new_rays_base ,
299289 rays_object = new_rays_object ,
300290 rays_wavelength = new_rays_wavelength ,
301- material = self . material ,
291+ rays_index = new_rays_index ,
302292 dfk = new_dfk ,
303293 ifk = new_ifk ,
304294 ), valid_refraction
@@ -328,6 +318,9 @@ def forward(self, data: OpticalData) -> OpticalData:
328318 rays_wavelength = filter_optional_tensor (
329319 data .rays_wavelength , valid_collision
330320 ),
321+ rays_index = filter_optional_tensor (
322+ data .rays_index , valid_collision
323+ ),
331324 dfk = new_dfk ,
332325 ifk = new_ifk , # correct but useless cause Aperture is only circular plane currently
333326 )
0 commit comments