@@ -536,4 +536,76 @@ def compute(self) -> jax.Array:
536536 self .false_positives , self .false_positives + self .true_negatives
537537 )
538538 # Threshold goes from 0 to 1, so trapezoid is negative.
539- return jnp .trapezoid (tp_rate , fp_rate ) * - 1
539+ return jnp .trapezoid (tp_rate , fp_rate ) * - 1
540+
541+
542+ @flax .struct .dataclass
543+ class Perplexity (clu_metrics .Metric ):
544+ """Computes perplexity for sequence generation.
545+
546+ Perplexity is a measurement of how well a probability distribution predicts a
547+ sample. It is defined as the exponentiation of the cross-entropy. A low
548+ perplexity indicates the probability distribution is good at predicting the
549+ sample.
550+
551+ For language models, it can be interpreted as the weighted average branching
552+ factor of the model - how many equally likely words can be selected at each
553+ step.
554+ """
555+
556+ aggregate_crossentropy : jax .Array
557+ num_samples : jax .Array
558+
559+ @classmethod
560+ def from_model_output (
561+ cls ,
562+ predictions : jax .Array ,
563+ labels : jax .Array ,
564+ sample_weights : jax .Array | None = None ,
565+ ) -> 'Perplexity' :
566+ """Updates the metric.
567+
568+ Args:
569+ predictions: A floating point tensor representing the prediction
570+ generated from the model. The shape should be (batch_size, seq_len,
571+ vocab_size).
572+ labels: True value. The shape should be (batch_size, seq_len).
573+ sample_weights: An optional tensor representing the
574+ weight of each token. The shape should be (batch_size, seq_len).
575+
576+ Returns:
577+ Updated Perplexity metric.
578+
579+ Raises:
580+ ValueError: If type of `labels` is wrong or the shapes of `predictions`
581+ and `labels` are incompatible.
582+ """
583+ predictions = predictions / jnp .sum (predictions , axis = - 1 , keepdims = True )
584+ labels_one_hot = jax .nn .one_hot (labels , predictions .shape [- 1 ], axis = - 1 )
585+ log_prob = jnp .log (predictions )
586+ crossentropy = - jnp .sum (labels_one_hot * log_prob , axis = - 1 )
587+
588+ # Sum across sequence length dimension first.
589+ if sample_weights is not None :
590+ crossentropy = crossentropy * sample_weights
591+ # Normalize by the sum of weights for each sequence.
592+ crossentropy = jnp .sum (crossentropy ) / jnp .sum (sample_weights )
593+ else :
594+ crossentropy = jnp .mean (crossentropy )
595+
596+ batch_size = jnp .array (labels .shape [0 ])
597+ return cls (
598+ aggregate_crossentropy = (batch_size * crossentropy ),
599+ num_samples = batch_size ,
600+ )
601+
602+ def merge (self , other : 'Perplexity' ) -> 'Perplexity' :
603+ return type (self )(
604+ aggregate_crossentropy = (
605+ self .aggregate_crossentropy + other .aggregate_crossentropy
606+ ),
607+ num_samples = self .num_samples + other .num_samples ,
608+ )
609+
610+ def compute (self ) -> jax .Array :
611+ return jnp .exp (self .aggregate_crossentropy / self .num_samples )
0 commit comments