@@ -282,6 +282,78 @@ def test_rmse(self, y_true, y_pred, sample_weights):
282282 atol = atol ,
283283 )
284284
285+ @parameterized .named_parameters (
286+ ('basic_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 ),
287+ ('basic_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 ),
288+ ('basic_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 ),
289+ ('batch_size_one' , OUTPUT_LABELS_BS1 , OUTPUT_PREDS_BS1 ),
290+ ('weighted_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 ),
291+ ('weighted_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 ),
292+ ('weighted_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 ),
293+ )
294+ def test_msle (self , y_true , y_pred ):
295+ """Test that `MSLE` Metric computes correct values."""
296+ y_true = y_true .astype (y_pred .dtype )
297+ y_pred = y_pred .astype (y_true .dtype )
298+
299+ metric = None
300+ for labels , logits in zip (y_true , y_pred ):
301+ update = metrax .MSLE .from_model_output (
302+ predictions = logits ,
303+ labels = labels ,
304+ )
305+ metric = update if metric is None else metric .merge (update )
306+
307+ expected = sklearn_metrics .mean_squared_log_error (
308+ y_true .astype ('float32' ).flatten (),
309+ y_pred .astype ('float32' ).flatten (),
310+ )
311+ # Use lower tolerance for lower precision dtypes.
312+ rtol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
313+ atol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
314+ np .testing .assert_allclose (
315+ metric .compute (),
316+ expected ,
317+ rtol = rtol ,
318+ atol = atol ,
319+ )
320+
321+ @parameterized .named_parameters (
322+ ('basic_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 ),
323+ ('basic_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 ),
324+ ('basic_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 ),
325+ ('batch_size_one' , OUTPUT_LABELS_BS1 , OUTPUT_PREDS_BS1 ),
326+ ('weighted_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 ),
327+ ('weighted_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 ),
328+ ('weighted_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 ),
329+ )
330+ def test_rmsle (self , y_true , y_pred ):
331+ """Test that `RMSLE` Metric computes correct values."""
332+ y_true = y_true .astype (y_pred .dtype )
333+ y_pred = y_pred .astype (y_true .dtype )
334+
335+ metric = None
336+ for labels , logits in zip (y_true , y_pred ):
337+ update = metrax .RMSLE .from_model_output (
338+ predictions = logits ,
339+ labels = labels ,
340+ )
341+ metric = update if metric is None else metric .merge (update )
342+
343+ expected = sklearn_metrics .root_mean_squared_log_error (
344+ y_true .astype ('float32' ).flatten (),
345+ y_pred .astype ('float32' ).flatten (),
346+ )
347+ # Use lower tolerance for lower precision dtypes.
348+ rtol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
349+ atol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
350+ np .testing .assert_allclose (
351+ metric .compute (),
352+ expected ,
353+ rtol = rtol ,
354+ atol = atol ,
355+ )
356+
285357 @parameterized .named_parameters (
286358 ('basic_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 , None ),
287359 ('basic_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 , None ),
0 commit comments