1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- """Tests for metrax.metrax ."""
15+ """Tests for metrax metrics ."""
1616
1717from absl .testing import absltest
1818from absl .testing import parameterized
1919import jax
2020import jax .numpy as jnp
21+ import keras
2122import keras_hub
2223import metrax
2324import numpy as np
4748
4849class MetricsTest (parameterized .TestCase ):
4950
50- def setUp (self ):
51- super ().setUp ()
52-
53- # TODO(jeffcarp): Merge these into generated fixtures.
54- self .model_outputs = (
55- dict (
56- logits = jnp .array (
57- [0.34 , 0.89 , 0.12 , 0.67 , 0.98 , 0.23 , 0.56 , 0.71 , 0.45 , 0.08 ]
58- ),
59- labels = jnp .array ([1 , 0 , 1 , 1 , 0 , 0 , 1 , 0 , 1 , 1 ]),
60- ),
61- dict (
62- logits = jnp .array (
63- [0.23 , 0.89 , 0.57 , 0.11 , 0.99 , 0.38 , 0.76 , 0.05 , 0.62 , 0.44 ]
64- ),
65- labels = jnp .array ([0 , 0 , 1 , 0 , 1 , 1 , 0 , 1 , 0 , 0 ]),
66- ),
67- dict (
68- logits = jnp .array (
69- [0.67 , 0.21 , 0.95 , 0.03 , 0.88 , 0.51 , 0.34 , 0.79 , 0.15 , 0.42 ]
70- ),
71- labels = jnp .array ([1 , 1 , 0 , 1 , 0 , 1 , 1 , 0 , 0 , 1 ]),
72- ),
73- dict (
74- logits = jnp .array (
75- [0.91 , 0.37 , 0.18 , 0.75 , 0.59 , 0.02 , 0.83 , 0.26 , 0.64 , 0.48 ]
76- ),
77- labels = jnp .array ([0 , 1 , 1 , 0 , 0 , 1 , 0 , 1 , 1 , 0 ]),
78- ),
79- )
80- self .model_outputs_batch_size_one = (
81- dict (
82- logits = jnp .array ([[0.32 ]]),
83- labels = jnp .array ([1 ]),
84- ),
85- dict (
86- logits = jnp .array ([[0.74 ]]),
87- labels = jnp .array ([1 ]),
88- ),
89- dict (
90- logits = jnp .array ([[0.86 ]]),
91- labels = jnp .array ([1 ]),
92- ),
93- dict (
94- logits = jnp .array ([[0.21 ]]),
95- labels = jnp .array ([1 ]),
96- ),
97- )
98- self .sample_weights = jnp .array ([0.5 , 1 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ])
99-
100- def compute_aucpr (self , model_outputs , sample_weights = None ):
101- metric = None
102- for model_output in model_outputs :
103- update = metrax .AUCPR .from_model_output (
104- predictions = model_output .get ('logits' ),
105- labels = model_output .get ('labels' ),
106- sample_weights = sample_weights ,
107- )
108- metric = update if metric is None else metric .merge (update )
109- return metric .compute ()
110-
111- def compute_aucroc (self , model_outputs , sample_weights = None ):
112- metric = None
113- for model_output in model_outputs :
114- update = metrax .AUCROC .from_model_output (
115- predictions = model_output .get ('logits' ),
116- labels = model_output .get ('labels' ),
117- sample_weights = sample_weights ,
118- )
119- metric = update if metric is None else metric .merge (update )
120- return metric .compute ()
121-
12251 def test_mse_empty (self ):
12352 """Tests the `empty` method of the `MSE` class."""
12453 m = metrax .MSE .empty ()
@@ -197,10 +126,13 @@ def sharded_r2(logits, labels):
197126 metric = jax .jit (sharded_r2 )(y_pred , y_true )
198127 metric = metric .reduce ()
199128
200- expected = sklearn_metrics .r2_score (
201- y_true .flatten (),
202- y_pred .flatten (),
203- )
129+ keras_r2 = keras .metrics .R2Score ()
130+ for labels , logits in zip (y_true , y_pred ):
131+ keras_r2 .update_state (
132+ labels [:, jnp .newaxis ],
133+ logits [:, jnp .newaxis ],
134+ )
135+ expected = keras_r2 .result ()
204136 np .testing .assert_allclose (
205137 metric .compute (),
206138 expected ,
@@ -215,10 +147,12 @@ def sharded_r2(logits, labels):
215147 ('batch_size_one' , OUTPUT_LABELS_BS1 , OUTPUT_PREDS_BS1 , 0.5 ),
216148 )
217149 def test_precision (self , y_true , y_pred , threshold ):
218- """Test that Precision metric computes correct values."""
150+ """Test that ` Precision` metric computes correct values."""
219151 y_true = y_true .reshape ((- 1 ,))
220152 y_pred = jnp .where (y_pred .reshape ((- 1 ,)) >= threshold , 1 , 0 )
221- expected = sklearn_metrics .precision_score (y_true , y_pred )
153+ keras_precision = keras .metrics .Precision (thresholds = threshold )
154+ keras_precision .update_state (y_true , y_pred )
155+ expected = keras_precision .result ()
222156
223157 metric = None
224158 for logits , labels in zip (y_pred , y_true ):
@@ -241,10 +175,12 @@ def test_precision(self, y_true, y_pred, threshold):
241175 ('batch_size_one' , OUTPUT_LABELS_BS1 , OUTPUT_PREDS_BS1 , 0.5 ),
242176 )
243177 def test_recall (self , y_true , y_pred , threshold ):
244- """Test that Recall metric computes correct values."""
178+ """Test that ` Recall` metric computes correct values."""
245179 y_true = y_true .reshape ((- 1 ,))
246180 y_pred = jnp .where (y_pred .reshape ((- 1 ,)) >= threshold , 1 , 0 )
247- expected = sklearn_metrics .recall_score (y_true , y_pred )
181+ keras_recall = keras .metrics .Recall (thresholds = threshold )
182+ keras_recall .update_state (y_true , y_pred )
183+ expected = keras_recall .result ()
248184
249185 metric = None
250186 for logits , labels in zip (y_pred , y_true ):
@@ -260,64 +196,64 @@ def test_recall(self, y_true, y_pred, threshold):
260196 expected ,
261197 )
262198
263- def test_aucpr (self ):
264- """Test that AUC-PR Metric computes correct values."""
265- np .testing .assert_allclose (
266- self .compute_aucpr (self .model_outputs ),
267- jnp .array (0.41513795 , dtype = jnp .float32 ),
268- )
199+ @parameterized .named_parameters (
200+ ('basic' , OUTPUT_LABELS , OUTPUT_PREDS , None ),
201+ ('batch_size_one' , OUTPUT_LABELS_BS1 , OUTPUT_PREDS_BS1 , None ),
202+ ('weighted' , OUTPUT_LABELS , OUTPUT_PREDS , SAMPLE_WEIGHTS ),
203+ )
204+ def test_aucpr (self , y_true , y_pred , sample_weights ):
205+ """Test that `AUC-PR` Metric computes correct values."""
206+ if sample_weights is None :
207+ sample_weights = np .ones_like (y_true )
269208
270- def test_aucpr_with_sample_weight (self ):
271- """Test that AUC-PR Metric computes correct values when using sample weights."""
272- np .testing .assert_allclose (
273- self .compute_aucpr (self .model_outputs , self .sample_weights ),
274- jnp .array (0.32785615 , dtype = jnp .float32 ),
275- )
209+ metric = None
210+ for labels , logits , weights in zip (y_true , y_pred , sample_weights ):
211+ update = metrax .AUCPR .from_model_output (
212+ predictions = logits ,
213+ labels = labels ,
214+ sample_weights = weights ,
215+ )
216+ metric = update if metric is None else metric .merge (update )
276217
277- def test_aucpr_with_batch_size_one (self ):
278- """Test that AUC-PR Metric computes correct values with batch size one."""
218+ keras_aucpr = keras .metrics .AUC (curve = 'PR' )
219+ for labels , logits , weights in zip (y_true , y_pred , sample_weights ):
220+ keras_aucpr .update_state (labels , logits , sample_weight = weights )
221+ expected = keras_aucpr .result ()
279222 np .testing .assert_allclose (
280- self .compute_aucpr (self .model_outputs_batch_size_one ),
281- jnp .array (1.0 , dtype = jnp .float32 ),
223+ metric .compute (),
224+ expected ,
225+ rtol = 1e-07 ,
226+ atol = 1e-07 ,
282227 )
283228
284- def test_aucroc (self ):
285- """Test that AUC-ROC Metric computes correct values."""
286- # Concatenate logits and labels
287- all_logits = jnp .concatenate (
288- [model_output ['logits' ] for model_output in self .model_outputs ]
289- )
290- all_labels = jnp .concatenate (
291- [model_output ['labels' ] for model_output in self .model_outputs ]
292- )
293- np .testing .assert_allclose (
294- self .compute_aucroc (self .model_outputs ),
295- sklearn_metrics .roc_auc_score (all_labels , all_logits ),
296- )
229+ @parameterized .named_parameters (
230+ ('basic' , OUTPUT_LABELS , OUTPUT_PREDS , None ),
231+ ('batch_size_one' , OUTPUT_LABELS_BS1 , OUTPUT_PREDS_BS1 , None ),
232+ ('weighted' , OUTPUT_LABELS , OUTPUT_PREDS , SAMPLE_WEIGHTS ),
233+ )
234+ def test_aucroc (self , y_true , y_pred , sample_weights ):
235+ """Test that `AUC-ROC` Metric computes correct values."""
236+ if sample_weights is None :
237+ sample_weights = np .ones_like (y_true )
297238
298- def test_aucroc_with_sample_weight (self ):
299- """Test that AUC-ROC Metric computes correct values when using sample weights."""
300- # Concatenate logits and labels
301- all_logits = jnp .concatenate (
302- [model_output ['logits' ] for model_output in self .model_outputs ]
303- )
304- all_labels = jnp .concatenate (
305- [model_output ['labels' ] for model_output in self .model_outputs ]
306- )
307- sample_weights = jnp .concatenate (
308- [self .sample_weights ] * len (self .model_outputs )
309- )
239+ metric = None
240+ for labels , logits , weights in zip (y_true , y_pred , sample_weights ):
241+ update = metrax .AUCROC .from_model_output (
242+ predictions = logits ,
243+ labels = labels ,
244+ sample_weights = weights ,
245+ )
246+ metric = update if metric is None else metric .merge (update )
247+
248+ keras_aucroc = keras .metrics .AUC (curve = 'ROC' )
249+ for labels , logits , weights in zip (y_true , y_pred , sample_weights ):
250+ keras_aucroc .update_state (labels , logits , sample_weight = weights )
251+ expected = keras_aucroc .result ()
310252 np .testing .assert_allclose (
311- jnp .array (
312- self .compute_aucroc (self .model_outputs , self .sample_weights ),
313- dtype = jnp .float16 ,
314- ),
315- jnp .array (
316- sklearn_metrics .roc_auc_score (
317- all_labels , all_logits , sample_weight = sample_weights
318- ),
319- dtype = jnp .float16 ,
320- ),
253+ metric .compute (),
254+ expected ,
255+ rtol = 1e-07 ,
256+ atol = 1e-07 ,
321257 )
322258
323259 @parameterized .named_parameters (
@@ -326,6 +262,7 @@ def test_aucroc_with_sample_weight(self):
326262 ('weighted' , OUTPUT_LABELS , OUTPUT_PREDS , SAMPLE_WEIGHTS ),
327263 )
328264 def test_mse (self , y_true , y_pred , sample_weights ):
265+ """Test that `MSE` Metric computes correct values."""
329266 if sample_weights is None :
330267 sample_weights = np .ones_like (y_true )
331268
@@ -338,6 +275,8 @@ def test_mse(self, y_true, y_pred, sample_weights):
338275 )
339276 metric = update if metric is None else metric .merge (update )
340277
278+ # TODO(jiwonshin): Use `keras.metrics.MeanSquaredError` once it supports
279+ # sample weights.
341280 expected = sklearn_metrics .mean_squared_error (
342281 y_true .flatten (),
343282 y_pred .flatten (),
@@ -356,6 +295,7 @@ def test_mse(self, y_true, y_pred, sample_weights):
356295 ('weighted' , OUTPUT_LABELS , OUTPUT_PREDS , SAMPLE_WEIGHTS ),
357296 )
358297 def test_rmse (self , y_true , y_pred , sample_weights ):
298+ """Test that `RMSE` Metric computes correct values."""
359299 if sample_weights is None :
360300 sample_weights = np .ones_like (y_true )
361301
@@ -368,13 +308,10 @@ def test_rmse(self, y_true, y_pred, sample_weights):
368308 )
369309 metric = update if metric is None else metric .merge (update )
370310
371- # `sklearn_metrics.root_mean_squared_error` is not available.
372- expected = jnp .sqrt (
373- jnp .average (
374- jnp .square (y_pred .flatten () - y_true .flatten ()),
375- weights = sample_weights .flatten (),
376- ),
377- )
311+ keras_rmse = keras .metrics .RootMeanSquaredError ()
312+ for labels , logits , weights in zip (y_true , y_pred , sample_weights ):
313+ keras_rmse .update_state (labels , logits , sample_weight = weights )
314+ expected = keras_rmse .result ()
378315 np .testing .assert_allclose (
379316 metric .compute (),
380317 expected ,
@@ -388,6 +325,7 @@ def test_rmse(self, y_true, y_pred, sample_weights):
388325 ('weighted' , OUTPUT_LABELS , OUTPUT_PREDS , SAMPLE_WEIGHTS ),
389326 )
390327 def test_rsquared (self , y_true , y_pred , sample_weights ):
328+ """Test that `RSQUARED` Metric computes correct values."""
391329 if sample_weights is None :
392330 sample_weights = np .ones_like (y_true )
393331
@@ -400,11 +338,14 @@ def test_rsquared(self, y_true, y_pred, sample_weights):
400338 )
401339 metric = update if metric is None else metric .merge (update )
402340
403- expected = sklearn_metrics .r2_score (
404- y_true .flatten (),
405- y_pred .flatten (),
406- sample_weight = sample_weights .flatten (),
407- )
341+ keras_r2 = keras .metrics .R2Score ()
342+ for labels , logits , weights in zip (y_true , y_pred , sample_weights ):
343+ keras_r2 .update_state (
344+ labels [:, jnp .newaxis ],
345+ logits [:, jnp .newaxis ],
346+ sample_weight = weights [:, jnp .newaxis ],
347+ )
348+ expected = keras_r2 .result ()
408349 np .testing .assert_allclose (
409350 metric .compute (),
410351 expected ,
@@ -427,6 +368,7 @@ def test_rsquared(self, y_true, y_pred, sample_weights):
427368 ),
428369 )
429370 def test_perplexity (self , y_true , y_pred , sample_weights ):
371+ """Test that `Perplexity` Metric computes correct values."""
430372 keras_metric = keras_hub .metrics .Perplexity ()
431373 metrax_metric = None
432374 for index , (labels , logits ) in enumerate (zip (y_true , y_pred )):
@@ -454,4 +396,4 @@ def test_perplexity(self, y_true, y_pred, sample_weights):
454396 os .environ ['XLA_FLAGS' ] = (
455397 '--xla_force_host_platform_device_count=4' # Use 4 CPU devices
456398 )
457- absltest .main ()
399+ absltest .main ()
0 commit comments