@@ -282,6 +282,86 @@ def test_rmse(self, y_true, y_pred, sample_weights):
282282 atol = atol ,
283283 )
284284
285+ @parameterized .named_parameters (
286+ ('basic_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 , None ),
287+ ('basic_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 , None ),
288+ ('basic_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 , None ),
289+ ('batch_size_one' , OUTPUT_LABELS_BS1 , OUTPUT_PREDS_BS1 , None ),
290+ ('weighted_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 , SAMPLE_WEIGHTS ),
291+ ('weighted_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 , SAMPLE_WEIGHTS ),
292+ ('weighted_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 , SAMPLE_WEIGHTS ),
293+ )
294+ def test_msle (self , y_true , y_pred , sample_weights ):
295+ """Test that `MSLE` Metric computes correct values."""
296+ y_true = y_true .astype (y_pred .dtype )
297+ y_pred = y_pred .astype (y_true .dtype )
298+ if sample_weights is None :
299+ sample_weights = np .ones_like (y_true )
300+
301+ metric = None
302+ for labels , logits , weights in zip (y_true , y_pred , sample_weights ):
303+ update = metrax .MSLE .from_model_output (
304+ predictions = logits ,
305+ labels = labels ,
306+ sample_weights = weights ,
307+ )
308+ metric = update if metric is None else metric .merge (update )
309+
310+ expected = sklearn_metrics .mean_squared_log_error (
311+ y_true .astype ('float32' ).flatten (),
312+ y_pred .astype ('float32' ).flatten (),
313+ sample_weight = sample_weights .astype ('float32' ).flatten (),
314+ )
315+ # Use lower tolerance for lower precision dtypes.
316+ rtol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
317+ atol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
318+ np .testing .assert_allclose (
319+ metric .compute (),
320+ expected ,
321+ rtol = rtol ,
322+ atol = atol ,
323+ )
324+
325+ @parameterized .named_parameters (
326+ ('basic_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 , None ),
327+ ('basic_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 , None ),
328+ ('basic_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 , None ),
329+ ('batch_size_one' , OUTPUT_LABELS_BS1 , OUTPUT_PREDS_BS1 , None ),
330+ ('weighted_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 , SAMPLE_WEIGHTS ),
331+ ('weighted_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 , SAMPLE_WEIGHTS ),
332+ ('weighted_bf16' , OUTPUT_LABELS , OUTPUT_PREDS_BF16 , SAMPLE_WEIGHTS ),
333+ )
334+ def test_rmsle (self , y_true , y_pred , sample_weights ):
335+ """Test that `RMSLE` Metric computes correct values."""
336+ y_true = y_true .astype (y_pred .dtype )
337+ y_pred = y_pred .astype (y_true .dtype )
338+ if sample_weights is None :
339+ sample_weights = np .ones_like (y_true )
340+
341+ metric = None
342+ for labels , logits , weights in zip (y_true , y_pred , sample_weights ):
343+ update = metrax .RMSLE .from_model_output (
344+ predictions = logits ,
345+ labels = labels ,
346+ sample_weights = weights ,
347+ )
348+ metric = update if metric is None else metric .merge (update )
349+
350+ expected = sklearn_metrics .root_mean_squared_log_error (
351+ y_true .astype ('float32' ).flatten (),
352+ y_pred .astype ('float32' ).flatten (),
353+ sample_weight = sample_weights .astype ('float32' ).flatten (),
354+ )
355+ # Use lower tolerance for lower precision dtypes.
356+ rtol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
357+ atol = 1e-2 if y_true .dtype in (jnp .float16 , jnp .bfloat16 ) else 1e-05
358+ np .testing .assert_allclose (
359+ metric .compute (),
360+ expected ,
361+ rtol = rtol ,
362+ atol = atol ,
363+ )
364+
285365 @parameterized .named_parameters (
286366 ('basic_f16' , OUTPUT_LABELS , OUTPUT_PREDS_F16 , None ),
287367 ('basic_f32' , OUTPUT_LABELS , OUTPUT_PREDS_F32 , None ),
0 commit comments