Skip to content

Commit f1973a2

Browse files
committed
add changes for ranking_metrics_test
1 parent 7ba3bd9 commit f1973a2

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

src/metrax/ranking_metrics_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
0.637499988079071,
4949
])
5050
MAP_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
51+
P_FROM_KERAS = np.array([0.75, 0.875, 0.58333337306976320, 0.5625, 0.5, 0.5])
52+
P_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
5153

5254

5355
class RankingMetricsTest(parameterized.TestCase):
@@ -89,6 +91,31 @@ def test_averageprecisionatk(self, y_true, y_pred, map_from_keras, jitted):
8991
atol=1e-05,
9092
)
9193

94+
@parameterized.named_parameters(
95+
('basic', OUTPUT_LABELS, OUTPUT_PREDS, P_FROM_KERAS),
96+
(
97+
'vocab_size_one',
98+
OUTPUT_LABELS_VS1,
99+
OUTPUT_PREDS_VS1,
100+
P_FROM_KERAS_VS1,
101+
),
102+
)
103+
def test_precisionatk(self, y_true, y_pred, map_from_keras):
104+
"""Test that `PrecisionAtK` Metric computes correct values."""
105+
ks = jnp.array([1, 2, 3, 4, 5, 6])
106+
metric = metrax.PrecisionAtK.from_model_output(
107+
predictions=y_pred,
108+
labels=y_true,
109+
ks=ks,
110+
)
111+
112+
np.testing.assert_allclose(
113+
metric.compute(),
114+
map_from_keras,
115+
rtol=1e-05,
116+
atol=1e-05,
117+
)
118+
92119

93120
if __name__ == '__main__':
94121
absltest.main()

0 commit comments

Comments
 (0)