@@ -536,4 +536,76 @@ def compute(self) -> jax.Array:
536536 self .false_positives , self .false_positives + self .true_negatives
537537 )
538538 # Threshold goes from 0 to 1, so trapezoid is negative.
539- return jnp .trapezoid (tp_rate , fp_rate ) * - 1
539+ return jnp .trapezoid (tp_rate , fp_rate ) * - 1
540+
541+
542+ @flax .struct .dataclass
543+ class Perplexity (clu_metrics .Metric ):
544+ """Computes perplexity for sequence generation.
545+
546+ Perplexity is a measurement of how well a probability distribution predicts a
547+ sample. It is defined as the exponentiation of the cross-entropy. A low
548+ perplexity indicates the probability distribution is good at predicting the
549+ sample.
550+
551+ For language models, it can be interpreted as the weighted average branching
552+ factor of the model - how many equally likely words can be selected at each
553+ step.
554+ """
555+
556+ aggregate_crossentropy : jax .Array
557+ num_samples : jax .Array
558+
559+ @classmethod
560+ def from_model_output (
561+ cls ,
562+ predictions : jax .Array ,
563+ labels : jax .Array ,
564+ sample_weights : jax .Array | None = None ,
565+ ) -> 'Perplexity' :
566+ """Updates the metric.
567+
568+ Args:
569+ predictions: A floating point 2D vector representing the prediction
570+ generated from the model. The shape should be (batch_size, seq_len).
571+ labels: True value. The shape should be (batch_size, seq_len).
572+ sample_weights: An optional floating point 2D vector representing the
573+ weight of each token. The shape should be (batch_size, seq_len).
574+
575+ Returns:
576+ Updated Perplexity metric. The shape should be a single scalar.
577+
578+ Raises:
579+ ValueError: If type of `labels` is wrong or the shapes of `predictions`
580+ and `labels` are incompatible.
581+ """
582+ batch_size = jnp .array (labels .shape [0 ])
583+ predictions = predictions / jnp .sum (predictions , axis = - 1 , keepdims = True )
584+ predictions = jnp .clip (predictions , 1e-5 , 1.0 - 1e-5 )
585+ log_prob = jnp .log (predictions )
586+ labels = jax .nn .one_hot (labels , predictions .shape [- 1 ], axis = - 1 )
587+ crossentropy = - jnp .sum (labels * log_prob , axis = - 1 )
588+
589+ # Sum across sequence length dimension first
590+ if sample_weights is not None :
591+ crossentropy = crossentropy * sample_weights
592+ # Normalize by the sum of weights for each sequence
593+ crossentropy = jnp .sum (crossentropy ) / jnp .sum (sample_weights )
594+ else :
595+ crossentropy = jnp .mean (crossentropy )
596+
597+ return cls (
598+ aggregate_crossentropy = (batch_size * crossentropy ),
599+ num_samples = batch_size ,
600+ )
601+
602+ def merge (self , other : 'Perplexity' ) -> 'Perplexity' :
603+ return type (self )(
604+ aggregate_crossentropy = (
605+ self .aggregate_crossentropy + other .aggregate_crossentropy
606+ ),
607+ num_samples = self .num_samples + other .num_samples ,
608+ )
609+
610+ def compute (self ) -> jax .Array :
611+ return jnp .exp (self .aggregate_crossentropy / self .num_samples )
0 commit comments