@@ -228,19 +228,6 @@ def __init__(
228228 def bcyl (self ) -> Tensor :
229229 raise NotImplementedError
230230
231- def contains (self , points : Tensor , tol : Optional [float ] = None ) -> Tensor :
232- if tol is None :
233- tol = {torch .float32 : 1e-4 , torch .float64 : 1e-7 }[self .dtype ]
234-
235- # TODO if removing diameterbandsurface, check bounds/bcyl before calling F
236- # because points might be out of g() domain
237-
238- dim = points .shape [1 ]
239-
240- F = self .F if dim == 3 else self .f
241-
242- return torch .abs (F (points )) < tol
243-
244231 def rmse (self , points : Tensor ) -> float :
245232 N = sum (points .shape [:- 1 ])
246233 return torch .sqrt (torch .sum (self .Fd (points ) ** 2 ) / N ).item ()
@@ -337,54 +324,6 @@ def F_grad(self, points: Tensor) -> Tensor:
337324 raise NotImplementedError
338325
339326
340- class DiameterBandSurfaceSq (ImplicitSurface ):
341- "Distance to edge points"
342-
343- def __init__ (self , Ax : Tensor , Ar : Tensor , dtype : torch .dtype = torch .float64 ):
344- super ().__init__ (dtype = dtype )
345- self .Ax = Ax
346- self .Ar = Ar
347-
348- def f (self , points : Tensor ) -> Tensor :
349- assert points .shape [- 1 ] == 2
350- X , R = points .unbind (- 1 )
351- Ax , Ar = self .Ax , self .Ar
352- return torch .sqrt ((X - Ax ) ** 2 + (torch .abs (R ) - Ar ) ** 2 )
353-
354- def f_grad (self , points : Tensor ) -> Tensor :
355- assert points .shape [- 1 ] == 2
356- X , R = points .unbind (- 1 )
357- Ax , Ar = self .Ax , self .Ar
358- sq = self .f (points )
359- return torch .stack (
360- ((X - Ax ) / sq , torch .sign (R ) * (torch .abs (R ) - Ar ) / sq ), dim = - 1
361- )
362-
363- def F (self , points : Tensor ) -> Tensor :
364- assert points .shape [- 1 ] == 3
365- X , Y , Z = points .unbind (- 1 )
366- R2 = Y ** 2 + Z ** 2
367- Ax , Ar = self .Ax , self .Ar
368- return torch .sqrt ((X - Ax ) ** 2 + (torch .sqrt (R2 ) - Ar ) ** 2 )
369-
370- def F_grad (self , points : Tensor ) -> Tensor :
371- assert points .shape [- 1 ] == 3
372- X , Y , Z = points .unbind (- 1 )
373- R2 = Y ** 2 + Z ** 2
374- Ax , Ar = self .Ax , self .Ar
375- sq = self .F (points )
376- sqr2 = torch .sqrt (R2 )
377- quot = (sqr2 - Ar ) / (sqr2 * sq )
378- return torch .stack (
379- (
380- (X - Ax ) / sq ,
381- Y * quot ,
382- Z * quot ,
383- ),
384- dim = - 1 ,
385- )
386-
387-
388327class SagSurface (ImplicitSurface ):
389328 """
390329 Axially symmetric implicit surface defined by a sag function.
@@ -416,16 +355,10 @@ def __init__(
416355 def mask_function (self , points : Tensor ) -> Tensor :
417356 return within_radius (self .diameter / 2 , points )
418357
419- def fallback_surface (self ) -> DiameterBandSurfaceSq :
420- return DiameterBandSurfaceSq (
421- Ax = self .extent_x (),
422- Ar = torch .as_tensor (self .diameter / 2 , dtype = self .dtype ),
423- dtype = self .dtype ,
424- )
425-
426358 def parameters (self ) -> dict [str , nn .Parameter ]:
427359 return self .sag_function .parameters ()
428360
361+ # TODO remove?
429362 def bounding_radius (self ) -> float :
430363 """
431364 Any point on the surface has a distance to the center that is less
@@ -438,47 +371,28 @@ def tau(self) -> Tensor:
438371 return torch .as_tensor (self .diameter / 2 , dtype = self .dtype )
439372
440373 def f (self , points : Tensor ) -> Tensor :
374+ "points are assumed to be within the bcyl domain"
441375 assert points .shape [- 1 ] == 2
442376 x , r = points .unbind (- 1 )
443- sag_f = self .sag_function .g (r , self .tau ()) - x
444- mask = self .mask_function (points )
445- fallback = self .fallback_surface ()
446- return torch .where (mask , sag_f , fallback .f (points ))
377+ return self .sag_function .g (r , self .tau ()) - x
447378
448379 def f_grad (self , points : Tensor ) -> Tensor :
449380 assert points .shape [- 1 ] == 2
450381 x , r = points .unbind (- 1 )
451- sag_f_grad = torch .stack (
382+ return torch .stack (
452383 (- torch .ones_like (x ), self .sag_function .g_grad (r , self .tau ())), dim = - 1
453384 )
454- mask = self .mask_function (points )
455- fallback = self .fallback_surface ()
456- return torch .where (
457- mask .unsqueeze (- 1 ).expand (* mask .size (), 2 ),
458- sag_f_grad ,
459- fallback .f_grad (points ),
460- )
461385
462386 def F (self , points : Tensor ) -> Tensor :
463387 assert points .shape [- 1 ] == 3
464388 x , y , z = points .unbind (- 1 )
465- sag_F = self .sag_function .G (y , z , self .tau ()) - x
466- mask = self .mask_function (points )
467- fallback = self .fallback_surface ()
468- return torch .where (mask , sag_F , fallback .F (points ))
389+ return self .sag_function .G (y , z , self .tau ()) - x
469390
470391 def F_grad (self , points : Tensor ) -> Tensor :
471392 assert points .shape [- 1 ] == 3
472393 x , y , z = points .unbind (- 1 )
473394 grad_y , grad_z = self .sag_function .G_grad (y , z , self .tau ())
474- sag_F_grad = torch .stack ((- torch .ones_like (x ), grad_y , grad_z ), dim = - 1 )
475- mask = self .mask_function (points )
476- fallback = self .fallback_surface ()
477- return torch .where (
478- mask .unsqueeze (- 1 ).expand (* mask .size (), 3 ),
479- sag_F_grad ,
480- fallback .F_grad (points ),
481- )
395+ return torch .stack ((- torch .ones_like (x ), grad_y , grad_z ), dim = - 1 )
482396
483397 def extent_x (self ) -> Tensor :
484398 return torch .max (torch .abs (self .sag_function .bounds (self .tau ())))
@@ -497,6 +411,34 @@ def bcyl(self) -> Tensor:
497411 dim = 0 ,
498412 )
499413
414+ def contains (self , points : Tensor , tol : Optional [float ] = None ) -> Tensor :
415+ if tol is None :
416+ tol = {torch .float32 : 1e-4 , torch .float64 : 1e-7 }[self .dtype ]
417+
418+ N , dim = points .shape
419+
420+ # Check points are within the diameter
421+ r2 = points [:, 1 ] if dim == 2 else points [:, 1 ] ** 2 + points [:, 2 ] ** 2
422+ within_diameter = r2 <= self .diameter ** 2
423+
424+ tau = self .tau ()
425+ zeros1d = torch .zeros_like (points [:, 1 ])
426+ zeros2d = torch .zeros_like (r2 )
427+
428+ # If within diameter, check the sag equation x = g(r)
429+ if dim == 2 :
430+ safe_input = torch .where (within_diameter , torch .sqrt (r2 ), zeros2d )
431+ sagG = self .sag_function .g (safe_input , tau )
432+ G = torch .where (within_diameter , sagG , zeros2d )
433+ else :
434+ safe_input_y = torch .where (within_diameter , points [:, 1 ], zeros1d )
435+ safe_input_z = torch .where (within_diameter , points [:, 2 ], zeros1d )
436+ sagG = self .sag_function .G (safe_input_y , safe_input_z , tau )
437+ G = torch .where (within_diameter , sagG , zeros2d )
438+
439+ within_tol = torch .abs (G - points [:, 0 ]) < tol
440+ return torch .logical_and (within_diameter , within_tol )
441+
500442 def samples2D_full (self , N , epsilon ):
501443 start = - (1 - epsilon ) * self .diameter / 2
502444 end = (1 - epsilon ) * self .diameter / 2
0 commit comments