|
15 | 15 | """Tests for partial_lanczos.py."""
|
16 | 16 |
|
17 | 17 | import jax
|
18 |
| -from jax import config |
19 | 18 | import jax.numpy as jnp
|
20 | 19 | import numpy as np
|
21 | 20 | from tensorflow_probability.python.experimental.fastgp import mbcg
|
22 | 21 | from tensorflow_probability.python.experimental.fastgp import partial_lanczos
|
23 |
| -from absl.testing import absltest |
| 22 | +from tensorflow_probability.substrates.jax.internal import test_util |
24 | 23 |
|
25 | 24 | # pylint: disable=invalid-name
|
26 | 25 |
|
27 | 26 |
|
28 |
| -class _PartialLanczosTest(absltest.TestCase): |
| 27 | +class _PartialLanczosTest(test_util.TestCase): |
29 | 28 |
|
30 | 29 | def test_gram_schmidt(self):
|
31 | 30 | w = jnp.ones((5, 1), dtype=self.dtype)
|
@@ -59,18 +58,17 @@ def test_partial_lanczos_identity(self):
|
59 | 58 | )
|
60 | 59 |
|
61 | 60 | def test_diagonal_matrix_heavily_imbalanced(self):
|
| 61 | + if self.dtype == np.float32: |
| 62 | + self.skipTest("Numerically unstable") |
62 | 63 | A = jnp.diag(
|
63 | 64 | jnp.array([1e-3, 1.0, 2.0, 3.0, 4.0, 10000.0], dtype=self.dtype)
|
64 | 65 | )
|
65 | 66 | v = jnp.ones((6, 1)).astype(self.dtype)
|
66 | 67 | Q, T = partial_lanczos.partial_lanczos(
|
67 |
| - lambda x: A @ x, v, jax.random.PRNGKey(9), 6 |
| 68 | + lambda x: A @ x, v, test_util.test_seed(), 6 |
68 | 69 | )
|
69 | 70 | atol = 1e-6
|
70 | 71 | det_rtol = 1e-6
|
71 |
| - if self.dtype == np.float32: |
72 |
| - atol = 2e-3 |
73 |
| - det_rtol = 0.26 |
74 | 72 | np.testing.assert_allclose(jnp.identity(6), Q[0] @ Q[0].T, atol=atol)
|
75 | 73 | np.testing.assert_allclose(
|
76 | 74 | mbcg.tridiagonal_det(T.diag[0, :], T.off_diag[0, :]),
|
@@ -140,23 +138,6 @@ def test_make_lanczos_preconditioner(self):
|
140 | 138 | out = preconditioner.solve(jnp.identity(100))
|
141 | 139 | np.testing.assert_allclose(out, jnp.identity(100), atol=0.2)
|
142 | 140 |
|
143 |
| - def test_preconditioner_preserves_psd(self): |
144 |
| - M = jnp.array([ |
145 |
| - [2.6452732, -1.4553788, -0.5272188, 0.524349], |
146 |
| - [-1.4553788, 4.4274387, 0.21998158, 1.8666775], |
147 |
| - [-0.5272188, 0.21998158, 2.4756536, -0.5257966], |
148 |
| - [0.524349, 1.8666775, -0.5257966, 2.889879], |
149 |
| - ]).astype(self.dtype) |
150 |
| - orig_eigenvalues = jnp.linalg.eigvalsh(M) |
151 |
| - self.assertFalse((orig_eigenvalues < 0).any()) |
152 |
| - |
153 |
| - preconditioner = partial_lanczos.make_lanczos_preconditioner( |
154 |
| - M, jax.random.PRNGKey(7) |
155 |
| - ) |
156 |
| - preconditioned_M = preconditioner.solve(M) |
157 |
| - after_eigenvalues = jnp.linalg.eigvalsh(preconditioned_M) |
158 |
| - self.assertFalse((after_eigenvalues < 0).any()) |
159 |
| - |
160 | 141 | def test_my_tridiagonal_solve(self):
|
161 | 142 | empty = jnp.array([]).astype(self.dtype)
|
162 | 143 | self.assertEqual(
|
@@ -220,5 +201,4 @@ class PartialLanczosTestFloat64(_PartialLanczosTest):
|
220 | 201 |
|
221 | 202 |
|
222 | 203 | if __name__ == "__main__":
|
223 |
| - config.update("jax_enable_x64", True) |
224 |
| - absltest.main() |
| 204 | + test_util.main() |
0 commit comments