@@ -49,7 +49,24 @@ def _divide_no_nan(x: jax.Array, y: jax.Array) -> jax.Array:
4949
5050@flax .struct .dataclass
5151class MSE (clu_metrics .Average ):
52- """Computes the mean squared error for regression problems given `predictions` and `labels`."""
52+ r"""Computes the mean squared error for regression problems given `predictions` and `labels`.
53+
54+ The mean squared error without sample weights is defined as:
55+
56+ .. math::
57+ MSE = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
58+
59+ When sample weights :math:`w_i` are provided, the weighted mean squared error is:
60+
61+ .. math::
62+ MSE = \frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i}
63+
64+ where:
65+ - :math:`y_i` are true values
66+ - :math:`\hat{y}_i` are predictions
67+ - :math:`w_i` are sample weights
68+ - :math:`N` is the number of samples
69+ """
5370
5471 @classmethod
5572 def from_model_output (
@@ -87,20 +104,61 @@ def from_model_output(
87104
88105@flax .struct .dataclass
89106class RMSE (MSE ):
90- """Computes the root mean squared error for regression problems given `predictions` and `labels`."""
107+ r"""Computes the root mean squared error for regression problems given `predictions` and `labels`.
108+
109+ The root mean squared error without sample weights is defined as:
110+
111+ .. math::
112+ RMSE = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2}
113+
114+ When sample weights :math:`w_i` are provided, the weighted root mean squared error is:
115+
116+ .. math::
117+ RMSE = \sqrt{\frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i}}
118+
119+ where:
120+ - :math:`y_i` are true values
121+ - :math:`\hat{y}_i` are predictions
122+ - :math:`w_i` are sample weights
123+ - :math:`N` is the number of samples
124+ """
91125
92126 def compute (self ) -> jax .Array :
93127 return jnp .sqrt (super ().compute ())
94128
95129
96130@flax .struct .dataclass
97131class RSQUARED (clu_metrics .Metric ):
98- """Computes the r-squared score of a scalar or a batch of tensors.
132+ r """Computes the r-squared score of a scalar or a batch of tensors.
99133
100134 R-squared is a measure of how well the regression model fits the data. It
101135 measures the proportion of the variance in the dependent variable that is
102136 explained by the independent variable(s). It is defined as 1 - SSE / SST,
103137 where SSE is the sum of squared errors and SST is the total sum of squares.
138+
139+ .. math::
140+ R^2 = 1 - \frac{SSE}{SST}
141+
142+ where:
143+ .. math::
144+ SSE = \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
145+ .. math::
146+ SST = \sum_{i=1}^{N} (y_i - \bar{y})^2
147+
148+ When sample weights :math:`w_i` are provided:
149+
150+ .. math::
151+ R^2 = 1 - \frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i(y_i - \bar{y})^2}
152+
153+ where:
154+ - :math:`y_i` are true values
155+ - :math:`\hat{y}_i` are predictions
156+ - :math:`\bar{y}` is the mean of true values
157+ - :math:`w_i` are sample weights
158+ - :math:`N` is the number of samples
159+
160+ The score ranges from -∞ to 1, where 1 indicates perfect prediction and 0 indicates
161+ that the model performs no better than a horizontal line.
104162 """
105163
106164 total : jax .Array
@@ -177,7 +235,19 @@ def compute(self) -> jax.Array:
177235
178236@flax .struct .dataclass
179237class Precision (clu_metrics .Metric ):
180- """Computes precision for binary classification given `predictions` and `labels`.
238+ r"""Computes precision for binary classification given `predictions` and `labels`.
239+
240+ It is calculated as:
241+
242+ .. math::
243+ Precision = \frac{TP}{TP + FP}
244+
245+ where:
246+ - TP (True Positives): Number of correctly predicted positive cases
247+ - FP (False Positives): Number of incorrectly predicted positive cases
248+
249+ A threshold parameter (default 0.5) is used to convert probability predictions
250+ to binary predictions.
181251
182252 Attributes:
183253 true_positives: The count of true positive instances from the given data,
@@ -232,7 +302,19 @@ def compute(self) -> jax.Array:
232302
233303@flax .struct .dataclass
234304class Recall (clu_metrics .Metric ):
235- """Computes recall for binary classification given `predictions` and `labels`.
305+ r"""Computes recall for binary classification given `predictions` and `labels`.
306+
307+ It is calculated as:
308+
309+ .. math::
310+ Recall = \frac{TP}{TP + FN}
311+
312+ where:
313+ - TP (True Positives): Number of correctly predicted positive cases
314+ - FN (False Negatives): Number of incorrectly predicted negative cases
315+
316+ A threshold parameter (default 0.5) is used to convert probability predictions
317+ to binary predictions.
236318
237319 Attributes:
238320 true_positives: The count of true positive instances from the given data,
@@ -284,7 +366,29 @@ def compute(self) -> jax.Array:
284366
285367@flax .struct .dataclass
286368class AUCPR (clu_metrics .Metric ):
287- """Computes area under the precision-recall curve for binary classification given `predictions` and `labels`.
369+ r"""Computes area under the precision-recall curve for binary classification given `predictions` and `labels`.
370+
371+ The Precision-Recall curve shows the tradeoff between precision and recall at different
372+ classification thresholds. The area under this curve (AUC-PR) provides a single score
373+ that represents the model's ability to identify positive cases across
374+ all possible classification thresholds, particularly in imbalanced datasets.
375+
376+ For each threshold :math:`t`, precision and recall are calculated as:
377+
378+ .. math::
379+ Precision(t) = \frac{TP(t)}{TP(t) + FP(t)}
380+
381+ Recall(t) = \frac{TP(t)}{TP(t) + FN(t)}
382+
383+ The AUC-PR is then computed using interpolation:
384+
385+ .. math::
386+ AUC-PR = \sum_{i=1}^{n-1} (R_{i+1} - R_i) \cdot \frac{P_i + P_{i+1}}{2}
387+
388+ where:
389+ - :math:`P_i` is precision at threshold i
390+ - :math:`R_i` is recall at threshold i
391+ - :math:`n` is the number of thresholds
288392
289393 AUC-PR Curve metric have a number of known issues so use it with caution.
290394 - PR curves are highly class balance sensitive.
@@ -448,7 +552,27 @@ def compute(self) -> jax.Array:
448552
449553@flax .struct .dataclass
450554class AUCROC (clu_metrics .Metric ):
451- """Computes area under the receiver operation characteristic curve for binary classification given `predictions` and `labels`.
555+ r"""Computes area under the receiver operation characteristic curve for binary classification given `predictions` and `labels`.
556+
557+ The ROC curve shows the tradeoff between the true positive rate (TPR) and false positive
558+ rate (FPR) at different classification thresholds. The area under this curve (AUC-ROC)
559+ provides a single score that represents the model's ability to discriminate between
560+ positive and negative cases across all possible classification thresholds,
561+ regardless of class imbalance.
562+
563+ For each threshold :math:`t`, TPR and FPR are calculated as:
564+
565+ .. math::
566+ TPR(t) = \frac{TP(t)}{TP(t) + FN(t)}
567+
568+ FPR(t) = \frac{FP(t)}{FP(t) + TN(t)}
569+
570+ The AUC-ROC is then computed using the trapezoidal rule:
571+
572+ .. math::
573+ AUC-ROC = \int_{0}^{1} TPR(FPR^{-1}(x)) dx
574+
575+ A score of 1 represents perfect classification, while 0.5 represents random guessing.
452576
453577 Attributes:
454578 true_positives: The count of true positive instances from the given data and
@@ -541,7 +665,7 @@ def compute(self) -> jax.Array:
541665
542666@flax .struct .dataclass
543667class Perplexity (clu_metrics .Metric ):
544- """Computes perplexity for sequence generation.
668+ r """Computes perplexity for sequence generation.
545669
546670 Perplexity is a measurement of how well a probability distribution predicts a
547671 sample. It is defined as the exponentiation of the cross-entropy. A low
@@ -551,6 +675,24 @@ class Perplexity(clu_metrics.Metric):
551675 For language models, it can be interpreted as the weighted average branching
552676 factor of the model - how many equally likely words can be selected at each
553677 step.
678+
679+ Given a sequence of :math:`N` tokens, perplexity is calculated as:
680+
681+ .. math::
682+ Perplexity = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log P(x_i|x_{<i})\right)
683+
684+ When sample weights :math:`w_i` are provided:
685+
686+ .. math::
687+ Perplexity = \exp\left(-\frac{\sum_{i=1}^{N} w_i\log P(x_i|x_{<i})}{\sum_{i=1}^{N} w_i}\right)
688+
689+ where:
690+ - :math:`P(x_i|x_{<i})` is the predicted probability of token :math:`x_i`
691+ given previous tokens
692+ - :math:`w_i` are sample weights
693+ - :math:`N` is the sequence length
694+
695+ Lower perplexity indicates better prediction - the model is less "perplexed" by the data.
554696 """
555697
556698 aggregate_crossentropy : jax .Array
0 commit comments