Skip to content

Commit 8ac2427

Browse files
committed
bump jaxtyping
1 parent 3b60094 commit 8ac2427

File tree

14 files changed

+106
-74
lines changed

14 files changed

+106
-74
lines changed

examples/classification.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"import distrax as dx\n",
2828
"from gpjax.utils import I\n",
2929
"import jax.scipy as jsp\n",
30-
"from jaxtyping import f64\n",
30+
"from jaxtyping import Float, Array\n",
3131
"\n",
3232
"key = jr.PRNGKey(123)"
3333
]
@@ -287,7 +287,7 @@
287287
"from gpjax.kernels import gram, cross_covariance\n",
288288
"\n",
289289
"\n",
290-
"def predict(laplace_at_data: dx.Distribution, train_data: Dataset, test_inputs: f64[\"N D\"], jitter: int = 1e-6) -> dx.Distribution:\n",
290+
"def predict(laplace_at_data: dx.Distribution, train_data: Dataset, test_inputs: Float[Array, \"N D\"], jitter: int = 1e-6) -> dx.Distribution:\n",
291291
" \"\"\"Compute the predictive distribution of the Laplace approximation at novel inputs.\n",
292292
"\n",
293293
" Args:\n",

examples/kernels.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
"metadata": {},
1818
"outputs": [],
1919
"source": [
20+
"import gpjax as gpx\n",
21+
"\n",
2022
"import jax.numpy as jnp\n",
2123
"import jax.random as jr\n",
2224
"import matplotlib.pyplot as plt\n",
@@ -25,8 +27,7 @@
2527
"import jax\n",
2628
"from optax import adam\n",
2729
"import distrax as dx\n",
28-
"\n",
29-
"import gpjax as gpx\n",
30+
"from jaxtyping import Float, Array\n",
3031
"\n",
3132
"key = jr.PRNGKey(123)"
3233
]
@@ -261,7 +262,6 @@
261262
"outputs": [],
262263
"source": [
263264
"from chex import dataclass\n",
264-
"from jaxtyping import f64\n",
265265
"\n",
266266
"\n",
267267
"def angular_distance(x, y, c):\n",
@@ -275,7 +275,7 @@
275275
" def __post_init__(self):\n",
276276
" self.c = self.period / 2.0 # in [0, \\pi]\n",
277277
"\n",
278-
" def __call__(self, x: f64[\"1 D\"], y: f64[\"1 D\"], params: dict) -> f64[\"1\"]:\n",
278+
" def __call__(self, x: Float[Array, \"1 D\"], y: Float[Array, \"1 D\"], params: dict) -> Float[Array, \"1\"]:\n",
279279
" tau = params[\"tau\"]\n",
280280
" t = angular_distance(x, y, self.c)\n",
281281
" K = (1 + tau * t / self.c) * jnp.clip(1 - t / self.c, 0, jnp.inf) ** tau\n",

gpjax/abstractions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from chex import dataclass
88
from jax import lax
99
from jax.experimental import host_callback
10-
from jaxtyping import f64
10+
from jaxtyping import Array, Float
1111
from tqdm.auto import tqdm
1212

1313
from .parameters import trainable_params
@@ -17,7 +17,7 @@
1717
@dataclass(frozen=True)
1818
class InferenceState:
1919
params: tp.Dict
20-
history: f64["n_iters"]
20+
history: Float[Array, "n_iters"]
2121

2222
def unpack(self):
2323
return self.params, self.history
@@ -113,7 +113,7 @@ def fit(
113113
n_iters (int, optional): The number of optimisation steps to run. Defaults to 100.
114114
log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10.
115115
Returns:
116-
tp.Tuple[tp.Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively.
116+
InferenceState: An InferenceState object comprising the optimised parameters and training history respectively.
117117
"""
118118
opt_state = optax_optim.init(params)
119119

@@ -161,7 +161,7 @@ def fit_batches(
161161
n_iters (int, optional): The number of optimisation steps to run. Defaults to 100.
162162
log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10.
163163
Returns:
164-
tp.Tuple[tp.Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively.
164+
InferenceState: An InferenceState object comprising the optimised parameters and training history respectively.
165165
"""
166166

167167
opt_state = optax_optim.init(params)

gpjax/gps.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import jax.random as jr
88
import jax.scipy as jsp
99
from chex import dataclass
10-
from jaxtyping import f64
10+
from jaxtyping import Array, Float
1111

1212
from .config import get_defaults
1313
from .kernels import Kernel, cross_covariance, gram
@@ -74,15 +74,17 @@ def __rmul__(self, other: AbstractLikelihood):
7474
"""Reimplement the multiplication operator to allow for order-invariant product of a likelihood and a prior i.e., likelihood * prior."""
7575
return self.__mul__(other)
7676

77-
def predict(self, params: dict) -> tp.Callable[[f64["N D"]], dx.Distribution]:
77+
def predict(
78+
self, params: dict
79+
) -> tp.Callable[[Float[Array, "N D"]], dx.Distribution]:
7880
"""Compute the GP's prior mean and variance.
7981
Args:
8082
params (dict): The specific set of parameters for which the mean function should be defined for.
8183
Returns:
8284
tp.Callable[[Array], Array]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned.
8385
"""
8486

85-
def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution:
87+
def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution:
8688
t = test_inputs
8789
n_test = t.shape[0]
8890
μt = self.mean_function(t, params["mean_function"])
@@ -139,7 +141,7 @@ class ConjugatePosterior(AbstractPosterior):
139141

140142
def predict(
141143
self, train_data: Dataset, params: dict
142-
) -> tp.Callable[[f64["N D"]], dx.Distribution]:
144+
) -> tp.Callable[[Float[Array, "N D"]], dx.Distribution]:
143145
"""Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density.
144146
145147
Args:
@@ -166,7 +168,7 @@ def predict(
166168
# w = L⁻¹ (y - μx)
167169
w = jsp.linalg.solve_triangular(L, y - μx, lower=True)
168170

169-
def predict(test_inputs: f64["N D"]) -> dx.Distribution:
171+
def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution:
170172
t = test_inputs
171173
n_test = t.shape[0]
172174
μt = self.prior.mean_function(t, params["mean_function"])
@@ -195,7 +197,7 @@ def marginal_log_likelihood(
195197
transformations: Dict,
196198
priors: dict = None,
197199
negative: bool = False,
198-
) -> tp.Callable[[dict], f64["1"]]:
200+
) -> tp.Callable[[dict], Float[Array, "1"]]:
199201
"""Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here enables exact estimation of the Gaussian process' latent function values.
200202
201203
Args:
@@ -261,7 +263,7 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
261263

262264
def predict(
263265
self, train_data: Dataset, params: dict
264-
) -> tp.Callable[[f64["N D"]], dx.Distribution]:
266+
) -> tp.Callable[[Float[Array, "N D"]], dx.Distribution]:
265267
"""Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Note, to gain predictions on the scale of the original data, the returned distribution will need to be transformed through the likelihood function's inverse link function.
266268
267269
Args:
@@ -277,7 +279,7 @@ def predict(
277279
Kxx += I(n) * self.jitter
278280
Lx = jnp.linalg.cholesky(Kxx)
279281

280-
def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution:
282+
def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution:
281283
t = test_inputs
282284
n_test = t.shape[0]
283285
Ktx = cross_covariance(self.prior.kernel, t, x, params["kernel"])
@@ -306,7 +308,7 @@ def marginal_log_likelihood(
306308
transformations: Dict,
307309
priors: dict = None,
308310
negative: bool = False,
309-
) -> tp.Callable[[dict], f64["1"]]:
311+
) -> tp.Callable[[dict], Float[Array, "1"]]:
310312
"""Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here is general and will work for any likelihood support by GPJax.
311313
312314
Args:

gpjax/kernels.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import jax.numpy as jnp
55
from chex import dataclass
66
from jax import vmap
7-
from jaxtyping import f64
7+
from jaxtyping import Array, Float
88

99

1010
##########################################
@@ -23,7 +23,9 @@ def __post_init__(self):
2323
self.ndims = 1 if not self.active_dims else len(self.active_dims)
2424

2525
@abc.abstractmethod
26-
def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
26+
def __call__(
27+
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
28+
) -> Float[Array, "1"]:
2729
"""Evaluate the kernel on a pair of inputs.
2830
Args:
2931
x (jnp.DeviceArray): The left hand argument of the kernel function's call.
@@ -34,7 +36,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
3436
"""
3537
raise NotImplementedError
3638

37-
def slice_input(self, x: f64["N D"]) -> f64["N Q"]:
39+
def slice_input(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]:
3840
"""Select the relevant columns of the supplied matrix to be used within the kernel's evaluation.
3941
Args:
4042
x (Array): The matrix or vector that is to be sliced.
@@ -101,7 +103,9 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
101103
"""A template dictionary of the kernel's parameter set."""
102104
return [kernel._initialise_params(key) for kernel in self.kernel_set]
103105

104-
def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
106+
def __call__(
107+
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
108+
) -> Float[Array, "1"]:
105109
return self.combination_fn(
106110
jnp.stack([k(x, y, p) for k, p in zip(self.kernel_set, params)])
107111
)
@@ -135,7 +139,9 @@ class RBF(Kernel):
135139
def __post_init__(self):
136140
self.ndims = 1 if not self.active_dims else len(self.active_dims)
137141

138-
def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
142+
def __call__(
143+
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
144+
) -> Float[Array, "1"]:
139145
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma`
140146
141147
.. math::
@@ -170,7 +176,9 @@ class Matern12(Kernel):
170176
def __post_init__(self):
171177
self.ndims = 1 if not self.active_dims else len(self.active_dims)
172178

173-
def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
179+
def __call__(
180+
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
181+
) -> Float[Array, "1"]:
174182
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma`
175183
176184
.. math::
@@ -204,7 +212,9 @@ class Matern32(Kernel):
204212
def __post_init__(self):
205213
self.ndims = 1 if not self.active_dims else len(self.active_dims)
206214

207-
def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
215+
def __call__(
216+
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
217+
) -> Float[Array, "1"]:
208218
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\ell` and variance :math:`\sigma`
209219
210220
.. math::
@@ -244,7 +254,9 @@ class Matern52(Kernel):
244254
def __post_init__(self):
245255
self.ndims = 1 if not self.active_dims else len(self.active_dims)
246256

247-
def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
257+
def __call__(
258+
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
259+
) -> Float[Array, "1"]:
248260
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\ell` and variance :math:`\sigma`
249261
250262
.. math::
@@ -286,7 +298,9 @@ def __post_init__(self):
286298
self.ndims = 1 if not self.active_dims else len(self.active_dims)
287299
self.name = f"Polynomial Degree: {self.degree}"
288300

289-
def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
301+
def __call__(
302+
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
303+
) -> Float[Array, "1"]:
290304
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\alpha` and variance :math:`\sigma` through
291305
292306
.. math::
@@ -317,7 +331,7 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
317331
##########################################
318332
@dataclass
319333
class _EigenKernel:
320-
laplacian: f64["N N"]
334+
laplacian: Float[Array, "N N"]
321335

322336

323337
@dataclass
@@ -330,7 +344,9 @@ def __post_init__(self):
330344
self.evals = evals.reshape(-1, 1)
331345
self.num_vertex = self.laplacian.shape[0]
332346

333-
def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
347+
def __call__(
348+
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
349+
) -> Float[Array, "1"]:
334350
"""Evaluate the graph kernel on a pair of vertices v_i, v_j.
335351
336352
Args:
@@ -361,17 +377,23 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
361377
}
362378

363379

364-
def squared_distance(x: f64["1 D"], y: f64["1 D"]) -> f64["1"]:
380+
def squared_distance(
381+
x: Float[Array, "1 D"], y: Float[Array, "1 D"]
382+
) -> Float[Array, "1"]:
365383
"""Compute the squared distance between a pair of inputs."""
366384
return jnp.sum((x - y) ** 2)
367385

368386

369-
def euclidean_distance(x: f64["1 D"], y: f64["1 D"]) -> f64["1"]:
387+
def euclidean_distance(
388+
x: Float[Array, "1 D"], y: Float[Array, "1 D"]
389+
) -> Float[Array, "1"]:
370390
"""Compute the l1 norm between a pair of inputs."""
371391
return jnp.sqrt(jnp.maximum(jnp.sum((x - y) ** 2), 1e-36))
372392

373393

374-
def gram(kernel: Kernel, inputs: f64["N D"], params: dict) -> f64["N N"]:
394+
def gram(
395+
kernel: Kernel, inputs: Float[Array, "N D"], params: dict
396+
) -> Float[Array, "N N"]:
375397
"""For a given kernel, compute the :math:`n \times n` gram matrix on an input matrix of shape :math:`n \times d` for :math:`d\geq 1`.
376398
377399
Args:
@@ -386,8 +408,8 @@ def gram(kernel: Kernel, inputs: f64["N D"], params: dict) -> f64["N N"]:
386408

387409

388410
def cross_covariance(
389-
kernel: Kernel, x: f64["N D"], y: f64["M D"], params: dict
390-
) -> f64["N M"]:
411+
kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"], params: dict
412+
) -> Float[Array, "N M"]:
391413
"""For a given kernel, compute the :math:`m \times n` gram matrix on an a pair of input matrices with shape :math:`m \times d` and :math:`n \times d` for :math:`d\geq 1`.
392414
393415
Args:
@@ -402,7 +424,9 @@ def cross_covariance(
402424
return vmap(lambda x1: vmap(lambda y1: kernel(x1, y1, params))(y))(x)
403425

404426

405-
def diagonal(kernel: Kernel, inputs: f64["N D"], params: dict) -> f64["N N"]:
427+
def diagonal(
428+
kernel: Kernel, inputs: Float[Array, "N D"], params: dict
429+
) -> Float[Array, "N N"]:
406430
"""For a given kernel, compute the elementwise diagonal of the :math:`n \times n` gram matrix on an input matrix of shape :math:`n \times d` for :math:`d\geq 1`.
407431
Args:
408432
kernel (Kernel): The kernel for which the variance vector should be computed for.

gpjax/likelihoods.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import jax.numpy as jnp
66
import jax.scipy as jsp
77
from chex import dataclass
8-
from jaxtyping import f64
8+
from jaxtyping import Array, Float
99

1010
from .utils import I
1111

@@ -107,7 +107,9 @@ def predictive_moment_fn(self) -> Callable:
107107
Callable: A callable object that accepts a mean and variance term from which the predictive random variable is computed.
108108
"""
109109

110-
def moment_fn(mean: f64["N D"], variance: f64["N D"], params: Dict):
110+
def moment_fn(
111+
mean: Float[Array, "N D"], variance: Float[Array, "N D"], params: Dict
112+
):
111113
rv = self.link_function(mean / jnp.sqrt(1 + variance), params)
112114
return rv
113115

gpjax/mean_functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import jax.numpy as jnp
55
from chex import dataclass
6-
from jaxtyping import f64
6+
from jaxtyping import Array, Float
77

88

99
@dataclass(repr=False)
@@ -14,7 +14,7 @@ class AbstractMeanFunction:
1414
name: Optional[str] = "Mean function"
1515

1616
@abc.abstractmethod
17-
def __call__(self, x: f64["N D"]) -> f64["N Q"]:
17+
def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]:
1818
"""Evaluate the mean function at the given points. This method is required for all subclasses.
1919
2020
Args:
@@ -44,7 +44,7 @@ class Zero(AbstractMeanFunction):
4444
output_dim: Optional[int] = 1
4545
name: Optional[str] = "Zero mean function"
4646

47-
def __call__(self, x: f64["N D"], params: dict) -> f64["N Q"]:
47+
def __call__(self, x: Float[Array, "N D"], params: dict) -> Float[Array, "N Q"]:
4848
"""Evaluate the mean function at the given points.
4949
5050
Args:
@@ -72,7 +72,7 @@ class Constant(AbstractMeanFunction):
7272
output_dim: Optional[int] = 1
7373
name: Optional[str] = "Constant mean function"
7474

75-
def __call__(self, x: f64["N D"], params: Dict) -> f64["N Q"]:
75+
def __call__(self, x: Float[Array, "N D"], params: Dict) -> Float[Array, "N Q"]:
7676
"""Evaluate the mean function at the given points.
7777
7878
Args:

0 commit comments

Comments
 (0)