Skip to content

Commit ae9cfa3

Browse files
authored
Merge pull request #405 from JaxGaussianProcesses/bump_cola
bump cola to v0.0.5
2 parents 70838a8 + 392b3da commit ae9cfa3

File tree

15 files changed

+1802
-1811
lines changed

15 files changed

+1802
-1811
lines changed

docs/examples/collapsed_vi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,11 @@
244244
full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian(
245245
num_datapoints=D.n
246246
)
247-
negative_mll = jit(gpx.ConjugateMLL(negative=True))
247+
negative_mll = jit(gpx.ConjugateMLL(negative=True).step)
248248
# %timeit negative_mll(full_rank_model, D).block_until_ready()
249249

250250
# %%
251-
negative_elbo = jit(gpx.CollapsedELBO(negative=True))
251+
negative_elbo = jit(gpx.CollapsedELBO(negative=True).step)
252252
# %timeit negative_elbo(q, D).block_until_ready()
253253

254254
# %% [markdown]

docs/examples/graph_kernels.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# -*- coding: utf-8 -*-
12
# %% [markdown]
23
# # Graph Kernels
34
#
@@ -119,7 +120,8 @@
119120
cmap=plt.cm.inferno, norm=plt.Normalize(vmin=vmin, vmax=vmax)
120121
)
121122
sm.set_array([])
122-
cbar = plt.colorbar(sm)
123+
ax = plt.gca()
124+
cbar = plt.colorbar(sm, ax=ax)
123125

124126
# %% [markdown]
125127
#
@@ -201,8 +203,8 @@
201203
sm = plt.cm.ScalarMappable(
202204
cmap=plt.cm.inferno, norm=plt.Normalize(vmin=vmin, vmax=vmax)
203205
)
204-
sm.set_array([])
205-
cbar = plt.colorbar(sm)
206+
ax = plt.gca()
207+
cbar = plt.colorbar(sm, ax=ax)
206208

207209
# %% [markdown]
208210
#

gpjax/citation.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@
3131
NonConjugateMLL,
3232
)
3333

34-
CitationType = Union[str, Dict[str, str]]
34+
CitationType = Union[None, str, Dict[str, str]]
3535

3636

3737
@dataclass(repr=False)
3838
class AbstractCitation:
39-
citation_key: str = None
40-
authors: str = None
41-
title: str = None
42-
year: str = None
39+
citation_key: Union[str, None] = None
40+
authors: Union[str, None] = None
41+
title: Union[str, None] = None
42+
year: Union[str, None] = None
4343

4444
def as_str(self) -> str:
4545
citation_str = f"@{self.citation_type}{{{self.citation_key},"
@@ -64,29 +64,24 @@ def __str__(self) -> str:
6464
)
6565

6666

67-
class JittedFnCitation(AbstractCitation):
68-
def __str__(self) -> str:
69-
return "Citation not available for jitted objects."
70-
71-
7267
@dataclass
7368
class PhDThesisCitation(AbstractCitation):
74-
school: str = None
75-
institution: str = None
76-
citation_type: str = "phdthesis"
69+
school: Union[str, None] = None
70+
institution: Union[str, None] = None
71+
citation_type: CitationType = "phdthesis"
7772

7873

7974
@dataclass
8075
class PaperCitation(AbstractCitation):
81-
booktitle: str = None
82-
citation_type: str = "inproceedings"
76+
booktitle: Union[str, None] = None
77+
citation_type: CitationType = "inproceedings"
8378

8479

8580
@dataclass
8681
class BookCitation(AbstractCitation):
87-
publisher: str = None
88-
volume: str = None
89-
citation_type: str = "book"
82+
publisher: Union[str, None] = None
83+
volume: Union[str, None] = None
84+
citation_type: CitationType = "book"
9085

9186

9287
####################
@@ -101,8 +96,8 @@ def cite(tree) -> AbstractCitation:
10196
# Default citation
10297
####################
10398
@cite.register(PjitFunction)
104-
def _(tree):
105-
return JittedFnCitation()
99+
def _(tree) -> None:
100+
raise RuntimeError("Citation not available for jitted objects.")
106101

107102

108103
####################

gpjax/distributions.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from cola.ops import (
2727
Dense,
2828
Identity,
29+
LinearOperator,
2930
)
3031
from jax import vmap
3132
import jax.numpy as jnp
@@ -45,6 +46,8 @@
4546

4647
tfd = tfp.distributions
4748

49+
from cola.linalg.decompositions.decompositions import Cholesky
50+
4851

4952
def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
5053
r"""Checks that the inputs are correct."""
@@ -60,9 +63,9 @@ def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
6063
f"`scale.shape = {scale.shape}`."
6164
)
6265

63-
if scale is not None and not isinstance(scale, cola.LinearOperator):
66+
if scale is not None and not isinstance(scale, LinearOperator):
6467
raise ValueError(
65-
f"The `scale` must be a cola.LinearOperator but got {type(scale)}"
68+
f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
6669
)
6770

6871
if scale is not None and (scale.shape[-1] != scale.shape[-2]):
@@ -84,7 +87,7 @@ class GaussianDistribution(tfd.Distribution):
8487
8588
Args:
8689
loc (Optional[Float[Array, " N"]]): The mean of the distribution. Defaults to None.
87-
scale (Optional[cola.LinearOperator]): The scale matrix of the distribution. Defaults to None.
90+
scale (Optional[LinearOperator]): The scale matrix of the distribution. Defaults to None.
8891
8992
Returns
9093
-------
@@ -99,7 +102,7 @@ class GaussianDistribution(tfd.Distribution):
99102
def __init__(
100103
self,
101104
loc: Optional[Float[Array, " N"]] = None,
102-
scale: Optional[cola.LinearOperator] = None,
105+
scale: Optional[LinearOperator] = None,
103106
) -> None:
104107
r"""Initialises the distribution."""
105108
_check_loc_scale(loc, scale)
@@ -155,9 +158,7 @@ def entropy(self) -> ScalarFloat:
155158
r"""Calculates the entropy of the distribution."""
156159
return 0.5 * (
157160
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
158-
+ cola.logdet(
159-
self.scale, method="dense"
160-
) # <--- Seems to be an issue with CoLA!
161+
+ cola.logdet(self.scale, Cholesky(), Cholesky())
161162
)
162163

163164
def log_prob(
@@ -191,8 +192,8 @@ def log_prob(
191192
# compute the pdf, -1/2[ n log(2π) + log|Σ| + (y - µ)ᵀΣ⁻¹(y - µ) ]
192193
return -0.5 * (
193194
n * jnp.log(2.0 * jnp.pi)
194-
+ cola.logdet(sigma, method="dense") # <--- Seems to be an issue with CoLA!
195-
+ diff.T @ cola.solve(sigma, diff)
195+
+ cola.logdet(sigma, Cholesky(), Cholesky())
196+
+ diff.T @ cola.solve(sigma, diff, Cholesky())
196197
)
197198

198199
def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
@@ -347,17 +348,19 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl
347348

348349
# trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])²
349350
trace = _frobenius_norm_squared(
350-
cola.solve(sqrt_p, sqrt_q.to_dense())
351+
cola.solve(sqrt_p, sqrt_q.to_dense(), Cholesky())
351352
) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator.
352353

353354
# Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])²
354-
mahalanobis = jnp.sum(
355-
jnp.square(cola.solve(sqrt_p, diff))
356-
) # TODO: Need to improve this. Perhaps add a Mahalanobis method to ``LinearOperator``s.
355+
mahalanobis = jnp.sum(jnp.square(cola.solve(sqrt_p, diff, Cholesky())))
357356

358357
# KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2
359358
return (
360-
mahalanobis - n_dim - cola.logdet(sigma_q) + cola.logdet(sigma_p) + trace
359+
mahalanobis
360+
- n_dim
361+
- cola.logdet(sigma_q, Cholesky(), Cholesky())
362+
+ cola.logdet(sigma_p, Cholesky(), Cholesky())
363+
+ trace
361364
) / 2.0
362365

363366

gpjax/gps.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16-
# from __future__ import annotations
1716
from abc import abstractmethod
1817
from dataclasses import (
1918
dataclass,
@@ -25,8 +24,10 @@
2524
Any,
2625
Callable,
2726
Optional,
27+
Union,
2828
)
2929
import cola
30+
from cola.linalg.decompositions.decompositions import Cholesky
3031
from cola.ops import Dense
3132
import jax.numpy as jnp
3233
from jax.random import (
@@ -152,17 +153,17 @@ class Prior(AbstractPrior):
152153
```
153154
"""
154155

155-
@overload
156-
def __mul__(self, other: Gaussian) -> "ConjugatePosterior":
157-
...
156+
# @overload
157+
# def __mul__(self, other: Gaussian) -> "ConjugatePosterior":
158+
# ...
158159

159-
@overload
160-
def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
161-
...
160+
# @overload
161+
# def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
162+
# ...
162163

163-
@overload
164-
def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
165-
...
164+
# @overload
165+
# def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
166+
# ...
166167

167168
def __mul__(self, other):
168169
r"""Combine the prior with a likelihood to form a posterior distribution.
@@ -198,17 +199,17 @@ def __mul__(self, other):
198199
"""
199200
return construct_posterior(prior=self, likelihood=other)
200201

201-
@overload
202-
def __rmul__(self, other: Gaussian) -> "ConjugatePosterior":
203-
...
202+
# @overload
203+
# def __rmul__(self, other: Gaussian) -> "ConjugatePosterior":
204+
# ...
204205

205-
@overload
206-
def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
207-
...
206+
# @overload
207+
# def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
208+
# ...
208209

209-
@overload
210-
def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
211-
...
210+
# @overload
211+
# def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
212+
# ...
212213

213214
def __rmul__(self, other):
214215
r"""Combine the prior with a likelihood to form a posterior distribution.
@@ -540,7 +541,7 @@ def predict(
540541
# Σ⁻¹ Kxt
541542
if mask is not None:
542543
Kxt = jnp.where(mask * jnp.ones((1, n_train), dtype=bool), 0.0, Kxt)
543-
Sigma_inv_Kxt = cola.solve(Sigma, Kxt)
544+
Sigma_inv_Kxt = cola.solve(Sigma, Kxt, Cholesky())
544545

545546
# μt + Ktx (Kxx + Io²)⁻¹ (y - μx)
546547
mean = mean_t.flatten() + Sigma_inv_Kxt.T @ (y - mx).flatten()
@@ -618,7 +619,9 @@ def sample_approx(
618619
y = train_data.y - self.prior.mean_function(train_data.X) # account for mean
619620
Phi = fourier_feature_fn(train_data.X)
620621
canonical_weights = cola.solve(
621-
Sigma, y + eps - jnp.inner(Phi, fourier_weights)
622+
Sigma,
623+
y + eps - jnp.inner(Phi, fourier_weights),
624+
Cholesky(),
622625
) # [N, B]
623626

624627
def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
@@ -656,7 +659,7 @@ class NonConjugatePosterior(AbstractPosterior):
656659
from, or optimise an approximation to, the posterior distribution.
657660
"""
658661

659-
latent: Float[Array, "N 1"] = param_field(None)
662+
latent: Union[Float[Array, "N 1"], None] = param_field(None)
660663
key: KeyArray = static_field(PRNGKey(42))
661664

662665
def __post_init__(self):
@@ -707,7 +710,7 @@ def predict(
707710
mean_t = mean_function(t)
708711

709712
# Lx⁻¹ Kxt
710-
Lx_inv_Kxt = cola.solve(Lx, Ktx.T)
713+
Lx_inv_Kxt = cola.solve(Lx, Ktx.T, Cholesky())
711714

712715
# Whitened function values, wx, corresponding to the inputs, x
713716
wx = self.latent

gpjax/integrators.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
11
from abc import abstractmethod
22
from dataclasses import dataclass
3-
from typing import TYPE_CHECKING
3+
from typing import (
4+
TypeVar,
5+
Union,
6+
)
47

58
from beartype.typing import Callable
69
import jax.numpy as jnp
710
from jaxtyping import Float
811
import numpy as np
912

13+
import gpjax
1014
from gpjax.typing import Array
1115

12-
if TYPE_CHECKING:
13-
import gpjax.likelihoods
16+
Likelihood = TypeVar(
17+
"Likelihood",
18+
bound=Union["gpjax.likelihoods.AbstractLikelihood", None], # noqa: F821
19+
)
20+
Gaussian = TypeVar("Gaussian", bound="gpjax.likelihoods.Gaussian") # noqa: F821
1421

1522

1623
@dataclass
@@ -24,7 +31,7 @@ def integrate(
2431
y: Float[Array, "N D"],
2532
mean: Float[Array, "N D"],
2633
variance: Float[Array, "N D"],
27-
likelihood: "gpjax.likelihoods.AbstractLikelihood" = None,
34+
likelihood: Likelihood,
2835
) -> Float[Array, " N"]:
2936
r"""Integrate a function with respect to a Gaussian distribution.
3037
@@ -47,7 +54,7 @@ def __call__(
4754
y: Float[Array, "N D"],
4855
mean: Float[Array, "N D"],
4956
variance: Float[Array, "N D"],
50-
likelihood: "gpjax.likelihoods.AbstractLikelihood" = None,
57+
likelihood: Likelihood,
5158
) -> Float[Array, " N"]:
5259
r"""Integrate a function with respect to a Gaussian distribution.
5360
@@ -86,7 +93,7 @@ def integrate(
8693
y: Float[Array, "N D"],
8794
mean: Float[Array, "N D"],
8895
variance: Float[Array, "N D"],
89-
likelihood: "gpjax.likelihoods.AbstractLikelihood" = None,
96+
likelihood: Likelihood,
9097
) -> Float[Array, " N"]:
9198
r"""Compute a quadrature integral.
9299
@@ -127,7 +134,7 @@ def integrate(
127134
y: Float[Array, "N D"],
128135
mean: Float[Array, "N D"],
129136
variance: Float[Array, "N D"],
130-
likelihood: "gpjax.likelihoods.Gaussian" = None,
137+
likelihood: Gaussian,
131138
) -> Float[Array, " N"]:
132139
r"""Compute a Gaussian integral.
133140

0 commit comments

Comments
 (0)