Skip to content

Commit 79fefa2

Browse files
authored
Perplexity: add clipping and from_logits (#47)
It was pointed out that Perplexity returns NaNs for negative values. This is because our implementation did not clip logit values to [0, 1], whereas the Keras implementation does. [1] Even with that fix, the tests were failing because Keras defaults to the TensorFlow version of the metric, which applies softmax to the outputs unconditionally [2], unlike the JAX implementation which does not. [3] I also added a `from_logits` arg, similar to Keras, for users who want to pass raw logits and have us apply softmax internally. [1] https://github.com/keras-team/keras/blob/3f8b065e82b17884bd43fcfbd4bd79f18a7019fe/keras/src/backend/jax/nn.py#L582 [2] https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits [3] https://github.com/keras-team/keras/blob/3f8b065e82b17884bd43fcfbd4bd79f18a7019fe/keras/src/backend/jax/nn.py#L578-L579
1 parent c698d95 commit 79fefa2

File tree

5 files changed

+46
-9
lines changed

5 files changed

+46
-9
lines changed

src/metrax/base_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
"""Tests for metrax base utilities."""
1616

17+
import os
18+
os.environ['KERAS_BACKEND'] = 'jax'
19+
1720
from absl.testing import absltest
1821
from absl.testing import parameterized
1922
import jax.numpy as jnp

src/metrax/nlp_metrics.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def from_model_output(
6868
predictions: jax.Array,
6969
labels: jax.Array,
7070
sample_weights: jax.Array | None = None,
71+
from_logits: bool = False,
7172
) -> 'Perplexity':
7273
"""Updates the metric.
7374
@@ -78,6 +79,9 @@ def from_model_output(
7879
labels: True value. The shape should be (batch_size, seq_len).
7980
sample_weights: An optional tensor representing the
8081
weight of each token. The shape should be (batch_size, seq_len).
82+
from_logits: Whether the predictions are logits. If True, the predictions
83+
are converted to probabilities using a softmax. If False, all values
84+
outside of [0, 1] are clipped to 0 or 1.
8185
8286
Returns:
8387
Updated Perplexity metric.
@@ -86,11 +90,17 @@ def from_model_output(
8690
ValueError: If type of `labels` is wrong or the shapes of `predictions`
8791
and `labels` are incompatible.
8892
"""
89-
predictions = base.divide_no_nan(
90-
predictions, jnp.sum(predictions, axis=-1, keepdims=True)
91-
)
93+
if from_logits:
94+
log_prob = jax.nn.log_softmax(predictions, axis=-1)
95+
else:
96+
predictions = base.divide_no_nan(
97+
predictions, jnp.sum(predictions, axis=-1, keepdims=True)
98+
)
99+
epsilon = 1e-7
100+
predictions = jnp.clip(predictions, epsilon, 1.0 - epsilon)
101+
log_prob = jnp.log(predictions)
102+
92103
labels_one_hot = jax.nn.one_hot(labels, predictions.shape[-1], axis=-1)
93-
log_prob = jnp.log(predictions)
94104
crossentropy = -jnp.sum(labels_one_hot * log_prob, axis=-1)
95105

96106
# Sum across sequence length dimension first.
@@ -227,4 +237,4 @@ def _levenshtein_distance(prediction: list, reference: list) -> int:
227237
distance_matrix[i - 1][j - 1] + cost, # substitution
228238
)
229239

230-
return distance_matrix[m][n]
240+
return distance_matrix[m][n]

src/metrax/nlp_metrics_test.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,18 @@
1414

1515
"""Tests for metrax nlp metrics."""
1616

17+
import os
18+
os.environ['KERAS_BACKEND'] = 'jax'
19+
1720
from absl.testing import absltest
1821
from absl.testing import parameterized
1922
import jax.numpy as jnp
2023
import keras_hub
2124
import metrax
2225
import numpy as np
2326

27+
np.random.seed(42)
28+
2429

2530
class NlpMetricsTest(parameterized.TestCase):
2631

@@ -42,17 +47,33 @@ def test_wer_empty(self):
4247
np.random.randint(10, size=[2, 5, 10]),
4348
np.random.uniform(size=(2, 5, 10, 20)),
4449
None,
50+
False,
4551
),
4652
(
4753
'weighted',
4854
np.random.randint(10, size=[2, 5, 10]),
4955
np.random.uniform(size=(2, 5, 10, 20)),
5056
np.random.randint(2, size=(2, 5, 10)).astype(np.float32),
57+
False,
58+
),
59+
(
60+
'negative_values',
61+
np.random.randint(10, size=[2, 5, 10]),
62+
np.random.uniform(size=(2, 5, 10, 20), low=-2, high=2),
63+
None,
64+
False,
65+
),
66+
(
67+
'from_logits',
68+
np.random.randint(10, size=[2, 5, 10]),
69+
np.random.uniform(size=(2, 5, 10, 20), low=-2, high=2),
70+
None,
71+
True,
5172
),
5273
)
53-
def test_perplexity(self, y_true, y_pred, sample_weights):
74+
def test_perplexity(self, y_true, y_pred, sample_weights, from_logits):
5475
"""Test that `Perplexity` Metric computes correct values."""
55-
keras_metric = keras_hub.metrics.Perplexity()
76+
keras_metric = keras_hub.metrics.Perplexity(from_logits=from_logits)
5677
metrax_metric = None
5778
for index, (labels, logits) in enumerate(zip(y_true, y_pred)):
5879
weights = sample_weights[index] if sample_weights is not None else None
@@ -61,6 +82,7 @@ def test_perplexity(self, y_true, y_pred, sample_weights):
6182
predictions=logits,
6283
labels=labels,
6384
sample_weights=weights,
85+
from_logits=from_logits,
6486
)
6587
metrax_metric = update if metrax_metric is None else metrax_metric.merge(
6688
update
@@ -119,4 +141,4 @@ def test_wer(self):
119141

120142

121143
if __name__ == '__main__':
122-
absltest.main()
144+
absltest.main()

src/metrax/regression_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def merge(self, other: 'RSQUARED') -> 'RSQUARED':
203203
)
204204

205205
def compute(self) -> jax.Array:
206-
"""Computes the r-squared score.
206+
r"""Computes the r-squared score.
207207
208208
Since we don't know the mean of the labels before we aggregate all of the
209209
data, we will manipulate the formula to be:

src/metrax/regression_metrics_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""Tests for metrax regression metrics."""
1616

1717
import os
18+
os.environ['KERAS_BACKEND'] = 'jax'
19+
1820
from absl.testing import absltest
1921
from absl.testing import parameterized
2022
import jax

0 commit comments

Comments
 (0)