|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -"""Tests for metrax base metrics.""" |
| 15 | +"""Tests for metrax base utilities.""" |
16 | 16 |
|
17 | 17 | from absl.testing import absltest |
18 | 18 | from absl.testing import parameterized |
19 | 19 | import jax.numpy as jnp |
20 | 20 | import keras |
21 | 21 | import metrax |
22 | | -from metrax import base_metrics |
| 22 | +from metrax import base |
23 | 23 | import numpy as np |
24 | 24 |
|
25 | 25 | np.random.seed(42) |
|
36 | 36 | ) |
37 | 37 |
|
38 | 38 |
|
39 | | -class BaseMetricsTest(parameterized.TestCase): |
| 39 | +class BaseTest(parameterized.TestCase): |
40 | 40 |
|
41 | | - def test_basic_division(self): |
42 | | - x = jnp.array([10.0, 20.0, 30.0]) |
43 | | - y = jnp.array([2.0, 4.0, 5.0]) |
44 | | - expected = jnp.array([5.0, 5.0, 6.0]) |
45 | | - result = base_metrics.divide_no_nan(x, y) |
46 | | - self.assertTrue(jnp.array_equal(result, expected)) |
47 | | - |
48 | | - def test_division_by_zero(self): |
49 | | - x = jnp.array([10.0, 20.0, 30.0]) |
50 | | - y = jnp.array([2.0, 0.0, 5.0]) |
51 | | - expected = jnp.array([5.0, 0.0, 6.0]) |
52 | | - result = base_metrics.divide_no_nan(x, y) |
53 | | - self.assertTrue(jnp.array_equal(result, expected)) |
54 | | - |
55 | | - def test_all_zeros_denominator(self): |
56 | | - x = jnp.array([10.0, 20.0, 30.0]) |
57 | | - y = jnp.array([0.0, 0.0, 0.0]) |
58 | | - expected = jnp.array([0.0, 0.0, 0.0]) |
59 | | - result = base_metrics.divide_no_nan(x, y) |
60 | | - self.assertTrue(jnp.array_equal(result, expected)) |
61 | | - |
62 | | - def test_all_zeros_numerator(self): |
63 | | - x = jnp.array([0.0, 0.0, 0.0]) |
64 | | - y = jnp.array([2.0, 4.0, 5.0]) |
65 | | - expected = jnp.array([0.0, 0.0, 0.0]) |
66 | | - result = base_metrics.divide_no_nan(x, y) |
67 | | - self.assertTrue(jnp.array_equal(result, expected)) |
68 | | - |
69 | | - def test_mixed_zeros(self): |
70 | | - x = jnp.array([10.0, 0.0, 30.0, 0.0]) |
71 | | - y = jnp.array([2.0, 0.0, 5.0, 4.0]) |
72 | | - expected = jnp.array([5.0, 0.0, 6.0, 0.0]) |
73 | | - result = base_metrics.divide_no_nan(x, y) |
74 | | - self.assertTrue(jnp.array_equal(result, expected)) |
75 | | - |
76 | | - def test_scalar_inputs(self): |
77 | | - x = jnp.array(10.0) |
78 | | - y = jnp.array(2.0) |
79 | | - expected = jnp.array(5.0) |
80 | | - result = base_metrics.divide_no_nan(x, y) |
81 | | - self.assertTrue(jnp.array_equal(result, expected)) |
82 | | - |
83 | | - def test_scalar_denominator_zero(self): |
84 | | - x = jnp.array(10.0) |
85 | | - y = jnp.array(0.0) |
86 | | - expected = jnp.array(0.0) |
87 | | - result = base_metrics.divide_no_nan(x, y) |
88 | | - self.assertTrue(jnp.array_equal(result, expected)) |
89 | | - |
90 | | - def test_negative_values(self): |
91 | | - x = jnp.array([-10.0, 20.0, -30.0]) |
92 | | - y = jnp.array([2.0, -4.0, 5.0]) |
93 | | - expected = jnp.array([-5.0, -5.0, -6.0]) |
94 | | - result = base_metrics.divide_no_nan(x, y) |
95 | | - self.assertTrue(jnp.array_equal(result, expected)) |
96 | | - |
97 | | - def test_negative_and_zero_values(self): |
98 | | - x = jnp.array([-10.0, 20.0, -30.0, 10.0]) |
99 | | - y = jnp.array([2.0, -4.0, 0.0, 0.0]) |
100 | | - expected = jnp.array([-5.0, -5.0, 0.0, 0.0]) |
101 | | - result = base_metrics.divide_no_nan(x, y) |
102 | | - self.assertTrue(jnp.array_equal(result, expected)) |
| 41 | + @parameterized.named_parameters( |
| 42 | + ( |
| 43 | + 'basic_division', |
| 44 | + np.array([10.0, 20.0, 30.0]), |
| 45 | + np.array([2.0, 4.0, 5.0]), |
| 46 | + np.array([5.0, 5.0, 6.0]), |
| 47 | + ), |
| 48 | + ( |
| 49 | + 'division_by_zero', |
| 50 | + np.array([10.0, 20.0, 30.0]), |
| 51 | + np.array([2.0, 0.0, 5.0]), |
| 52 | + np.array([5.0, 0.0, 6.0]), |
| 53 | + ), |
| 54 | + ( |
| 55 | + 'all_zeros_denominator', |
| 56 | + np.array([10.0, 20.0, 30.0]), |
| 57 | + np.array([0.0, 0.0, 0.0]), |
| 58 | + np.array([0.0, 0.0, 0.0]), |
| 59 | + ), |
| 60 | + ( |
| 61 | + 'all_zeros_numerator', |
| 62 | + np.array([0.0, 0.0, 0.0]), |
| 63 | + np.array([2.0, 4.0, 5.0]), |
| 64 | + np.array([0.0, 0.0, 0.0]), |
| 65 | + ), |
| 66 | + ( |
| 67 | + 'mixed_zeros', |
| 68 | + np.array([10.0, 0.0, 30.0, 0.0]), |
| 69 | + np.array([2.0, 0.0, 5.0, 4.0]), |
| 70 | + np.array([5.0, 0.0, 6.0, 0.0]), |
| 71 | + ), |
| 72 | + ('scalar_inputs', np.array(10.0), np.array(2.0), np.array(5.0)), |
| 73 | + ( |
| 74 | + 'scalar_denominator_zero', |
| 75 | + np.array(10.0), |
| 76 | + np.array(0.0), |
| 77 | + np.array(0.0), |
| 78 | + ), |
| 79 | + ( |
| 80 | + 'negative_values', |
| 81 | + np.array([-10.0, 20.0, -30.0]), |
| 82 | + np.array([2.0, -4.0, 5.0]), |
| 83 | + np.array([-5.0, -5.0, -6.0]), |
| 84 | + ), |
| 85 | + ( |
| 86 | + 'negative_and_zero_values', |
| 87 | + np.array([-10.0, 20.0, -30.0, 10.0]), |
| 88 | + np.array([2.0, -4.0, 0.0, 0.0]), |
| 89 | + np.array([-5.0, -5.0, 0.0, 0.0]), |
| 90 | + ), |
| 91 | + ) |
| 92 | + def test_divide_no_nan(self, x, y, expected): |
| 93 | + """Test that `divide_no_nan` functioncomputes correct values.""" |
| 94 | + result = base.divide_no_nan(x, y) |
| 95 | + self.assertTrue(np.array_equal(result, expected)) |
103 | 96 |
|
104 | 97 | @parameterized.named_parameters( |
105 | 98 | ('basic_f16', OUTPUT_F16, None), |
|
0 commit comments