Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 150 additions & 8 deletions src/metrax/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,24 @@ def _divide_no_nan(x: jax.Array, y: jax.Array) -> jax.Array:

@flax.struct.dataclass
class MSE(clu_metrics.Average):
"""Computes the mean squared error for regression problems given `predictions` and `labels`."""
r"""Computes the mean squared error for regression problems given `predictions` and `labels`.

The mean squared error without sample weights is defined as:

.. math::
MSE = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2

When sample weights :math:`w_i` are provided, the weighted mean squared error is:

.. math::
MSE = \frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i}

where:
- :math:`y_i` are true values
- :math:`\hat{y}_i` are predictions
- :math:`w_i` are sample weights
- :math:`N` is the number of samples
"""

@classmethod
def from_model_output(
Expand Down Expand Up @@ -87,20 +104,61 @@ def from_model_output(

@flax.struct.dataclass
class RMSE(MSE):
"""Computes the root mean squared error for regression problems given `predictions` and `labels`."""
r"""Computes the root mean squared error for regression problems given `predictions` and `labels`.

The root mean squared error without sample weights is defined as:

.. math::
RMSE = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2}

When sample weights :math:`w_i` are provided, the weighted root mean squared error is:

.. math::
RMSE = \sqrt{\frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i}}

where:
- :math:`y_i` are true values
- :math:`\hat{y}_i` are predictions
- :math:`w_i` are sample weights
- :math:`N` is the number of samples
"""

def compute(self) -> jax.Array:
return jnp.sqrt(super().compute())


@flax.struct.dataclass
class RSQUARED(clu_metrics.Metric):
"""Computes the r-squared score of a scalar or a batch of tensors.
r"""Computes the r-squared score of a scalar or a batch of tensors.

R-squared is a measure of how well the regression model fits the data. It
measures the proportion of the variance in the dependent variable that is
explained by the independent variable(s). It is defined as 1 - SSE / SST,
where SSE is the sum of squared errors and SST is the total sum of squares.

.. math::
R^2 = 1 - \frac{SSE}{SST}

where:
.. math::
SSE = \sum_{i=1}^{N} (y_i - \hat{y}_i)^2
.. math::
SST = \sum_{i=1}^{N} (y_i - \bar{y})^2

When sample weights :math:`w_i` are provided:

.. math::
R^2 = 1 - \frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i(y_i - \bar{y})^2}

where:
- :math:`y_i` are true values
- :math:`\hat{y}_i` are predictions
- :math:`\bar{y}` is the mean of true values
- :math:`w_i` are sample weights
- :math:`N` is the number of samples

The score ranges from -∞ to 1, where 1 indicates perfect prediction and 0 indicates
that the model performs no better than a horizontal line.
"""

total: jax.Array
Expand Down Expand Up @@ -177,7 +235,19 @@ def compute(self) -> jax.Array:

@flax.struct.dataclass
class Precision(clu_metrics.Metric):
"""Computes precision for binary classification given `predictions` and `labels`.
r"""Computes precision for binary classification given `predictions` and `labels`.

It is calculated as:

.. math::
Precision = \frac{TP}{TP + FP}

where:
- TP (True Positives): Number of correctly predicted positive cases
- FP (False Positives): Number of incorrectly predicted positive cases

A threshold parameter (default 0.5) is used to convert probability predictions
to binary predictions.

Attributes:
true_positives: The count of true positive instances from the given data,
Expand Down Expand Up @@ -232,7 +302,19 @@ def compute(self) -> jax.Array:

@flax.struct.dataclass
class Recall(clu_metrics.Metric):
"""Computes recall for binary classification given `predictions` and `labels`.
r"""Computes recall for binary classification given `predictions` and `labels`.

It is calculated as:

.. math::
Recall = \frac{TP}{TP + FN}

where:
- TP (True Positives): Number of correctly predicted positive cases
- FN (False Negatives): Number of incorrectly predicted negative cases

A threshold parameter (default 0.5) is used to convert probability predictions
to binary predictions.

Attributes:
true_positives: The count of true positive instances from the given data,
Expand Down Expand Up @@ -284,7 +366,29 @@ def compute(self) -> jax.Array:

@flax.struct.dataclass
class AUCPR(clu_metrics.Metric):
"""Computes area under the precision-recall curve for binary classification given `predictions` and `labels`.
r"""Computes area under the precision-recall curve for binary classification given `predictions` and `labels`.

The Precision-Recall curve shows the tradeoff between precision and recall at different
classification thresholds. The area under this curve (AUC-PR) provides a single score
that represents the model's ability to identify positive cases across
all possible classification thresholds, particularly in imbalanced datasets.

For each threshold :math:`t`, precision and recall are calculated as:

.. math::
Precision(t) = \frac{TP(t)}{TP(t) + FP(t)}

Recall(t) = \frac{TP(t)}{TP(t) + FN(t)}

The AUC-PR is then computed using interpolation:

.. math::
AUC-PR = \sum_{i=1}^{n-1} (R_{i+1} - R_i) \cdot \frac{P_i + P_{i+1}}{2}

where:
- :math:`P_i` is precision at threshold i
- :math:`R_i` is recall at threshold i
- :math:`n` is the number of thresholds

AUC-PR Curve metric have a number of known issues so use it with caution.
- PR curves are highly class balance sensitive.
Expand Down Expand Up @@ -448,7 +552,27 @@ def compute(self) -> jax.Array:

@flax.struct.dataclass
class AUCROC(clu_metrics.Metric):
"""Computes area under the receiver operation characteristic curve for binary classification given `predictions` and `labels`.
r"""Computes area under the receiver operation characteristic curve for binary classification given `predictions` and `labels`.

The ROC curve shows the tradeoff between the true positive rate (TPR) and false positive
rate (FPR) at different classification thresholds. The area under this curve (AUC-ROC)
provides a single score that represents the model's ability to discriminate between
positive and negative cases across all possible classification thresholds,
regardless of class imbalance.

For each threshold :math:`t`, TPR and FPR are calculated as:

.. math::
TPR(t) = \frac{TP(t)}{TP(t) + FN(t)}

FPR(t) = \frac{FP(t)}{FP(t) + TN(t)}

The AUC-ROC is then computed using the trapezoidal rule:

.. math::
AUC-ROC = \int_{0}^{1} TPR(FPR^{-1}(x)) dx

A score of 1 represents perfect classification, while 0.5 represents random guessing.

Attributes:
true_positives: The count of true positive instances from the given data and
Expand Down Expand Up @@ -541,7 +665,7 @@ def compute(self) -> jax.Array:

@flax.struct.dataclass
class Perplexity(clu_metrics.Metric):
"""Computes perplexity for sequence generation.
r"""Computes perplexity for sequence generation.

Perplexity is a measurement of how well a probability distribution predicts a
sample. It is defined as the exponentiation of the cross-entropy. A low
Expand All @@ -551,6 +675,24 @@ class Perplexity(clu_metrics.Metric):
For language models, it can be interpreted as the weighted average branching
factor of the model - how many equally likely words can be selected at each
step.

Given a sequence of :math:`N` tokens, perplexity is calculated as:

.. math::
Perplexity = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log P(x_i|x_{<i})\right)

When sample weights :math:`w_i` are provided:

.. math::
Perplexity = \exp\left(-\frac{\sum_{i=1}^{N} w_i\log P(x_i|x_{<i})}{\sum_{i=1}^{N} w_i}\right)

where:
- :math:`P(x_i|x_{<i})` is the predicted probability of token :math:`x_i`
given previous tokens
- :math:`w_i` are sample weights
- :math:`N` is the sequence length

Lower perplexity indicates better prediction - the model is less "perplexed" by the data.
"""

aggregate_crossentropy: jax.Array
Expand Down