Skip to content

Commit 778e00a

Browse files
hyunn9973copybara-github
authored andcommitted
Add MSLE and RMSLE.
PiperOrigin-RevId: 842027777
1 parent ca6afce commit 778e00a

File tree

6 files changed

+179
-0
lines changed

6 files changed

+179
-0
lines changed

src/metrax/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@
3434
MAE = regression_metrics.MAE
3535
MRR = ranking_metrics.MRR
3636
MSE = regression_metrics.MSE
37+
MSLE = regression_metrics.MSLE
3738
NDCGAtK = ranking_metrics.NDCGAtK
3839
Perplexity = nlp_metrics.Perplexity
3940
Precision = classification_metrics.Precision
4041
PrecisionAtK = ranking_metrics.PrecisionAtK
4142
PSNR = image_metrics.PSNR
4243
RMSE = regression_metrics.RMSE
44+
RMSLE = regression_metrics.RMSLE
4345
RSQUARED = regression_metrics.RSQUARED
4446
Recall = classification_metrics.Recall
4547
RecallAtK = ranking_metrics.RecallAtK
@@ -66,12 +68,14 @@
6668
"MAE",
6769
"MRR",
6870
"MSE",
71+
"MSLE",
6972
"NDCGAtK",
7073
"Perplexity",
7174
"Precision",
7275
"PrecisionAtK",
7376
"PSNR",
7477
"RMSE",
78+
"RMSLE",
7579
"RSQUARED",
7680
"SpearmanRankCorrelation",
7781
"Recall",

src/metrax/metrax_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,12 @@ class MetraxTest(parameterized.TestCase):
148148
metrax.MSE,
149149
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
150150
),
151+
(
152+
'msle',
153+
metrax.MSLE,
154+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
155+
),
156+
151157
(
152158
'ndcgAtK',
153159
metrax.NDCGAtK,
@@ -190,6 +196,11 @@ class MetraxTest(parameterized.TestCase):
190196
metrax.RMSE,
191197
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
192198
),
199+
(
200+
'rmsle',
201+
metrax.RMSLE,
202+
{'predictions': OUTPUT_LABELS, 'labels': OUTPUT_PREDS},
203+
),
193204
(
194205
'rsquared',
195206
metrax.RSQUARED,

src/metrax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@
2828
MAE = nnx_metrics.MAE
2929
MRR = nnx_metrics.MRR
3030
MSE = nnx_metrics.MSE
31+
MSLE = nnx_metrics.MSLE
3132
NDCGAtK = nnx_metrics.NDCGAtK
3233
Perplexity = nnx_metrics.Perplexity
3334
Precision = nnx_metrics.Precision
3435
PrecisionAtK = nnx_metrics.PrecisionAtK
3536
PSNR = nnx_metrics.PSNR
3637
RMSE = nnx_metrics.RMSE
38+
RMSLE = nnx_metrics.RMSLE
3739
RSQUARED = nnx_metrics.RSQUARED
3840
Recall = nnx_metrics.Recall
3941
RecallAtK = nnx_metrics.RecallAtK

src/metrax/nnx/nnx_metrics.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ def __init__(self):
114114
super().__init__(metrax.MSE)
115115

116116

117+
class MSLE(NnxWrapper):
118+
"""An NNX class for the Metrax metric MSLE."""
119+
120+
def __init__(self):
121+
super().__init__(metrax.MSLE)
122+
123+
117124
class NDCGAtK(NnxWrapper):
118125
"""An NNX class for the Metrax metric NDCGAtK."""
119126

@@ -170,6 +177,13 @@ def __init__(self):
170177
super().__init__(metrax.RMSE)
171178

172179

180+
class RMSLE(NnxWrapper):
181+
"""An NNX class for the Metrax metric RMSLE."""
182+
183+
def __init__(self):
184+
super().__init__(metrax.RMSLE)
185+
186+
173187
class RougeL(NnxWrapper):
174188
"""An NNX class for the Metrax metric RougeL."""
175189

src/metrax/regression_metrics.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,74 @@ def compute(self) -> jax.Array:
160160
return jnp.sqrt(super().compute())
161161

162162

163+
@flax.struct.dataclass
164+
class MSLE(base.Average):
165+
r"""Computes the mean squared logarithmic error for regression problems given `predictions` and `labels`.
166+
167+
The mean squared logarithmic error is defined as:
168+
169+
.. math::
170+
MSLE = \frac{1}{N} \sum_{i=1}^{N} (ln(y_i + 1) - ln(\hat{y}_i + 1))^2
171+
172+
where:
173+
- :math:`y_i` are true values
174+
- :math:`\hat{y}_i` are predictions
175+
- :math:`N` is the number of samples
176+
"""
177+
178+
@classmethod
179+
def from_model_output(
180+
cls,
181+
predictions: jax.Array,
182+
labels: jax.Array,
183+
sample_weights: jax.Array | None = None,
184+
) -> 'MSLE':
185+
"""Updates the metric.
186+
187+
Args:
188+
predictions: A floating point 1D vector representing the prediction
189+
generated from the model. The shape should be (batch_size,).
190+
labels: True value. The shape should be (batch_size,).
191+
sample_weights: An optional floating point 1D vector representing the
192+
weight of each sample. The shape should be (batch_size,).
193+
194+
Returns:
195+
Updated MSLE metric. The shape should be a single scalar.
196+
"""
197+
log_predictions = jnp.log1p(predictions)
198+
log_labels = jnp.log1p(labels)
199+
squared_error = jnp.square(log_predictions - log_labels)
200+
count = jnp.ones_like(labels, dtype=jnp.int32)
201+
if sample_weights is not None:
202+
squared_error = squared_error * sample_weights
203+
count = count * sample_weights
204+
return cls(
205+
total=squared_error.sum(),
206+
count=count.sum(),
207+
)
208+
209+
210+
@flax.struct.dataclass
211+
class RMSLE(MSLE):
212+
r"""Computes the root mean squared logarithmic error for regression problems given `predictions` and `labels`.
213+
214+
The root mean squared logarithmic error is defined as:
215+
216+
.. math::
217+
RMSLE = \sqrt{\frac{1}{N} \sum_{i=1}^{N}
218+
(ln(y_i + 1) - ln(\hat{y}_i + 1))^2
219+
}
220+
221+
where:
222+
- :math:`y_i` are true values
223+
- :math:`\hat{y}_i` are predictions
224+
- :math:`N` is the number of samples
225+
"""
226+
227+
def compute(self) -> jax.Array:
228+
return jnp.sqrt(super().compute())
229+
230+
163231
@flax.struct.dataclass
164232
class RSQUARED(clu_metrics.Metric):
165233
r"""Computes the r-squared score of a scalar or a batch of tensors.

src/metrax/regression_metrics_test.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,86 @@ def test_rmse(self, y_true, y_pred, sample_weights):
282282
atol=atol,
283283
)
284284

285+
@parameterized.named_parameters(
286+
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
287+
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),
288+
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, None),
289+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
290+
('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS),
291+
('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS),
292+
('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS),
293+
)
294+
def test_msle(self, y_true, y_pred, sample_weights):
295+
"""Test that `MSLE` Metric computes correct values."""
296+
y_true = y_true.astype(y_pred.dtype)
297+
y_pred = y_pred.astype(y_true.dtype)
298+
if sample_weights is None:
299+
sample_weights = np.ones_like(y_true)
300+
301+
metric = None
302+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
303+
update = metrax.MSLE.from_model_output(
304+
predictions=logits,
305+
labels=labels,
306+
sample_weights=weights,
307+
)
308+
metric = update if metric is None else metric.merge(update)
309+
310+
expected = sklearn_metrics.mean_squared_log_error(
311+
y_true.astype('float32').flatten(),
312+
y_pred.astype('float32').flatten(),
313+
sample_weight=sample_weights.astype('float32').flatten(),
314+
)
315+
# Use lower tolerance for lower precision dtypes.
316+
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
317+
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
318+
np.testing.assert_allclose(
319+
metric.compute(),
320+
expected,
321+
rtol=rtol,
322+
atol=atol,
323+
)
324+
325+
@parameterized.named_parameters(
326+
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
327+
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),
328+
('basic_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, None),
329+
('batch_size_one', OUTPUT_LABELS_BS1, OUTPUT_PREDS_BS1, None),
330+
('weighted_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, SAMPLE_WEIGHTS),
331+
('weighted_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, SAMPLE_WEIGHTS),
332+
('weighted_bf16', OUTPUT_LABELS, OUTPUT_PREDS_BF16, SAMPLE_WEIGHTS),
333+
)
334+
def test_rmsle(self, y_true, y_pred, sample_weights):
335+
"""Test that `RMSLE` Metric computes correct values."""
336+
y_true = y_true.astype(y_pred.dtype)
337+
y_pred = y_pred.astype(y_true.dtype)
338+
if sample_weights is None:
339+
sample_weights = np.ones_like(y_true)
340+
341+
metric = None
342+
for labels, logits, weights in zip(y_true, y_pred, sample_weights):
343+
update = metrax.RMSLE.from_model_output(
344+
predictions=logits,
345+
labels=labels,
346+
sample_weights=weights,
347+
)
348+
metric = update if metric is None else metric.merge(update)
349+
350+
expected = sklearn_metrics.root_mean_squared_log_error(
351+
y_true.astype('float32').flatten(),
352+
y_pred.astype('float32').flatten(),
353+
sample_weight=sample_weights.astype('float32').flatten(),
354+
)
355+
# Use lower tolerance for lower precision dtypes.
356+
rtol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
357+
atol = 1e-2 if y_true.dtype in (jnp.float16, jnp.bfloat16) else 1e-05
358+
np.testing.assert_allclose(
359+
metric.compute(),
360+
expected,
361+
rtol=rtol,
362+
atol=atol,
363+
)
364+
285365
@parameterized.named_parameters(
286366
('basic_f16', OUTPUT_LABELS, OUTPUT_PREDS_F16, None),
287367
('basic_f32', OUTPUT_LABELS, OUTPUT_PREDS_F32, None),

0 commit comments

Comments
 (0)