Skip to content

Commit 280120c

Browse files
committed
Add Perplexity metric
1 parent e640329 commit 280120c

File tree

5 files changed

+128
-7
lines changed

5 files changed

+128
-7
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,15 @@ jobs:
1111
runs-on: ubuntu-latest
1212
steps:
1313
- uses: actions/checkout@v4
14-
- name: Set up Python
14+
- name: Set up Python 3.9
1515
uses: actions/setup-python@v5
1616
with:
17-
python-version: '3.x'
17+
python-version: 3.9
1818
- name: Install dependencies
1919
run: |
2020
python -m pip install --upgrade pip
2121
pip install -r requirements.txt
22-
- name: Test with pytest
22+
- name: Run Unit Tests
2323
run: |
24-
pip install pytest
2524
pytest ./src/
2625

requirements.txt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,12 @@
1-
clu==0.0.12
2-
scikit-learn==1.6.1
1+
# Debugging
2+
tensorflow-cpu~=2.18
3+
tensorflow-text~=2.18
4+
jax[cpu]
5+
6+
absl-py
7+
clu
8+
# jax==0.4.37
9+
scikit-learn
10+
# keras==3.8.0
11+
keras-hub
12+
pytest

src/metrax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
AUCPR,
1717
AUCROC,
1818
MSE,
19+
Perplexity,
1920
Precision,
2021
RMSE,
2122
RSQUARED,
@@ -26,6 +27,7 @@
2627
"MSE",
2728
"RMSE",
2829
"RSQUARED",
30+
"Perplexity",
2931
"Precision",
3032
"Recall",
3133
"AUCPR",

src/metrax/metrics.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,4 +536,76 @@ def compute(self) -> jax.Array:
536536
self.false_positives, self.false_positives + self.true_negatives
537537
)
538538
# Threshold goes from 0 to 1, so trapezoid is negative.
539-
return jnp.trapezoid(tp_rate, fp_rate) * -1
539+
return jnp.trapezoid(tp_rate, fp_rate) * -1
540+
541+
542+
@flax.struct.dataclass
543+
class Perplexity(clu_metrics.Metric):
544+
"""Computes perplexity for sequence generation.
545+
546+
Perplexity is a measurement of how well a probability distribution predicts a
547+
sample. It is defined as the exponentiation of the cross-entropy. A low
548+
perplexity indicates the probability distribution is good at predicting the
549+
sample.
550+
551+
For language models, it can be interpreted as the weighted average branching
552+
factor of the model - how many equally likely words can be selected at each
553+
step.
554+
"""
555+
556+
aggregate_crossentropy: jax.Array
557+
num_samples: jax.Array
558+
559+
@classmethod
560+
def from_model_output(
561+
cls,
562+
predictions: jax.Array,
563+
labels: jax.Array,
564+
sample_weights: jax.Array | None = None,
565+
) -> 'Perplexity':
566+
"""Updates the metric.
567+
568+
Args:
569+
predictions: A floating point 2D vector representing the prediction
570+
generated from the model. The shape should be (batch_size, seq_len).
571+
labels: True value. The shape should be (batch_size, seq_len).
572+
sample_weights: An optional floating point 2D vector representing the
573+
weight of each token. The shape should be (batch_size, seq_len).
574+
575+
Returns:
576+
Updated Perplexity metric. The shape should be a single scalar.
577+
578+
Raises:
579+
ValueError: If type of `labels` is wrong or the shapes of `predictions`
580+
and `labels` are incompatible.
581+
"""
582+
batch_size = jnp.array(labels.shape[0])
583+
predictions = predictions / jnp.sum(predictions, axis=-1, keepdims=True)
584+
predictions = jnp.clip(predictions, 1e-5, 1.0 - 1e-5)
585+
log_prob = jnp.log(predictions)
586+
labels = jax.nn.one_hot(labels, predictions.shape[-1], axis=-1)
587+
crossentropy = -jnp.sum(labels * log_prob, axis=-1)
588+
589+
# Sum across sequence length dimension first
590+
if sample_weights is not None:
591+
crossentropy = crossentropy * sample_weights
592+
# Normalize by the sum of weights for each sequence
593+
crossentropy = jnp.sum(crossentropy) / jnp.sum(sample_weights)
594+
else:
595+
crossentropy = jnp.mean(crossentropy)
596+
597+
return cls(
598+
aggregate_crossentropy=(batch_size * crossentropy),
599+
num_samples=batch_size,
600+
)
601+
602+
def merge(self, other: 'Perplexity') -> 'Perplexity':
603+
return type(self)(
604+
aggregate_crossentropy=(
605+
self.aggregate_crossentropy + other.aggregate_crossentropy
606+
),
607+
num_samples=self.num_samples + other.num_samples,
608+
)
609+
610+
def compute(self) -> jax.Array:
611+
return jnp.exp(self.aggregate_crossentropy / self.num_samples)

src/metrax/metrics_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from absl.testing import absltest
1818
from absl.testing import parameterized
1919
import jax.numpy as jnp
20+
import keras_hub
2021
import metrax
2122
import numpy as np
2223
from sklearn import metrics as sklearn_metrics
@@ -320,6 +321,43 @@ def test_rsquared(self, y_true, y_pred, sample_weights):
320321
atol=1e-05,
321322
)
322323

324+
@parameterized.named_parameters(
325+
(
326+
'basic',
327+
np.random.randint(10, size=[2, 5, 10]),
328+
np.random.uniform(size=(2, 5, 10, 20)),
329+
None,
330+
),
331+
(
332+
'weighted',
333+
np.random.randint(10, size=[2, 5, 10]),
334+
np.random.uniform(size=(2, 5, 10, 20)),
335+
np.random.randint(2, size=(2, 5, 10)).astype(np.float32),
336+
),
337+
)
338+
def test_perplexity(self, y_true, y_pred, sample_weights):
339+
keras_metric = keras_hub.metrics.Perplexity()
340+
metrax_metric = None
341+
for index, (labels, logits) in enumerate(zip(y_true, y_pred)):
342+
weights = sample_weights[index] if sample_weights is not None else None
343+
keras_metric.update_state(labels, logits, sample_weight=weights)
344+
update = metrax.Perplexity.from_model_output(
345+
predictions=logits,
346+
labels=labels,
347+
sample_weights=weights,
348+
)
349+
metrax_metric = update if metrax_metric is None else metrax_metric.merge(
350+
update
351+
)
352+
353+
expected = keras_metric.result()
354+
np.testing.assert_allclose(
355+
metrax_metric.compute(),
356+
expected,
357+
rtol=1e-05,
358+
atol=1e-05,
359+
)
360+
323361

324362
if __name__ == '__main__':
325363
absltest.main()

0 commit comments

Comments
 (0)