Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .cspell/library_terms.txt
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ logpdf
lstsq
mathbb
mathbf
JMMD
jmmd
mathcal
mathrm
matplotlib
Expand Down
2 changes: 1 addition & 1 deletion coreax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
SteinKernel,
UniCompositeKernel,
)
from coreax.metrics import KSD, MMD
from coreax.metrics import JMMD, KSD, MMD
from coreax.score_matching import KernelDensityMatching, SlicedScoreMatching

__all__ = [
Expand Down
128 changes: 124 additions & 4 deletions coreax/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
import jax.tree_util as jtu
from jaxtyping import Array, Shaped

import coreax.kernels
import coreax.util
from coreax.data import Data
from coreax.data import Data, SupervisedData
from coreax.kernels import ScalarValuedKernel
from coreax.score_matching import ScoreMatching, convert_stein_kernel

_Data = TypeVar("_Data", bound=Data)
Expand Down Expand Up @@ -84,7 +84,7 @@ class MMD(Metric[Data]):
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}`
"""

kernel: coreax.kernels.ScalarValuedKernel
kernel: ScalarValuedKernel

def compute(
self,
Expand Down Expand Up @@ -194,7 +194,7 @@ class KSD(Metric[Data]):
are rounded to zero (accommodates precision loss)
"""

kernel: coreax.kernels.ScalarValuedKernel
kernel: ScalarValuedKernel
score_matching: ScoreMatching | None = None
precision_threshold: float = 1e-12

Expand Down Expand Up @@ -276,3 +276,123 @@ def _laplace_positive(x_: Shaped[Array, " m d"]) -> Shaped[Array, ""]:
self.precision_threshold,
)
return jnp.sqrt(squared_ksd_threshold_applied)


class JMMD(Metric[SupervisedData]):
r"""
Definition and calculation of the (weighted) joint maximum mean discrepancy metric.

For a dataset :math:`\mathcal{D}^{(1)} = \{(x_i, y_i)\}_{i=1}^n` with
:math:`x\in\mathbb{R}^d` and :math:`y\in\mathbb{R}^p`, and another dataset
:math:`\mathcal{D}^{(2)} = \{(x^\prime_i, y^\prime_i)\}_{i=1}^m`
with :math:`x^\prime\in\mathbb{R}^d` and :math:`y^\prime\in\mathbb{R}^p`,
the joint maximum mean discrepancy is given by:

.. math::
\text{JMMD}^2(\mathcal{D}_1,\mathcal{D}_2) = \mathbb{E}(r(\mathcal{D}_1,
\mathcal{D}_1)) + \mathbb{E}(r(\mathcal{D}_2,\mathcal{D}_2))
- 2\mathbb{E}(r(\mathcal{D}_1,\mathcal{D}_2))

where :math:`r` is a tensor-product kernel, defined as the product of a feature
kernel and a response kernel, and the expectation is with respect to
the normalised data weights.

.. note::
Assuming that the feature and response kernels are characteristic
(:cite:`muandet2016rkhs`), it can be shown
that :math:`\text{JMMD}^2(\mathcal{D}_1,\mathcal{D}_2) = 0` if and only if
:math:`\mathbb{P}^{(1)}_(X, Y) = \mathbb{P}^{(2)}_(X, Y)`, i.e. the joint
distributions are the same. Therefore, the JMMD gives us a way to measure if two
supervised datasets have the same (in the sense above) joint distribution.

Common uses of JMMD include comparing a reduced representation of a dataset to the
original dataset, comparing different original datasets to one another, or
comparing reduced representations of different original datasets to one another.

:param feature_kernel: :class:`~coreax.kernels.ScalarValuedKernel` instance
implementing a kernel function
:math:`k: \mathbb{R}^d \times \mathbb{R}^d \rightarrow \mathbb{R}` on the
feature space
:param response_kernel: :class:`~coreax.kernels.ScalarValuedKernel` instance
implementing a kernel function
:math:`k: \mathbb{R}^p \times \mathbb{R}^p \rightarrow \mathbb{R}` on the
response space
:param precision_threshold: Threshold above which negative values of the squared
JMMD are rounded to zero (accommodates precision loss)
"""

feature_kernel: ScalarValuedKernel
response_kernel: ScalarValuedKernel
precision_threshold: float = 1e-12

def compute(
self,
reference_data: SupervisedData,
comparison_data: SupervisedData,
**kwargs,
) -> Array:
r"""
Compute the (weighted) joint maximum mean discrepancy.

.. math::
\text{JMMD}^2(\mathcal{D}_1,\mathcal{D}_2) = \mathbb{E}(k(\mathcal{D}_1,
\mathcal{D}_1)) + \mathbb{E}(k(\mathcal{D}_2,\mathcal{D}_2))
- 2\mathbb{E}(k(\mathcal{D}_1,\mathcal{D}_2))

:param reference_data: Supervised dataset :math:`\mathcal{D}_1 =
\{(x_i, y_i)\}_{i=1}^n` with :math:`x \in \mathbb{R}^d` and
:math:`y \in \mathbb{R}^p`
:param comparison_data: Supervised dataset
:math:`\mathcal{D}_" = \{(x^\prime_i, y^\prime_i)\}_{i=1}^n` with
:math:`x^\prime \in \mathbb{R}^d` and
:math:`y^\prime \in \mathbb{R}^p`
:return: Joint maximum mean discrepancy as a 0-dimensional array
"""
del kwargs

# Normalise the weights to allow for computation of weighted means
reference_data = reference_data.normalize(preserve_zeros=True)
comparison_data = comparison_data.normalize(preserve_zeros=True)

# Variable rename allows for nicer automatic formatting
x1, y1, w1 = (
reference_data.data,
reference_data.supervision,
reference_data.weights,
)
x2, y2, w2 = (
comparison_data.data,
comparison_data.supervision,
comparison_data.weights,
)

kernel_1_mean = jnp.dot(
jnp.dot(
w1,
self.feature_kernel.compute(x1, x1)
* self.response_kernel.compute(y1, y1),
),
w1,
)
kernel_2_mean = jnp.dot(
jnp.dot(
w2,
self.feature_kernel.compute(x2, x2)
* self.response_kernel.compute(y2, y2),
),
w2,
)
kernel_12_mean = jnp.dot(
jnp.dot(
w1,
self.feature_kernel.compute(x1, x2)
* self.response_kernel.compute(y1, y2),
),
w2,
)

squared_mmd_threshold_applied = jnp.maximum(
kernel_1_mean + kernel_2_mean - 2 * kernel_12_mean,
0.0,
)
return jnp.sqrt(squared_mmd_threshold_applied)
Loading