Skip to content

Commit 2d9a1d9

Browse files
committed
update base_metrics to base
1 parent b0768c8 commit 2d9a1d9

File tree

1 file changed

+58
-65
lines changed

1 file changed

+58
-65
lines changed

src/metrax/base_test.py

Lines changed: 58 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Tests for metrax base metrics."""
15+
"""Tests for metrax base utilities."""
1616

1717
from absl.testing import absltest
1818
from absl.testing import parameterized
1919
import jax.numpy as jnp
2020
import keras
2121
import metrax
22-
from metrax import base_metrics
22+
from metrax import base
2323
import numpy as np
2424

2525
np.random.seed(42)
@@ -36,70 +36,63 @@
3636
)
3737

3838

39-
class BaseMetricsTest(parameterized.TestCase):
39+
class BaseTest(parameterized.TestCase):
4040

41-
def test_basic_division(self):
42-
x = jnp.array([10.0, 20.0, 30.0])
43-
y = jnp.array([2.0, 4.0, 5.0])
44-
expected = jnp.array([5.0, 5.0, 6.0])
45-
result = base_metrics.divide_no_nan(x, y)
46-
self.assertTrue(jnp.array_equal(result, expected))
47-
48-
def test_division_by_zero(self):
49-
x = jnp.array([10.0, 20.0, 30.0])
50-
y = jnp.array([2.0, 0.0, 5.0])
51-
expected = jnp.array([5.0, 0.0, 6.0])
52-
result = base_metrics.divide_no_nan(x, y)
53-
self.assertTrue(jnp.array_equal(result, expected))
54-
55-
def test_all_zeros_denominator(self):
56-
x = jnp.array([10.0, 20.0, 30.0])
57-
y = jnp.array([0.0, 0.0, 0.0])
58-
expected = jnp.array([0.0, 0.0, 0.0])
59-
result = base_metrics.divide_no_nan(x, y)
60-
self.assertTrue(jnp.array_equal(result, expected))
61-
62-
def test_all_zeros_numerator(self):
63-
x = jnp.array([0.0, 0.0, 0.0])
64-
y = jnp.array([2.0, 4.0, 5.0])
65-
expected = jnp.array([0.0, 0.0, 0.0])
66-
result = base_metrics.divide_no_nan(x, y)
67-
self.assertTrue(jnp.array_equal(result, expected))
68-
69-
def test_mixed_zeros(self):
70-
x = jnp.array([10.0, 0.0, 30.0, 0.0])
71-
y = jnp.array([2.0, 0.0, 5.0, 4.0])
72-
expected = jnp.array([5.0, 0.0, 6.0, 0.0])
73-
result = base_metrics.divide_no_nan(x, y)
74-
self.assertTrue(jnp.array_equal(result, expected))
75-
76-
def test_scalar_inputs(self):
77-
x = jnp.array(10.0)
78-
y = jnp.array(2.0)
79-
expected = jnp.array(5.0)
80-
result = base_metrics.divide_no_nan(x, y)
81-
self.assertTrue(jnp.array_equal(result, expected))
82-
83-
def test_scalar_denominator_zero(self):
84-
x = jnp.array(10.0)
85-
y = jnp.array(0.0)
86-
expected = jnp.array(0.0)
87-
result = base_metrics.divide_no_nan(x, y)
88-
self.assertTrue(jnp.array_equal(result, expected))
89-
90-
def test_negative_values(self):
91-
x = jnp.array([-10.0, 20.0, -30.0])
92-
y = jnp.array([2.0, -4.0, 5.0])
93-
expected = jnp.array([-5.0, -5.0, -6.0])
94-
result = base_metrics.divide_no_nan(x, y)
95-
self.assertTrue(jnp.array_equal(result, expected))
96-
97-
def test_negative_and_zero_values(self):
98-
x = jnp.array([-10.0, 20.0, -30.0, 10.0])
99-
y = jnp.array([2.0, -4.0, 0.0, 0.0])
100-
expected = jnp.array([-5.0, -5.0, 0.0, 0.0])
101-
result = base_metrics.divide_no_nan(x, y)
102-
self.assertTrue(jnp.array_equal(result, expected))
41+
@parameterized.named_parameters(
42+
(
43+
'basic_division',
44+
np.array([10.0, 20.0, 30.0]),
45+
np.array([2.0, 4.0, 5.0]),
46+
np.array([5.0, 5.0, 6.0]),
47+
),
48+
(
49+
'division_by_zero',
50+
np.array([10.0, 20.0, 30.0]),
51+
np.array([2.0, 0.0, 5.0]),
52+
np.array([5.0, 0.0, 6.0]),
53+
),
54+
(
55+
'all_zeros_denominator',
56+
np.array([10.0, 20.0, 30.0]),
57+
np.array([0.0, 0.0, 0.0]),
58+
np.array([0.0, 0.0, 0.0]),
59+
),
60+
(
61+
'all_zeros_numerator',
62+
np.array([0.0, 0.0, 0.0]),
63+
np.array([2.0, 4.0, 5.0]),
64+
np.array([0.0, 0.0, 0.0]),
65+
),
66+
(
67+
'mixed_zeros',
68+
np.array([10.0, 0.0, 30.0, 0.0]),
69+
np.array([2.0, 0.0, 5.0, 4.0]),
70+
np.array([5.0, 0.0, 6.0, 0.0]),
71+
),
72+
('scalar_inputs', np.array(10.0), np.array(2.0), np.array(5.0)),
73+
(
74+
'scalar_denominator_zero',
75+
np.array(10.0),
76+
np.array(0.0),
77+
np.array(0.0),
78+
),
79+
(
80+
'negative_values',
81+
np.array([-10.0, 20.0, -30.0]),
82+
np.array([2.0, -4.0, 5.0]),
83+
np.array([-5.0, -5.0, -6.0]),
84+
),
85+
(
86+
'negative_and_zero_values',
87+
np.array([-10.0, 20.0, -30.0, 10.0]),
88+
np.array([2.0, -4.0, 0.0, 0.0]),
89+
np.array([-5.0, -5.0, 0.0, 0.0]),
90+
),
91+
)
92+
def test_divide_no_nan(self, x, y, expected):
93+
"""Test that `divide_no_nan` functioncomputes correct values."""
94+
result = base.divide_no_nan(x, y)
95+
self.assertTrue(np.array_equal(result, expected))
10396

10497
@parameterized.named_parameters(
10598
('basic_f16', OUTPUT_F16, None),

0 commit comments

Comments
 (0)