@@ -60,17 +60,17 @@ struct _svd_config<double> {
6060 static constexpr int JACOBI_ITERATIONS = 8 ;
6161};
6262
63-
64-
65- // TODO: replace sqrt with rsqrt
66-
67- template <typename Type>
68- inline CUDA_CALLABLE
69- Type accurateSqrt (Type x)
63+ template <typename Type> inline CUDA_CALLABLE Type recipSqrt (Type x)
7064{
71- return x / sqrt (x);
65+ #if defined(__CUDA_ARCH__)
66+ return ::rsqrt (x);
67+ #else
68+ return Type (1 ) / sqrt (x);
69+ #endif
7270}
7371
72+ template <> inline CUDA_CALLABLE wp::half recipSqrt (wp::half x) { return wp::half (1 ) / sqrt (x); }
73+
7474template <typename Type>
7575inline CUDA_CALLABLE
7676void condSwap (bool c, Type &X, Type &Y)
@@ -175,7 +175,7 @@ void approximateGivensQuaternion(Type a11, Type a12, Type a22, Type &ch, Type &s
175175 ch = Type (2 )*(a11-a22);
176176 sh = a12;
177177 bool b = Type (_gamma)*sh*sh < ch*ch;
178- Type w = Type ( 1 ) / sqrt (ch*ch+sh*sh);
178+ Type w = recipSqrt (ch*ch+sh*sh);
179179 ch=b?w*ch:Type (_cstar);
180180 sh=b?w*sh:Type (_sstar);
181181}
@@ -304,13 +304,13 @@ void QRGivensQuaternion(Type a1, Type a2, Type &ch, Type &sh)
304304 // a1 = pivot point on diagonal
305305 // a2 = lower triangular entry we want to annihilate
306306 const Type epsilon = _svd_config<Type>::QR_GIVENS_EPSILON ;
307- Type rho = accurateSqrt (a1*a1 + a2*a2);
307+ Type rho = sqrt (a1*a1 + a2*a2);
308308
309309 sh = rho > epsilon ? a2 : Type (0 );
310310 ch = abs (a1) + max (rho,epsilon);
311311 bool b = a1 < Type (0 );
312312 condSwap (b,sh,ch);
313- Type w = Type ( 1 ) / sqrt (ch*ch+sh*sh);
313+ Type w = recipSqrt (ch*ch+sh*sh);
314314 ch *= w;
315315 sh *= w;
316316}
@@ -432,21 +432,15 @@ void _svd(// input A
432432 );
433433}
434434
435-
436- template <typename Type>
437- inline CUDA_CALLABLE
438- void _svd_2 (// input A
439- Type a11, Type a12,
440- Type a21, Type a22,
441- // output U
442- Type &u11, Type &u12,
443- Type &u21, Type &u22,
444- // output S
445- Type &s11, Type &s12,
446- Type &s21, Type &s22,
447- // output V
448- Type &v11, Type &v12,
449- Type &v21, Type &v22)
435+ template <typename Type>
436+ inline CUDA_CALLABLE void _svd_2 ( // input A
437+ Type a11, Type a12, Type a21, Type a22,
438+ // output U
439+ Type& u11, Type& u12, Type& u21, Type& u22,
440+ // output S
441+ Type& s1, Type& s2,
442+ // output V
443+ Type& v11, Type& v12, Type& v21, Type& v22)
450444{
451445 // Step 1: Compute ATA
452446 Type ATA11 = a11 * a11 + a21 * a21;
@@ -455,39 +449,56 @@ void _svd_2(// input A
455449
456450 // Step 2: Eigenanalysis
457451 Type trace = ATA11 + ATA22 ;
458- Type det = ATA11 * ATA22 - ATA12 * ATA12 ;
459- Type sqrt_term = sqrt (trace * trace - Type (4.0 ) * det);
460- Type lambda1 = (trace + sqrt_term) * Type (0.5 );
461- Type lambda2 = (trace - sqrt_term) * Type (0.5 );
452+ Type diff = ATA11 - ATA22 ;
453+ Type discriminant = diff * diff + Type (4 ) * ATA12 * ATA12 ;
462454
463455 // Step 3: Singular values
464- Type sigma1 = sqrt (lambda1);
456+ if (discriminant == Type (0 ))
457+ {
458+ // Duplicate eigenvalue, A ~ s Id
459+ s1 = s2 = sqrt (Type (0.5 ) * trace);
460+ u11 = v11 = Type (1 );
461+ u12 = v12 = Type (0 );
462+ u21 = v21 = Type (0 );
463+ u22 = v22 = Type (1 );
464+ return ;
465+ }
466+
467+ // General case
468+ Type sqrt_term = sqrt (discriminant);
469+ Type lambda1 = (trace + sqrt_term) * Type (0.5 );
470+ Type lambda2 = (trace - sqrt_term) * Type (0.5 );
471+ Type inv_sigma1 = recipSqrt (lambda1);
472+ Type sigma1 = Type (1 ) / inv_sigma1;
465473 Type sigma2 = sqrt (lambda2);
466474
467475 // Step 4: Eigenvectors (find V)
468- Type v1x = ATA12 , v1y = lambda1 - ATA11 ; // For first eigenvector
469- Type v2x = ATA12 , v2y = lambda2 - ATA11 ; // For second eigenvector
470- Type norm1 = sqrt (v1x * v1x + v1y * v1y);
471- Type norm2 = sqrt (v2x * v2x + v2y * v2y);
472-
473- v11 = v1x / norm1; v12 = v2x / norm2;
474- v21 = v1y / norm1; v22 = v2y / norm2;
476+ Type v1y = diff - sqrt_term + Type (2 ) * ATA12 , v1x = diff + sqrt_term - Type (2 ) * ATA12 ;
477+ Type len1_sq = v1x * v1x + v1y * v1y;
478+ if (len1_sq == Type (0 )) {
479+ v11 = Type (0.707106781186547524401 ); // M_SQRT1_2
480+ v21 = v11;
481+ } else {
482+ Type inv_len1 = recipSqrt (len1_sq);
483+ v11 = v1x * inv_len1;
484+ v21 = v1y * inv_len1;
485+ }
486+ v12 = -v21;
487+ v22 = v11;
475488
476489 // Step 5: Compute U
477- Type inv_sigma1 = (sigma1 > Type (1e-6 )) ? Type (1.0 ) / sigma1 : Type (0.0 );
478- Type inv_sigma2 = (sigma2 > Type (1e-6 )) ? Type (1.0 ) / sigma2 : Type (0.0 );
479-
480490 u11 = (a11 * v11 + a12 * v21) * inv_sigma1;
481- u12 = (a11 * v12 + a12 * v22) * inv_sigma2;
482491 u21 = (a21 * v11 + a22 * v21) * inv_sigma1;
483- u22 = (a21 * v12 + a22 * v22) * inv_sigma2;
492+ // sigma2 may be zero, but we can complete U orthogonally up to determinant's sign
493+ Type det_sign = wp::sign (a11 * a22 - a12 * a21);
494+ u12 = -u21 * det_sign;
495+ u22 = u11 * det_sign;
484496
485497 // Step 6: Set S
486- s11 = sigma1; s12 = Type ( 0.0 ) ;
487- s21 = Type ( 0.0 ); s22 = sigma2;
498+ s1 = sigma1;
499+ s2 = sigma2;
488500}
489501
490-
491502template <typename Type>
492503inline CUDA_CALLABLE void svd3 (const mat_t <3 ,3 ,Type>& A, mat_t <3 ,3 ,Type>& U, vec_t <3 ,Type>& sigma, mat_t <3 ,3 ,Type>& V) {
493504 Type s12, s13, s21, s23, s31, s32;
@@ -550,15 +561,14 @@ inline CUDA_CALLABLE void adj_svd3(const mat_t<3,3,Type>& A,
550561
551562template <typename Type>
552563inline CUDA_CALLABLE void svd2 (const mat_t <2 ,2 ,Type>& A, mat_t <2 ,2 ,Type>& U, vec_t <2 ,Type>& sigma, mat_t <2 ,2 ,Type>& V) {
553- Type s12, s21;
554564 _svd_2 (A.data [0 ][0 ], A.data [0 ][1 ],
555565 A.data [1 ][0 ], A.data [1 ][1 ],
556566
557567 U.data [0 ][0 ], U.data [0 ][1 ],
558568 U.data [1 ][0 ], U.data [1 ][1 ],
559569
560- sigma[0 ], s12,
561- s21, sigma[1 ],
570+ sigma[0 ],
571+ sigma[1 ],
562572
563573 V.data [0 ][0 ], V.data [0 ][1 ],
564574 V.data [1 ][0 ], V.data [1 ][1 ]);
0 commit comments