Skip to content

Commit fec6999

Browse files
hyunn9973copybara-github
authored andcommitted
Add MSLE and RMSLE.
PiperOrigin-RevId: 842027777
1 parent ca6afce commit fec6999

File tree

6 files changed

+165
-0
lines changed

6 files changed

+165
-0
lines changed

src/metrax/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@
3434
MAE = regression_metrics.MAE
3535
MRR = ranking_metrics.MRR
3636
MSE = regression_metrics.MSE
37+
MSLE = regression_metrics.MSLE
3738
NDCGAtK = ranking_metrics.NDCGAtK
3839
Perplexity = nlp_metrics.Perplexity
3940
Precision = classification_metrics.Precision
4041
PrecisionAtK = ranking_metrics.PrecisionAtK
4142
PSNR = image_metrics.PSNR
4243
RMSE = regression_metrics.RMSE
44+
RMSLE = regression_metrics.RMSLE
4345
RSQUARED = regression_metrics.RSQUARED
4446
Recall = classification_metrics.Recall
4547
RecallAtK = ranking_metrics.RecallAtK
@@ -66,12 +68,14 @@
6668
"MAE",
6769
"MRR",
6870
"MSE",
71+
"MSLE",
6972
"NDCGAtK",
7073
"Perplexity",
7174
"Precision",
7275
"PrecisionAtK",
7376
"PSNR",
7477
"RMSE",
78+
"RMSLE",
7579
"RSQUARED",
7680
"SpearmanRankCorrelation",
7781
"Recall",

src/metrax/metrax_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ class MetraxTest(parameterized.TestCase):
148148
metrax.MSE,
149149
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
150150
),
151+
(
152+
'msle',
153+
metrax.MSLE,
154+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
155+
),
156+
151157
(
152158
'ndcgAtK',
153159
metrax.NDCGAtK,
@@ -190,6 +196,11 @@ class MetraxTest(parameterized.TestCase):
190196
metrax.RMSE,
191197
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
192198
),
199+
(
200+
'rmsle',
201+
metrax.RMSLE,
202+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
203+
),
193204
(
194205
'rsquared',
195206
metrax.RSQUARED,

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@
2828
MAE = nnx_metrics.MAE
2929
MRR = nnx_metrics.MRR
3030
MSE = nnx_metrics.MSE
31+
MSLE = nnx_metrics.MSLE
3132
NDCGAtK = nnx_metrics.NDCGAtK
3233
Perplexity = nnx_metrics.Perplexity
3334
Precision = nnx_metrics.Precision
3435
PrecisionAtK = nnx_metrics.PrecisionAtK
3536
PSNR = nnx_metrics.PSNR
3637
RMSE = nnx_metrics.RMSE
38+
RMSLE = nnx_metrics.RMSLE
3739
RSQUARED = nnx_metrics.RSQUARED
3840
Recall = nnx_metrics.Recall
3941
RecallAtK = nnx_metrics.RecallAtK

src/metrax/nnx/nnx_metrics.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ def __init__(self):
114114
super().__init__(metrax.MSE)
115115

116116

117+
class MSLE(NnxWrapper):
118+
"""An NNX class for the Metrax metric MSLE."""
119+
120+
def __init__(self):
121+
super().__init__(metrax.MSLE)
122+
123+
117124
class NDCGAtK(NnxWrapper):
118125
"""An NNX class for the Metrax metric NDCGAtK."""
119126

@@ -170,6 +177,13 @@ def __init__(self):
170177
super().__init__(metrax.RMSE)
171178

172179

180+
class RMSLE(NnxWrapper):
181+
"""An NNX class for the Metrax metric RMSLE."""
182+
183+
def __init__(self):
184+
super().__init__(metrax.RMSLE)
185+
186+
173187
class RougeL(NnxWrapper):
174188
"""An NNX class for the Metrax metric RougeL."""
175189

src/metrax/regression_metrics.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,68 @@ def compute(self) -> jax.Array:
160160
return jnp.sqrt(super().compute())
161161

162162

163+
@flax.struct.dataclass
164+
class MSLE(base.Average):
165+
r"""Computes the mean squared logarithmic error for regression problems given `predictions` and `labels`.
166+
167+
The mean squared logarithmic error is defined as:
168+
169+
.. math::
170+
MSLE = \frac{1}{N} \sum_{i=1}^{N} (ln(y_i + 1) - ln(\hat{y}_i + 1))^2
171+
172+
where:
173+
- :math:`y_i` are true values
174+
- :math:`\hat{y}_i` are predictions
175+
- :math:`N` is the number of samples
176+
"""
177+
178+
@classmethod
179+
def from_model_output(
180+
cls,
181+
predictions: jax.Array,
182+
labels: jax.Array,
183+
) -> 'MSLE':
184+
"""Updates the metric.
185+
186+
Args:
187+
predictions: A floating point 1D vector representing the prediction
188+
generated from the model. The shape should be (batch_size,).
189+
labels: True value. The shape should be (batch_size,).
190+
191+
Returns:
192+
Updated MSLE metric. The shape should be a single scalar.
193+
"""
194+
log_predictions = jnp.log(predictions + 1)
195+
log_labels = jnp.log(labels + 1)
196+
squared_error = jnp.square(log_predictions - log_labels)
197+
count = jnp.ones_like(labels, dtype=jnp.int32)
198+
return cls(
199+
total=squared_error.sum(),
200+
count=count.sum(),
201+
)
202+
203+
204+
@flax.struct.dataclass
205+
class RMSLE(MSLE):
206+
r"""Computes the root mean squared logarithmic error for regression problems given `predictions` and `labels`.
207+
208+
The root mean squared logarithmic error is defined as:
209+
210+
.. math::
211+
RMSLE = \sqrt{\frac{1}{N} \sum_{i=1}^{N}
212+
(ln(y_i + 1) - ln(\hat{y}_i + 1))^2
213+
}
214+
215+
where:
216+
- :math:`y_i` are true values
217+
- :math:`\hat{y}_i` are predictions
218+
- :math:`N` is the number of samples
219+
"""
220+
221+
def compute(self) -> jax.Array:
222+
return jnp.sqrt(super().compute())
223+
224+
163225
@flax.struct.dataclass
164226
class RSQUARED(clu_metrics.Metric):
165227
r"""Computes the r-squared score of a scalar or a batch of tensors.

src/metrax/regression_metrics_test.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,78 @@ def test_rmse(self, y_true, y_pred, sample_weights):
282282
atol=atol,
283283
)
284284

285+
@parameterized.named_parameters(
286+
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16),
287+
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32),
288+
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16),
289+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1),
290+
('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16),
291+
('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32),
292+
('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16),
293+
)
294+
def test_msle(self, y_true, y_pred):
295+
"""Test that `MSLE` Metric computes correct values."""
296+
y_true = y_true.astype(y_pred.dtype)
297+
y_pred = y_pred.astype(y_true.dtype)
298+
299+
metric = None
300+
for labels, logits in zip(y_true, y_pred):
301+
update = metrax.MSLE.from_model_output(
302+
predictions=logits,
303+
labels=labels,
304+
)
305+
metric = update if metric is None else metric.merge(update)
306+
307+
expected = sklearn_metrics.mean_squared_log_error(
308+
y_true.astype('float32').flatten(),
309+
y_pred.astype('float32').flatten(),
310+
)
311+
# Use lower tolerance for lower precision dtypes.
312+
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
313+
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
314+
np.testing.assert_allclose(
315+
metric.compute(),
316+
expected,
317+
rtol=rtol,
318+
atol=atol,
319+
)
320+
321+
@parameterized.named_parameters(
322+
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16),
323+
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32),
324+
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16),
325+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1),
326+
('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16),
327+
('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32),
328+
('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16),
329+
)
330+
def test_rmsle(self, y_true, y_pred):
331+
"""Test that `RMSLE` Metric computes correct values."""
332+
y_true = y_true.astype(y_pred.dtype)
333+
y_pred = y_pred.astype(y_true.dtype)
334+
335+
metric = None
336+
for labels, logits in zip(y_true, y_pred):
337+
update = metrax.RMSLE.from_model_output(
338+
predictions=logits,
339+
labels=labels,
340+
)
341+
metric = update if metric is None else metric.merge(update)
342+
343+
expected = sklearn_metrics.root_mean_squared_log_error(
344+
y_true.astype('float32').flatten(),
345+
y_pred.astype('float32').flatten(),
346+
)
347+
# Use lower tolerance for lower precision dtypes.
348+
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
349+
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
350+
np.testing.assert_allclose(
351+
metric.compute(),
352+
expected,
353+
rtol=rtol,
354+
atol=atol,
355+
)
356+
285357
@parameterized.named_parameters(
286358
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
287359
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),

0 commit comments

Comments
 (0)