Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/likelihoods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ reduce the variance when computing approximate GP objective functions.
.. autoclass:: StudentTLikelihood
:members:

:hidden:`OrdinalLikelihood`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: OrdinalLikelihood
:members:


Multi-Dimensional Likelihoods
-----------------------------
Expand Down
2 changes: 2 additions & 0 deletions gpytorch/likelihoods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .multitask_gaussian_likelihood import _MultitaskGaussianLikelihoodBase, MultitaskGaussianLikelihood
from .negative_binomial_likelihood import NegativeBinomialLikelihood
from .noise_models import HeteroskedasticNoise
from .ordinal_likelihood import OrdinalLikelihood
from .poisson_likelihood import PoissonLikelihood
from .softmax_likelihood import SoftmaxLikelihood
from .student_t_likelihood import StudentTLikelihood
Expand All @@ -39,6 +40,7 @@
"LikelihoodList",
"MultitaskGaussianLikelihood",
"NegativeBinomialLikelihood",
"OrdinalLikelihood",
"PoissonLikelihood",
"SoftmaxLikelihood",
"StudentTLikelihood",
Expand Down
108 changes: 108 additions & 0 deletions gpytorch/likelihoods/ordinal_likelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Any

import torch
from torch import Tensor
from torch.distributions import Categorical

from ..constraints import Interval, Positive
from ..priors import Prior
from .likelihood import _OneDimensionalLikelihood


def inv_probit(x, jitter=1e-3):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def inv_probit(x, jitter=1e-3):
def inv_probit(x: Tensor, jitter: float = 1e-3):

Let's annotate these variables.

"""
Inverse probit function (standard normal CDF) with jitter for numerical stability.

Args:
x: Input tensor
jitter: Small constant to ensure outputs are strictly between 0 and 1

Returns:
Probabilities between jitter and 1-jitter
"""
return 0.5 * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0)))) * (1 - 2 * jitter) + jitter
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If x is a GPU tensor, then this line will trigger a device error as torch.tensor(2.0) is always a CPU tensor.

Suggested change
return 0.5 * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0)))) * (1 - 2 * jitter) + jitter
return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) * (1 - 2 * jitter) + jitter

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to `torch.tensor(2.0, device=x.device)' to keep torch speed + fix device error



class OrdinalLikelihood(_OneDimensionalLikelihood):
r"""
An ordinal likelihood for regressing over ordinal data.

The data are integer values from :math:`0` to :math:`k`, and the user must specify :math:`(k-1)`
'bin edges' which define the points at which the labels switch. Let the bin
edges be :math:`[a_0, a_1, ... a_{k-1}]`, then the likelihood is

.. math::
p(Y=0|F) &= \Phi((a_0 - F) / \sigma)

p(Y=1|F) &= \Phi((a_1 - F) / \sigma) - \Phi((a_0 - F) / \sigma)

p(Y=2|F) &= \Phi((a_2 - F) / \sigma) - \Phi((a_1 - F) / \sigma)

...

p(Y=K|F) &= 1 - \Phi((a_{k-1} - F) / \sigma)

where :math:`\Phi` is the cumulative density function of a Gaussian (the inverse probit
function) and :math:`\sigma` is a parameter to be learned.

From Chu et Ghahramani, Journal of Machine Learning Research, 2005
[https://www.jmlr.org/papers/volume6/chu05a/chu05a.pdf].

:param bin_edges: A tensor of shape :math:`(k-1)` containing the bin edges.
:param batch_shape: The batch shape of the learned sigma parameter (default: []).
:param sigma_prior: Prior for sigma parameter :math:`\sigma`.
:param sigma_constraint: Constraint for sigma parameter :math:`\sigma`.

:ivar torch.Tensor bin_edges: :math:`\{a_i\}_{i=0}^{k-1}` bin edges
:ivar torch.Tensor sigma: :math:`\sigma` parameter (scale)
"""

def __init__(
self,
bin_edges: Tensor,
batch_shape: torch.Size = torch.Size([]),
sigma_prior: Prior | None = None,
sigma_constraint: Interval | None = None,
) -> None:
super().__init__()

self.num_bins = len(bin_edges) + 1
self.register_parameter("bin_edges", torch.nn.Parameter(bin_edges, requires_grad=False))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.register_parameter("bin_edges", torch.nn.Parameter(bin_edges, requires_grad=False))
self.register_buffer("bin_edges", bin_edges)

nit: I think it makes more sense to register this as a buffer instead since we won't update the bin edges?

On the flip side, does it make sense to set requires_grad=True so that we learn the bin edges during model fitting? (Some packages choose to do so; see here.) IIUC, we only learn sigma here but the bin edges are fixed. I am wondering if this could limit the expressiveness of the likelihood.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the code to allow for learnable edges but default to fixed


if sigma_constraint is None:
sigma_constraint = Positive()

self.raw_sigma = torch.nn.Parameter(torch.ones(*batch_shape, 1))
if sigma_prior is not None:
self.register_prior("sigma_prior", sigma_prior, lambda m: m.sigma, lambda m, v: m._set_sigma(v))

self.register_constraint("raw_sigma", sigma_constraint)

@property
def sigma(self) -> Tensor:
return self.raw_sigma_constraint.transform(self.raw_sigma)

@sigma.setter
def sigma(self, value: Tensor) -> None:
self._set_sigma(value)

def _set_sigma(self, value: Tensor) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_sigma)
self.initialize(raw_sigma=self.raw_sigma_constraint.inverse_transform(value))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We've already annotated value as tensor. So we could drop the if-statement here? Also, maybe we could merge this method with the sigma setter method above?


def forward(self, function_samples: Tensor, *args: Any, data: dict[str, Tensor] = {}, **kwargs: Any) -> Categorical:
# Compute scaled bin edges
scaled_edges = self.bin_edges / self.sigma
scaled_edges_left = torch.cat([scaled_edges, torch.tensor([torch.inf], device=scaled_edges.device)], dim=-1)
scaled_edges_right = torch.cat([torch.tensor([-torch.inf], device=scaled_edges.device), scaled_edges])

# Calculate cumulative probabilities using standard normal CDF (probit function)
function_samples = function_samples.unsqueeze(-1)
scaled_edges_left = scaled_edges_left.reshape(1, -1)
scaled_edges_right = scaled_edges_right.reshape(1, -1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will these two lines work in batch settings where the batch shape is non-empty?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed + added test to confirm

probs = inv_probit(scaled_edges_left - function_samples / self.sigma) - inv_probit(
scaled_edges_right - function_samples / self.sigma
)

return Categorical(probs=probs)
30 changes: 30 additions & 0 deletions test/likelihoods/test_ordinal_likelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/usr/bin/env python3

import unittest

import torch
from torch.distributions import Distribution

from gpytorch.likelihoods import OrdinalLikelihood
from gpytorch.test.base_likelihood_test_case import BaseLikelihoodTestCase


class TestOrdinalLikelihood(BaseLikelihoodTestCase, unittest.TestCase):
seed = 0

def create_likelihood(self):
bin_edges = torch.tensor([-0.5, 0.5])
return OrdinalLikelihood(bin_edges)

def _create_targets(self, batch_shape=torch.Size([])):
return torch.distributions.Categorical(probs=torch.tensor([1 / 3, 1 / 3, 1 / 3])).sample(
torch.Size([*batch_shape, 5])
)

def _test_marginal(self, batch_shape):
likelihood = self.create_likelihood()
input = self._create_marginal_input(batch_shape)
output = likelihood(input)

self.assertTrue(isinstance(output, Distribution))
self.assertEqual(output.sample().shape[-len(batch_shape) - 1 :], torch.Size([*batch_shape, 5]))