Skip to content

Commit aa21972

Browse files
committed
Add Perplexity metric
1 parent e640329 commit aa21972

File tree

5 files changed

+124
-7
lines changed

5 files changed

+124
-7
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,15 @@ jobs:
1111
runs-on: ubuntu-latest
1212
steps:
1313
- uses: actions/checkout@v4
14-
- name: Set up Python
14+
- name: Set up Python 3.12
1515
uses: actions/setup-python@v5
1616
with:
17-
python-version: '3.x'
17+
python-version: 3.12
1818
- name: Install dependencies
1919
run: |
2020
python -m pip install --upgrade pip
2121
pip install -r requirements.txt
22-
- name: Test with pytest
22+
- name: Run Unit Tests
2323
run: |
24-
pip install pytest
2524
pytest ./src/
2625

requirements.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
1-
clu==0.0.12
2-
scikit-learn==1.6.1
1+
absl-py
2+
clu
3+
jax[cpu]
4+
keras-hub
5+
pytest
6+
scikit-learn
7+
tensorflow-cpu~=2.18
8+
tensorflow-text~=2.18

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
AUCPR,
1717
AUCROC,
1818
MSE,
19+
Perplexity,
1920
Precision,
2021
RMSE,
2122
RSQUARED,
@@ -26,6 +27,7 @@
2627
"MSE",
2728
"RMSE",
2829
"RSQUARED",
30+
"Perplexity",
2931
"Precision",
3032
"Recall",
3133
"AUCPR",

src/metrax/metrics.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,4 +536,76 @@ def compute(self) -> jax.Array:
536536
self.false_positives, self.false_positives + self.true_negatives
537537
)
538538
# Threshold goes from 0 to 1, so trapezoid is negative.
539-
return jnp.trapezoid(tp_rate, fp_rate) * -1
539+
return jnp.trapezoid(tp_rate, fp_rate) * -1
540+
541+
542+
@flax.struct.dataclass
543+
class Perplexity(clu_metrics.Metric):
544+
"""Computes perplexity for sequence generation.
545+
546+
Perplexity is a measurement of how well a probability distribution predicts a
547+
sample. It is defined as the exponentiation of the cross-entropy. A low
548+
perplexity indicates the probability distribution is good at predicting the
549+
sample.
550+
551+
For language models, it can be interpreted as the weighted average branching
552+
factor of the model - how many equally likely words can be selected at each
553+
step.
554+
"""
555+
556+
aggregate_crossentropy: jax.Array
557+
num_samples: jax.Array
558+
559+
@classmethod
560+
def from_model_output(
561+
cls,
562+
predictions: jax.Array,
563+
labels: jax.Array,
564+
sample_weights: jax.Array | None = None,
565+
) -> 'Perplexity':
566+
"""Updates the metric.
567+
568+
Args:
569+
predictions: A floating point tensor representing the prediction
570+
generated from the model. The shape should be (batch_size, seq_len,
571+
vocab_size).
572+
labels: True value. The shape should be (batch_size, seq_len).
573+
sample_weights: An optional tensor representing the
574+
weight of each token. The shape should be (batch_size, seq_len).
575+
576+
Returns:
577+
Updated Perplexity metric.
578+
579+
Raises:
580+
ValueError: If type of `labels` is wrong or the shapes of `predictions`
581+
and `labels` are incompatible.
582+
"""
583+
predictions = predictions / jnp.sum(predictions, axis=-1, keepdims=True)
584+
labels_one_hot = jax.nn.one_hot(labels, predictions.shape[-1], axis=-1)
585+
log_prob = jnp.log(predictions)
586+
crossentropy = -jnp.sum(labels_one_hot * log_prob, axis=-1)
587+
588+
# Sum across sequence length dimension first.
589+
if sample_weights is not None:
590+
crossentropy = crossentropy * sample_weights
591+
# Normalize by the sum of weights for each sequence.
592+
crossentropy = jnp.sum(crossentropy) / jnp.sum(sample_weights)
593+
else:
594+
crossentropy = jnp.mean(crossentropy)
595+
596+
batch_size = jnp.array(labels.shape[0])
597+
return cls(
598+
aggregate_crossentropy=(batch_size * crossentropy),
599+
num_samples=batch_size,
600+
)
601+
602+
def merge(self, other: 'Perplexity') -> 'Perplexity':
603+
return type(self)(
604+
aggregate_crossentropy=(
605+
self.aggregate_crossentropy + other.aggregate_crossentropy
606+
),
607+
num_samples=self.num_samples + other.num_samples,
608+
)
609+
610+
def compute(self) -> jax.Array:
611+
return jnp.exp(self.aggregate_crossentropy / self.num_samples)

src/metrax/metrics_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from absl.testing import absltest
1818
from absl.testing import parameterized
1919
import jax.numpy as jnp
20+
import keras_hub
2021
import metrax
2122
import numpy as np
2223
from sklearn import metrics as sklearn_metrics
@@ -320,6 +321,43 @@ def test_rsquared(self, y_true, y_pred, sample_weights):
320321
atol=1e-05,
321322
)
322323

324+
@parameterized.named_parameters(
325+
(
326+
'basic',
327+
np.random.randint(10, size=[2, 5, 10]),
328+
np.random.uniform(size=(2, 5, 10, 20)),
329+
None,
330+
),
331+
(
332+
'weighted',
333+
np.random.randint(10, size=[2, 5, 10]),
334+
np.random.uniform(size=(2, 5, 10, 20)),
335+
np.random.randint(2, size=(2, 5, 10)).astype(np.float32),
336+
),
337+
)
338+
def test_perplexity(self, y_true, y_pred, sample_weights):
339+
keras_metric = keras_hub.metrics.Perplexity()
340+
metrax_metric = None
341+
for index, (labels, logits) in enumerate(zip(y_true, y_pred)):
342+
weights = sample_weights[index] if sample_weights is not None else None
343+
keras_metric.update_state(labels, logits, sample_weight=weights)
344+
update = metrax.Perplexity.from_model_output(
345+
predictions=logits,
346+
labels=labels,
347+
sample_weights=weights,
348+
)
349+
metrax_metric = update if metrax_metric is None else metrax_metric.merge(
350+
update
351+
)
352+
353+
expected = keras_metric.result()
354+
np.testing.assert_allclose(
355+
metrax_metric.compute(),
356+
expected,
357+
rtol=1e-05,
358+
atol=1e-05,
359+
)
360+
323361

324362
if __name__ == '__main__':
325363
absltest.main()

0 commit comments

Comments
 (0)