|
13 | 13 | # limitations under the License. |
14 | 14 | # ============================================================================== |
15 | 15 |
|
16 | | -# from __future__ import annotations |
17 | 16 | from abc import abstractmethod |
18 | 17 | from dataclasses import ( |
19 | 18 | dataclass, |
|
25 | 24 | Any, |
26 | 25 | Callable, |
27 | 26 | Optional, |
| 27 | + Union, |
28 | 28 | ) |
29 | 29 | import cola |
| 30 | +from cola.linalg.decompositions.decompositions import Cholesky |
30 | 31 | from cola.ops import Dense |
31 | 32 | import jax.numpy as jnp |
32 | 33 | from jax.random import ( |
@@ -152,17 +153,17 @@ class Prior(AbstractPrior): |
152 | 153 | ``` |
153 | 154 | """ |
154 | 155 |
|
155 | | - @overload |
156 | | - def __mul__(self, other: Gaussian) -> "ConjugatePosterior": |
157 | | - ... |
| 156 | + # @overload |
| 157 | + # def __mul__(self, other: Gaussian) -> "ConjugatePosterior": |
| 158 | + # ... |
158 | 159 |
|
159 | | - @overload |
160 | | - def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior": |
161 | | - ... |
| 160 | + # @overload |
| 161 | + # def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior": |
| 162 | + # ... |
162 | 163 |
|
163 | | - @overload |
164 | | - def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior": |
165 | | - ... |
| 164 | + # @overload |
| 165 | + # def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior": |
| 166 | + # ... |
166 | 167 |
|
167 | 168 | def __mul__(self, other): |
168 | 169 | r"""Combine the prior with a likelihood to form a posterior distribution. |
@@ -198,17 +199,17 @@ def __mul__(self, other): |
198 | 199 | """ |
199 | 200 | return construct_posterior(prior=self, likelihood=other) |
200 | 201 |
|
201 | | - @overload |
202 | | - def __rmul__(self, other: Gaussian) -> "ConjugatePosterior": |
203 | | - ... |
| 202 | + # @overload |
| 203 | + # def __rmul__(self, other: Gaussian) -> "ConjugatePosterior": |
| 204 | + # ... |
204 | 205 |
|
205 | | - @overload |
206 | | - def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior": |
207 | | - ... |
| 206 | + # @overload |
| 207 | + # def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior": |
| 208 | + # ... |
208 | 209 |
|
209 | | - @overload |
210 | | - def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior": |
211 | | - ... |
| 210 | + # @overload |
| 211 | + # def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior": |
| 212 | + # ... |
212 | 213 |
|
213 | 214 | def __rmul__(self, other): |
214 | 215 | r"""Combine the prior with a likelihood to form a posterior distribution. |
@@ -540,7 +541,7 @@ def predict( |
540 | 541 | # Σ⁻¹ Kxt |
541 | 542 | if mask is not None: |
542 | 543 | Kxt = jnp.where(mask * jnp.ones((1, n_train), dtype=bool), 0.0, Kxt) |
543 | | - Sigma_inv_Kxt = cola.solve(Sigma, Kxt) |
| 544 | + Sigma_inv_Kxt = cola.solve(Sigma, Kxt, Cholesky()) |
544 | 545 |
|
545 | 546 | # μt + Ktx (Kxx + Io²)⁻¹ (y - μx) |
546 | 547 | mean = mean_t.flatten() + Sigma_inv_Kxt.T @ (y - mx).flatten() |
@@ -618,7 +619,9 @@ def sample_approx( |
618 | 619 | y = train_data.y - self.prior.mean_function(train_data.X) # account for mean |
619 | 620 | Phi = fourier_feature_fn(train_data.X) |
620 | 621 | canonical_weights = cola.solve( |
621 | | - Sigma, y + eps - jnp.inner(Phi, fourier_weights) |
| 622 | + Sigma, |
| 623 | + y + eps - jnp.inner(Phi, fourier_weights), |
| 624 | + Cholesky(), |
622 | 625 | ) # [N, B] |
623 | 626 |
|
624 | 627 | def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]: |
@@ -656,7 +659,7 @@ class NonConjugatePosterior(AbstractPosterior): |
656 | 659 | from, or optimise an approximation to, the posterior distribution. |
657 | 660 | """ |
658 | 661 |
|
659 | | - latent: Float[Array, "N 1"] = param_field(None) |
| 662 | + latent: Union[Float[Array, "N 1"], None] = param_field(None) |
660 | 663 | key: KeyArray = static_field(PRNGKey(42)) |
661 | 664 |
|
662 | 665 | def __post_init__(self): |
@@ -707,7 +710,7 @@ def predict( |
707 | 710 | mean_t = mean_function(t) |
708 | 711 |
|
709 | 712 | # Lx⁻¹ Kxt |
710 | | - Lx_inv_Kxt = cola.solve(Lx, Ktx.T) |
| 713 | + Lx_inv_Kxt = cola.solve(Lx, Ktx.T, Cholesky()) |
711 | 714 |
|
712 | 715 | # Whitened function values, wx, corresponding to the inputs, x |
713 | 716 | wx = self.latent |
|
0 commit comments