1818from absl .testing import parameterized
1919import jax .numpy as jnp
2020import metrax
21+ import keras_rs
2122import numpy as np
2223
2324np .random .seed (42 )
4748 2 ,
4849 size = (BATCH_SIZE , 1 ),
4950).astype (np .float32 )
50- # TODO(jiwonshin): Replace with keras metric once it is available in OSS.
51- MAP_FROM_KERAS = np .array ([
52- 0.2083333432674408 ,
53- 0.4791666865348816 ,
54- 0.4791666865348816 ,
55- 0.5416666865348816 ,
56- 0.574999988079071 ,
57- 0.637499988079071 ,
58- ])
59- MAP_FROM_KERAS_VS1 = np .array ([0.75 , 0.75 , 0.75 , 0.75 , 0.75 , 0.75 ])
60- P_FROM_KERAS = np .array ([0.75 , 0.875 , 0.58333337306976320 , 0.5625 , 0.5 , 0.5 ])
61- P_FROM_KERAS_VS1 = np .array ([0.75 , 0.75 , 0.75 , 0.75 , 0.75 , 0.75 ])
62- R_FROM_KERAS = np .array ([
63- 0.2083333432674408 ,
64- 0.5416666865348816 ,
65- 0.5416666865348816 ,
66- 0.625 ,
67- 0.6666666865348816 ,
68- 0.75 ,
69- ])
70- R_FROM_KERAS_VS1 = np .array ([0.75 , 0.75 , 0.75 , 0.75 , 0.75 , 0.75 ])
71- DCG_FROM_KERAS = np .array ([
72- 0.25 ,
73- 0.880929708480835 ,
74- 1.255929708480835 ,
75- 1.5789371728897095 ,
76- 1.8690768480300903 ,
77- 2.04718017578125 ,
78- ])
79- DCG_FROM_KERAS_VS1 = np .array ([0.75 , 0.75 , 0.75 , 0.75 , 0.75 , 0.75 ])
80- NDGC_FROM_KERAS = np .array ([
81- 0.25 ,
82- 0.5401396155357361 ,
83- 0.5893810987472534 ,
84- 0.6163855791091919 ,
85- 0.6469491124153137 ,
86- 0.6560885906219482 ,
87- ])
88- NDGC_FROM_KERAS_VS1 = np .array ([0.75 , 0.75 , 0.75 , 0.75 , 0.75 , 0.75 ])
8951
9052
9153class RankingMetricsTest (parameterized .TestCase ):
@@ -94,75 +56,75 @@ class RankingMetricsTest(parameterized.TestCase):
9456 (
9557 'averageprecisionatk_basic' ,
9658 metrax .AveragePrecisionAtK ,
59+ keras_rs .metrics .MeanAveragePrecision ,
9760 OUTPUT_LABELS ,
9861 OUTPUT_PREDS ,
99- MAP_FROM_KERAS ,
10062 ),
10163 (
10264 'averageprecisionatk_vocab_size_one' ,
10365 metrax .AveragePrecisionAtK ,
66+ keras_rs .metrics .MeanAveragePrecision ,
10467 OUTPUT_LABELS_VS1 ,
10568 OUTPUT_PREDS_VS1 ,
106- MAP_FROM_KERAS_VS1 ,
10769 ),
10870 (
10971 'precisionatk_basic' ,
11072 metrax .PrecisionAtK ,
73+ keras_rs .metrics .PrecisionAtK ,
11174 OUTPUT_LABELS ,
11275 OUTPUT_PREDS ,
113- P_FROM_KERAS ,
11476 ),
11577 (
11678 'precisionatk_vocab_size_one' ,
11779 metrax .PrecisionAtK ,
80+ keras_rs .metrics .PrecisionAtK ,
11881 OUTPUT_LABELS_VS1 ,
11982 OUTPUT_PREDS_VS1 ,
120- P_FROM_KERAS_VS1 ,
12183 ),
12284 (
12385 'recallatk_basic' ,
12486 metrax .RecallAtK ,
87+ keras_rs .metrics .RecallAtK ,
12588 OUTPUT_LABELS ,
12689 OUTPUT_PREDS ,
127- R_FROM_KERAS ,
12890 ),
12991 (
13092 'recallatk_vocab_size_one' ,
13193 metrax .RecallAtK ,
94+ keras_rs .metrics .RecallAtK ,
13295 OUTPUT_LABELS_VS1 ,
13396 OUTPUT_PREDS_VS1 ,
134- R_FROM_KERAS_VS1 ,
13597 ),
13698 (
13799 'dcgatk_basic' ,
138100 metrax .DCGAtK ,
101+ keras_rs .metrics .DCG ,
139102 OUTPUT_RELEVANCES ,
140103 OUTPUT_PREDS ,
141- DCG_FROM_KERAS ,
142104 ),
143105 (
144106 'dcgatk_vocab_size_one' ,
145107 metrax .DCGAtK ,
108+ keras_rs .metrics .DCG ,
146109 OUTPUT_RELEVANCES_VS1 ,
147110 OUTPUT_PREDS_VS1 ,
148- DCG_FROM_KERAS_VS1 ,
149111 ),
150112 (
151113 'ndcgatk_basic' ,
152114 metrax .NDCGAtK ,
115+ keras_rs .metrics .NDCG ,
153116 OUTPUT_RELEVANCES ,
154117 OUTPUT_PREDS ,
155- NDGC_FROM_KERAS ,
156118 ),
157119 (
158120 'ndcgatk_vocab_size_one' ,
159121 metrax .NDCGAtK ,
122+ keras_rs .metrics .NDCG ,
160123 OUTPUT_RELEVANCES_VS1 ,
161124 OUTPUT_PREDS_VS1 ,
162- NDGC_FROM_KERAS_VS1 ,
163125 ),
164126 )
165- def test_ranking_metrics (self , metric , y_true , y_pred , map_from_keras ):
127+ def test_ranking_metrics (self , metric , keras_metric , y_true , y_pred ):
166128 """Test that `NDCGAtK` Metric computes correct values."""
167129 ks = jnp .array ([1 , 2 , 3 , 4 , 5 , 6 ])
168130 metric = metric .from_model_output (
@@ -171,9 +133,15 @@ def test_ranking_metrics(self, metric, y_true, y_pred, map_from_keras):
171133 ks = ks ,
172134 )
173135
136+ keras_metrics = [keras_metric (k = n + 1 ) for n in range (6 )]
137+ results = []
138+ for keras_metric in keras_metrics :
139+ keras_metric .update_state (y_true , y_pred )
140+ results .append (keras_metric .result ())
141+
174142 np .testing .assert_allclose (
175143 metric .compute (),
176- map_from_keras ,
144+ jnp . array ( results ) ,
177145 rtol = 1e-05 ,
178146 atol = 1e-05 ,
179147 )
0 commit comments