Skip to content

Commit 93e3dd1

Browse files
committed
Merge branch 'fix/GH-679-svd2' into 'main'
Fix GH-679: improve svd2 robustness and accuracy See merge request omniverse/warp!1285
2 parents cedd3b3 + 446a1b2 commit 93e3dd1

3 files changed

Lines changed: 106 additions & 65 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
- Fix 2D tile load when source array and tile have incompatible strides
5858
([GH-688](https://github.com/NVIDIA/warp/issues/688)).
5959
- Fixed inconsistency in orientation of 2D geometry side normals ([GH-629](https://github.com/NVIDIA/warp/issues/629)).
60+
- Fixed `wp.svd2()` with duplicate singular values and improved accuracy ([GH-679](https://github.com/NVIDIA/warp/issues/679)).
6061

6162
## [1.7.1] - 2025-04-30
6263

warp/native/svd.h

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,17 @@ struct _svd_config<double> {
6060
static constexpr int JACOBI_ITERATIONS = 8;
6161
};
6262

63-
64-
65-
// TODO: replace sqrt with rsqrt
66-
67-
template<typename Type>
68-
inline CUDA_CALLABLE
69-
Type accurateSqrt(Type x)
63+
template <typename Type> inline CUDA_CALLABLE Type recipSqrt(Type x)
7064
{
71-
return x / sqrt(x);
65+
#if defined(__CUDA_ARCH__)
66+
return ::rsqrt(x);
67+
#else
68+
return Type(1) / sqrt(x);
69+
#endif
7270
}
7371

72+
template <> inline CUDA_CALLABLE wp::half recipSqrt(wp::half x) { return wp::half(1) / sqrt(x); }
73+
7474
template<typename Type>
7575
inline CUDA_CALLABLE
7676
void condSwap(bool c, Type &X, Type &Y)
@@ -175,7 +175,7 @@ void approximateGivensQuaternion(Type a11, Type a12, Type a22, Type &ch, Type &s
175175
ch = Type(2)*(a11-a22);
176176
sh = a12;
177177
bool b = Type(_gamma)*sh*sh < ch*ch;
178-
Type w = Type(1) / sqrt(ch*ch+sh*sh);
178+
Type w = recipSqrt(ch*ch+sh*sh);
179179
ch=b?w*ch:Type(_cstar);
180180
sh=b?w*sh:Type(_sstar);
181181
}
@@ -304,13 +304,13 @@ void QRGivensQuaternion(Type a1, Type a2, Type &ch, Type &sh)
304304
// a1 = pivot point on diagonal
305305
// a2 = lower triangular entry we want to annihilate
306306
const Type epsilon = _svd_config<Type>::QR_GIVENS_EPSILON;
307-
Type rho = accurateSqrt(a1*a1 + a2*a2);
307+
Type rho = sqrt(a1*a1 + a2*a2);
308308

309309
sh = rho > epsilon ? a2 : Type(0);
310310
ch = abs(a1) + max(rho,epsilon);
311311
bool b = a1 < Type(0);
312312
condSwap(b,sh,ch);
313-
Type w = Type(1) / sqrt(ch*ch+sh*sh);
313+
Type w = recipSqrt(ch*ch+sh*sh);
314314
ch *= w;
315315
sh *= w;
316316
}
@@ -432,21 +432,15 @@ void _svd(// input A
432432
);
433433
}
434434

435-
436-
template<typename Type>
437-
inline CUDA_CALLABLE
438-
void _svd_2(// input A
439-
Type a11, Type a12,
440-
Type a21, Type a22,
441-
// output U
442-
Type &u11, Type &u12,
443-
Type &u21, Type &u22,
444-
// output S
445-
Type &s11, Type &s12,
446-
Type &s21, Type &s22,
447-
// output V
448-
Type &v11, Type &v12,
449-
Type &v21, Type &v22)
435+
template <typename Type>
436+
inline CUDA_CALLABLE void _svd_2( // input A
437+
Type a11, Type a12, Type a21, Type a22,
438+
// output U
439+
Type& u11, Type& u12, Type& u21, Type& u22,
440+
// output S
441+
Type& s1, Type& s2,
442+
// output V
443+
Type& v11, Type& v12, Type& v21, Type& v22)
450444
{
451445
// Step 1: Compute ATA
452446
Type ATA11 = a11 * a11 + a21 * a21;
@@ -455,39 +449,56 @@ void _svd_2(// input A
455449

456450
// Step 2: Eigenanalysis
457451
Type trace = ATA11 + ATA22;
458-
Type det = ATA11 * ATA22 - ATA12 * ATA12;
459-
Type sqrt_term = sqrt(trace * trace - Type(4.0) * det);
460-
Type lambda1 = (trace + sqrt_term) * Type(0.5);
461-
Type lambda2 = (trace - sqrt_term) * Type(0.5);
452+
Type diff = ATA11 - ATA22;
453+
Type discriminant = diff * diff + Type(4) * ATA12 * ATA12;
462454

463455
// Step 3: Singular values
464-
Type sigma1 = sqrt(lambda1);
456+
if (discriminant == Type(0))
457+
{
458+
// Duplicate eigenvalue, A ~ s Id
459+
s1 = s2 = sqrt(Type(0.5) * trace);
460+
u11 = v11 = Type(1);
461+
u12 = v12 = Type(0);
462+
u21 = v21 = Type(0);
463+
u22 = v22 = Type(1);
464+
return;
465+
}
466+
467+
// General case
468+
Type sqrt_term = sqrt(discriminant);
469+
Type lambda1 = (trace + sqrt_term) * Type(0.5);
470+
Type lambda2 = (trace - sqrt_term) * Type(0.5);
471+
Type inv_sigma1 = recipSqrt(lambda1);
472+
Type sigma1 = Type(1) / inv_sigma1;
465473
Type sigma2 = sqrt(lambda2);
466474

467475
// Step 4: Eigenvectors (find V)
468-
Type v1x = ATA12, v1y = lambda1 - ATA11; // For first eigenvector
469-
Type v2x = ATA12, v2y = lambda2 - ATA11; // For second eigenvector
470-
Type norm1 = sqrt(v1x * v1x + v1y * v1y);
471-
Type norm2 = sqrt(v2x * v2x + v2y * v2y);
472-
473-
v11 = v1x / norm1; v12 = v2x / norm2;
474-
v21 = v1y / norm1; v22 = v2y / norm2;
476+
Type v1y = diff - sqrt_term + Type(2) * ATA12, v1x = diff + sqrt_term - Type(2) * ATA12;
477+
Type len1_sq = v1x * v1x + v1y * v1y;
478+
if (len1_sq == Type(0)) {
479+
v11 = Type(0.707106781186547524401); // M_SQRT1_2
480+
v21 = v11;
481+
} else {
482+
Type inv_len1 = recipSqrt(len1_sq);
483+
v11 = v1x * inv_len1;
484+
v21 = v1y * inv_len1;
485+
}
486+
v12 = -v21;
487+
v22 = v11;
475488

476489
// Step 5: Compute U
477-
Type inv_sigma1 = (sigma1 > Type(1e-6)) ? Type(1.0) / sigma1 : Type(0.0);
478-
Type inv_sigma2 = (sigma2 > Type(1e-6)) ? Type(1.0) / sigma2 : Type(0.0);
479-
480490
u11 = (a11 * v11 + a12 * v21) * inv_sigma1;
481-
u12 = (a11 * v12 + a12 * v22) * inv_sigma2;
482491
u21 = (a21 * v11 + a22 * v21) * inv_sigma1;
483-
u22 = (a21 * v12 + a22 * v22) * inv_sigma2;
492+
// sigma2 may be zero, but we can complete U orthogonally up to determinant's sign
493+
Type det_sign = wp::sign(a11 * a22 - a12 * a21);
494+
u12 = -u21 * det_sign;
495+
u22 = u11 * det_sign;
484496

485497
// Step 6: Set S
486-
s11 = sigma1; s12 = Type(0.0);
487-
s21 = Type(0.0); s22 = sigma2;
498+
s1 = sigma1;
499+
s2 = sigma2;
488500
}
489501

490-
491502
template<typename Type>
492503
inline CUDA_CALLABLE void svd3(const mat_t<3,3,Type>& A, mat_t<3,3,Type>& U, vec_t<3,Type>& sigma, mat_t<3,3,Type>& V) {
493504
Type s12, s13, s21, s23, s31, s32;
@@ -550,15 +561,14 @@ inline CUDA_CALLABLE void adj_svd3(const mat_t<3,3,Type>& A,
550561

551562
template<typename Type>
552563
inline CUDA_CALLABLE void svd2(const mat_t<2,2,Type>& A, mat_t<2,2,Type>& U, vec_t<2,Type>& sigma, mat_t<2,2,Type>& V) {
553-
Type s12, s21;
554564
_svd_2(A.data[0][0], A.data[0][1],
555565
A.data[1][0], A.data[1][1],
556566

557567
U.data[0][0], U.data[0][1],
558568
U.data[1][0], U.data[1][1],
559569

560-
sigma[0], s12,
561-
s21, sigma[1],
570+
sigma[0],
571+
sigma[1],
562572

563573
V.data[0][0], V.data[0][1],
564574
V.data[1][0], V.data[1][1]);

warp/tests/test_mat.py

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,15 +1061,21 @@ def check_mat_svd2(
10611061
Vout: wp.array(dtype=mat22),
10621062
outcomponents: wp.array(dtype=wptype),
10631063
):
1064+
tid = wp.tid()
1065+
10641066
U = mat22()
10651067
sigma = vec2()
10661068
V = mat22()
10671069

1068-
wp.svd2(m2[0], U, sigma, V) # Assuming there's a 2D SVD kernel
1070+
wp.svd2(m2[tid], U, sigma, V) # Assuming there's a 2D SVD kernel
10691071

1070-
Uout[0] = U
1071-
sigmaout[0] = sigma
1072-
Vout[0] = V
1072+
Uout[tid] = U
1073+
sigmaout[tid] = sigma
1074+
Vout[tid] = V
1075+
1076+
# backprop test only for first input
1077+
if tid > 0:
1078+
return
10731079

10741080
# multiply outputs by 2 so we've got something to backpropagate:
10751081
idx = 0
@@ -1094,22 +1100,46 @@ def check_mat_svd2(
10941100
if register_kernels:
10951101
return
10961102

1097-
m2 = wp.array(randvals(rng, [1, 2, 2], dtype) + np.eye(2), dtype=mat22, requires_grad=True, device=device)
1103+
mats = np.concatenate(
1104+
(
1105+
randvals(rng, [24, 2, 2], dtype) + np.eye(2),
1106+
# rng unlikely to hit edge cases, build them manually
1107+
[
1108+
np.zeros((2, 2)),
1109+
np.eye(2),
1110+
5.0 * np.eye(2),
1111+
np.array([[1.0, 0.0], [0.0, 0.0]]),
1112+
np.array([[0.0, 0.0], [0.0, 2.0]]),
1113+
np.array([[1.0, 1.0], [-1.0, -1.0]]),
1114+
np.array([[3.0, 0.0], [4.0, 5.0]]),
1115+
np.eye(2) + tol * np.array([[1.0, 1.0], [-1.0, -1.0]]),
1116+
],
1117+
),
1118+
axis=0,
1119+
)
1120+
M = len(mats)
1121+
m2 = wp.array(mats, dtype=mat22, requires_grad=True, device=device)
10981122

10991123
outcomponents = wp.zeros(2 * 2 * 2 + 2, dtype=wptype, requires_grad=True, device=device)
1100-
Uout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)
1101-
sigmaout = wp.zeros(1, dtype=vec2, requires_grad=True, device=device)
1102-
Vout = wp.zeros(1, dtype=mat22, requires_grad=True, device=device)
1124+
Uout = wp.zeros(M, dtype=mat22, requires_grad=True, device=device)
1125+
sigmaout = wp.zeros(M, dtype=vec2, requires_grad=True, device=device)
1126+
Vout = wp.zeros(M, dtype=mat22, requires_grad=True, device=device)
11031127

1104-
wp.launch(kernel, dim=1, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
1128+
wp.launch(kernel, dim=M, inputs=[m2], outputs=[Uout, sigmaout, Vout, outcomponents], device=device)
11051129

1106-
Uout_np = Uout.numpy()[0].astype(np.float64)
1107-
sigmaout_np = np.diag(sigmaout.numpy()[0].astype(np.float64))
1108-
Vout_np = Vout.numpy()[0].astype(np.float64)
1130+
Uout_np = Uout.numpy().astype(np.float64)
1131+
sigmaout_np = sigmaout.numpy().astype(np.float64)
1132+
Vout_np = Vout.numpy().astype(np.float64)
1133+
1134+
USVt_np = Uout_np @ (sigmaout_np[..., None] * np.transpose(Vout_np, axes=(0, 2, 1)))
11091135

11101136
assert_np_equal(
1111-
np.matmul(Uout_np, np.matmul(sigmaout_np, Vout_np.T)), m2.numpy()[0].astype(np.float64), tol=30 * tol
1137+
Uout_np @ np.transpose(Uout_np, axes=(0, 2, 1)), np.broadcast_to(np.eye(2), shape=(M, 2, 2)), tol=30 * tol
11121138
)
1139+
assert_np_equal(
1140+
Vout_np @ np.transpose(Vout_np, axes=(0, 2, 1)), np.broadcast_to(np.eye(2), shape=(M, 2, 2)), tol=30 * tol
1141+
)
1142+
assert_np_equal(USVt_np, m2.numpy().astype(np.float64), tol=30 * tol)
11131143

11141144
if dtype == np.float16:
11151145
# Skip gradient check for float16 due to rounding errors
@@ -1128,7 +1158,7 @@ def check_mat_svd2(
11281158

11291159
tape.zero()
11301160

1131-
dx = 0.0001
1161+
dx = 0.001
11321162
fdtol = 5.0e-4 if dtype == np.float64 else 2.0e-2
11331163
for ii in range(2):
11341164
for jj in range(2):
@@ -1163,9 +1193,9 @@ def test_qr(test, device, dtype, register_kernels=False):
11631193
rng = np.random.default_rng(123)
11641194

11651195
tol = {
1166-
np.float16: 2.0e-3,
1196+
np.float16: 2.5e-3,
11671197
np.float32: 1.0e-6,
1168-
np.float64: 1.0e-6,
1198+
np.float64: 1.0e-12,
11691199
}.get(dtype, 0)
11701200

11711201
wptype = wp.types.np_dtype_to_warp_type[np.dtype(dtype)]

0 commit comments

Comments
 (0)