Skip to content

Commit 7ac0ffb

Browse files
committed
use keras_rs metrics
1 parent 0deb624 commit 7ac0ffb

File tree

2 files changed

+20
-51
lines changed

2 files changed

+20
-51
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ clu
33
jax[cpu]
44
keras-hub
55
keras-nlp
6+
keras-rs
67
pytest
78
rouge-score
89
scikit-learn

src/metrax/ranking_metrics_test.py

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from absl.testing import parameterized
1919
import jax.numpy as jnp
2020
import metrax
21+
import keras_rs
2122
import numpy as np
2223

2324
np.random.seed(42)
@@ -47,45 +48,6 @@
4748
2,
4849
size=(BATCH_SIZE, 1),
4950
).astype(np.float32)
50-
# TODO(jiwonshin): Replace with keras metric once it is available in OSS.
51-
MAP_FROM_KERAS = np.array([
52-
0.2083333432674408,
53-
0.4791666865348816,
54-
0.4791666865348816,
55-
0.5416666865348816,
56-
0.574999988079071,
57-
0.637499988079071,
58-
])
59-
MAP_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
60-
P_FROM_KERAS = np.array([0.75, 0.875, 0.58333337306976320, 0.5625, 0.5, 0.5])
61-
P_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
62-
R_FROM_KERAS = np.array([
63-
0.2083333432674408,
64-
0.5416666865348816,
65-
0.5416666865348816,
66-
0.625,
67-
0.6666666865348816,
68-
0.75,
69-
])
70-
R_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
71-
DCG_FROM_KERAS = np.array([
72-
0.25,
73-
0.880929708480835,
74-
1.255929708480835,
75-
1.5789371728897095,
76-
1.8690768480300903,
77-
2.04718017578125,
78-
])
79-
DCG_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
80-
NDGC_FROM_KERAS = np.array([
81-
0.25,
82-
0.5401396155357361,
83-
0.5893810987472534,
84-
0.6163855791091919,
85-
0.6469491124153137,
86-
0.6560885906219482,
87-
])
88-
NDGC_FROM_KERAS_VS1 = np.array([0.75, 0.75, 0.75, 0.75, 0.75, 0.75])
8951

9052

9153
class RankingMetricsTest(parameterized.TestCase):
@@ -94,75 +56,75 @@ class RankingMetricsTest(parameterized.TestCase):
9456
(
9557
'averageprecisionatk_basic',
9658
metrax.AveragePrecisionAtK,
59+
keras_rs.metrics.MeanAveragePrecision,
9760
OUTPUT_LABELS,
9861
OUTPUT_PREDS,
99-
MAP_FROM_KERAS,
10062
),
10163
(
10264
'averageprecisionatk_vocab_size_one',
10365
metrax.AveragePrecisionAtK,
66+
keras_rs.metrics.MeanAveragePrecision,
10467
OUTPUT_LABELS_VS1,
10568
OUTPUT_PREDS_VS1,
106-
MAP_FROM_KERAS_VS1,
10769
),
10870
(
10971
'precisionatk_basic',
11072
metrax.PrecisionAtK,
73+
keras_rs.metrics.PrecisionAtK,
11174
OUTPUT_LABELS,
11275
OUTPUT_PREDS,
113-
P_FROM_KERAS,
11476
),
11577
(
11678
'precisionatk_vocab_size_one',
11779
metrax.PrecisionAtK,
80+
keras_rs.metrics.PrecisionAtK,
11881
OUTPUT_LABELS_VS1,
11982
OUTPUT_PREDS_VS1,
120-
P_FROM_KERAS_VS1,
12183
),
12284
(
12385
'recallatk_basic',
12486
metrax.RecallAtK,
87+
keras_rs.metrics.RecallAtK,
12588
OUTPUT_LABELS,
12689
OUTPUT_PREDS,
127-
R_FROM_KERAS,
12890
),
12991
(
13092
'recallatk_vocab_size_one',
13193
metrax.RecallAtK,
94+
keras_rs.metrics.RecallAtK,
13295
OUTPUT_LABELS_VS1,
13396
OUTPUT_PREDS_VS1,
134-
R_FROM_KERAS_VS1,
13597
),
13698
(
13799
'dcgatk_basic',
138100
metrax.DCGAtK,
101+
keras_rs.metrics.DCG,
139102
OUTPUT_RELEVANCES,
140103
OUTPUT_PREDS,
141-
DCG_FROM_KERAS,
142104
),
143105
(
144106
'dcgatk_vocab_size_one',
145107
metrax.DCGAtK,
108+
keras_rs.metrics.DCG,
146109
OUTPUT_RELEVANCES_VS1,
147110
OUTPUT_PREDS_VS1,
148-
DCG_FROM_KERAS_VS1,
149111
),
150112
(
151113
'ndcgatk_basic',
152114
metrax.NDCGAtK,
115+
keras_rs.metrics.NDCG,
153116
OUTPUT_RELEVANCES,
154117
OUTPUT_PREDS,
155-
NDGC_FROM_KERAS,
156118
),
157119
(
158120
'ndcgatk_vocab_size_one',
159121
metrax.NDCGAtK,
122+
keras_rs.metrics.NDCG,
160123
OUTPUT_RELEVANCES_VS1,
161124
OUTPUT_PREDS_VS1,
162-
NDGC_FROM_KERAS_VS1,
163125
),
164126
)
165-
def test_ranking_metrics(self, metric, y_true, y_pred, map_from_keras):
127+
def test_ranking_metrics(self, metric, keras_metric, y_true, y_pred):
166128
"""Test that `NDCGAtK` Metric computes correct values."""
167129
ks = jnp.array([1, 2, 3, 4, 5, 6])
168130
metric = metric.from_model_output(
@@ -171,9 +133,15 @@ def test_ranking_metrics(self, metric, y_true, y_pred, map_from_keras):
171133
ks=ks,
172134
)
173135

136+
keras_metrics = [keras_metric(k=n+1) for n in range(6)]
137+
results = []
138+
for keras_metric in keras_metrics:
139+
keras_metric.update_state(y_true, y_pred)
140+
results.append(keras_metric.result())
141+
174142
np.testing.assert_allclose(
175143
metric.compute(),
176-
map_from_keras,
144+
jnp.array(results),
177145
rtol=1e-05,
178146
atol=1e-05,
179147
)

0 commit comments

Comments
 (0)