Skip to content

Commit 9d50348

Browse files
authored
docs: add mathematical formulas to metric docstrings (#13)
1 parent b8ca720 commit 9d50348

File tree

1 file changed

+150
-8
lines changed

1 file changed

+150
-8
lines changed

src/metrax/metrics.py

Lines changed: 150 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,24 @@ def _divide_no_nan(x: jax.Array, y: jax.Array) -> jax.Array:
4949

5050
@flax.struct.dataclass
5151
class MSE(clu_metrics.Average):
52-
"""Computes the mean squared error for regression problems given `predictions` and `labels`."""
52+
r"""Computes the mean squared error for regression problems given `predictions` and `labels`.
53+
54+
The mean squared error without sample weights is defined as:
55+
56+
.. math::
57+
MSE = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
58+
59+
When sample weights :math:`w_i` are provided, the weighted mean squared error is:
60+
61+
.. math::
62+
MSE = \frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i}
63+
64+
where:
65+
- :math:`y_i` are true values
66+
- :math:`\hat{y}_i` are predictions
67+
- :math:`w_i` are sample weights
68+
- :math:`N` is the number of samples
69+
"""
5370

5471
@classmethod
5572
def from_model_output(
@@ -87,20 +104,61 @@ def from_model_output(
87104

88105
@flax.struct.dataclass
89106
class RMSE(MSE):
90-
"""Computes the root mean squared error for regression problems given `predictions` and `labels`."""
107+
r"""Computes the root mean squared error for regression problems given `predictions` and `labels`.
108+
109+
The root mean squared error without sample weights is defined as:
110+
111+
.. math::
112+
RMSE = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2}
113+
114+
When sample weights :math:`w_i` are provided, the weighted root mean squared error is:
115+
116+
.. math::
117+
RMSE = \sqrt{\frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i}}
118+
119+
where:
120+
- :math:`y_i` are true values
121+
- :math:`\hat{y}_i` are predictions
122+
- :math:`w_i` are sample weights
123+
- :math:`N` is the number of samples
124+
"""
91125

92126
def compute(self) -> jax.Array:
93127
return jnp.sqrt(super().compute())
94128

95129

96130
@flax.struct.dataclass
97131
class RSQUARED(clu_metrics.Metric):
98-
"""Computes the r-squared score of a scalar or a batch of tensors.
132+
r"""Computes the r-squared score of a scalar or a batch of tensors.
99133
100134
R-squared is a measure of how well the regression model fits the data. It
101135
measures the proportion of the variance in the dependent variable that is
102136
explained by the independent variable(s). It is defined as 1 - SSE / SST,
103137
where SSE is the sum of squared errors and SST is the total sum of squares.
138+
139+
.. math::
140+
R^2 = 1 - \frac{SSE}{SST}
141+
142+
where:
143+
.. math::
144+
SSE = \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
145+
.. math::
146+
SST = \sum_{i=1}^{N} (y_i - \bar{y})^2
147+
148+
When sample weights :math:`w_i` are provided:
149+
150+
.. math::
151+
R^2 = 1 - \frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i(y_i - \bar{y})^2}
152+
153+
where:
154+
- :math:`y_i` are true values
155+
- :math:`\hat{y}_i` are predictions
156+
- :math:`\bar{y}` is the mean of true values
157+
- :math:`w_i` are sample weights
158+
- :math:`N` is the number of samples
159+
160+
The score ranges from -∞ to 1, where 1 indicates perfect prediction and 0 indicates
161+
that the model performs no better than a horizontal line.
104162
"""
105163

106164
total: jax.Array
@@ -177,7 +235,19 @@ def compute(self) -> jax.Array:
177235

178236
@flax.struct.dataclass
179237
class Precision(clu_metrics.Metric):
180-
"""Computes precision for binary classification given `predictions` and `labels`.
238+
r"""Computes precision for binary classification given `predictions` and `labels`.
239+
240+
It is calculated as:
241+
242+
.. math::
243+
Precision = \frac{TP}{TP + FP}
244+
245+
where:
246+
- TP (True Positives): Number of correctly predicted positive cases
247+
- FP (False Positives): Number of incorrectly predicted positive cases
248+
249+
A threshold parameter (default 0.5) is used to convert probability predictions
250+
to binary predictions.
181251
182252
Attributes:
183253
true_positives: The count of true positive instances from the given data,
@@ -232,7 +302,19 @@ def compute(self) -> jax.Array:
232302

233303
@flax.struct.dataclass
234304
class Recall(clu_metrics.Metric):
235-
"""Computes recall for binary classification given `predictions` and `labels`.
305+
r"""Computes recall for binary classification given `predictions` and `labels`.
306+
307+
It is calculated as:
308+
309+
.. math::
310+
Recall = \frac{TP}{TP + FN}
311+
312+
where:
313+
- TP (True Positives): Number of correctly predicted positive cases
314+
- FN (False Negatives): Number of incorrectly predicted negative cases
315+
316+
A threshold parameter (default 0.5) is used to convert probability predictions
317+
to binary predictions.
236318
237319
Attributes:
238320
true_positives: The count of true positive instances from the given data,
@@ -284,7 +366,29 @@ def compute(self) -> jax.Array:
284366

285367
@flax.struct.dataclass
286368
class AUCPR(clu_metrics.Metric):
287-
"""Computes area under the precision-recall curve for binary classification given `predictions` and `labels`.
369+
r"""Computes area under the precision-recall curve for binary classification given `predictions` and `labels`.
370+
371+
The Precision-Recall curve shows the tradeoff between precision and recall at different
372+
classification thresholds. The area under this curve (AUC-PR) provides a single score
373+
that represents the model's ability to identify positive cases across
374+
all possible classification thresholds, particularly in imbalanced datasets.
375+
376+
For each threshold :math:`t`, precision and recall are calculated as:
377+
378+
.. math::
379+
Precision(t) = \frac{TP(t)}{TP(t) + FP(t)}
380+
381+
Recall(t) = \frac{TP(t)}{TP(t) + FN(t)}
382+
383+
The AUC-PR is then computed using interpolation:
384+
385+
.. math::
386+
AUC-PR = \sum_{i=1}^{n-1} (R_{i+1} - R_i) \cdot \frac{P_i + P_{i+1}}{2}
387+
388+
where:
389+
- :math:`P_i` is precision at threshold i
390+
- :math:`R_i` is recall at threshold i
391+
- :math:`n` is the number of thresholds
288392
289393
AUC-PR Curve metric have a number of known issues so use it with caution.
290394
- PR curves are highly class balance sensitive.
@@ -448,7 +552,27 @@ def compute(self) -> jax.Array:
448552

449553
@flax.struct.dataclass
450554
class AUCROC(clu_metrics.Metric):
451-
"""Computes area under the receiver operation characteristic curve for binary classification given `predictions` and `labels`.
555+
r"""Computes area under the receiver operation characteristic curve for binary classification given `predictions` and `labels`.
556+
557+
The ROC curve shows the tradeoff between the true positive rate (TPR) and false positive
558+
rate (FPR) at different classification thresholds. The area under this curve (AUC-ROC)
559+
provides a single score that represents the model's ability to discriminate between
560+
positive and negative cases across all possible classification thresholds,
561+
regardless of class imbalance.
562+
563+
For each threshold :math:`t`, TPR and FPR are calculated as:
564+
565+
.. math::
566+
TPR(t) = \frac{TP(t)}{TP(t) + FN(t)}
567+
568+
FPR(t) = \frac{FP(t)}{FP(t) + TN(t)}
569+
570+
The AUC-ROC is then computed using the trapezoidal rule:
571+
572+
.. math::
573+
AUC-ROC = \int_{0}^{1} TPR(FPR^{-1}(x)) dx
574+
575+
A score of 1 represents perfect classification, while 0.5 represents random guessing.
452576
453577
Attributes:
454578
true_positives: The count of true positive instances from the given data and
@@ -541,7 +665,7 @@ def compute(self) -> jax.Array:
541665

542666
@flax.struct.dataclass
543667
class Perplexity(clu_metrics.Metric):
544-
"""Computes perplexity for sequence generation.
668+
r"""Computes perplexity for sequence generation.
545669
546670
Perplexity is a measurement of how well a probability distribution predicts a
547671
sample. It is defined as the exponentiation of the cross-entropy. A low
@@ -551,6 +675,24 @@ class Perplexity(clu_metrics.Metric):
551675
For language models, it can be interpreted as the weighted average branching
552676
factor of the model - how many equally likely words can be selected at each
553677
step.
678+
679+
Given a sequence of :math:`N` tokens, perplexity is calculated as:
680+
681+
.. math::
682+
Perplexity = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log P(x_i|x_{<i})\right)
683+
684+
When sample weights :math:`w_i` are provided:
685+
686+
.. math::
687+
Perplexity = \exp\left(-\frac{\sum_{i=1}^{N} w_i\log P(x_i|x_{<i})}{\sum_{i=1}^{N} w_i}\right)
688+
689+
where:
690+
- :math:`P(x_i|x_{<i})` is the predicted probability of token :math:`x_i`
691+
given previous tokens
692+
- :math:`w_i` are sample weights
693+
- :math:`N` is the sequence length
694+
695+
Lower perplexity indicates better prediction - the model is less "perplexed" by the data.
554696
"""
555697

556698
aggregate_crossentropy: jax.Array

0 commit comments

Comments
 (0)