Skip to content

Commit acf7d2f

Browse files
committed
Format and increment
1 parent 24fbde8 commit acf7d2f

21 files changed

+177
-151
lines changed

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
author = "Thomas Pinder"
2929

3030
# The full version, including alpha/beta/rc tags
31-
release = "0.2.0"
31+
release = "0.3"
3232

3333

3434
# -- General configuration ---------------------------------------------------

gpjax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
from .predict import mean, variance
66
from .sampling import random_variable, sample
77

8-
__version__ = "0.3.0"
8+
__version__ = "0.3.1"

gpjax/kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __call__(self, x: Array, y: Array) -> Array:
3636

3737
@dispatch(RBF)
3838
def initialise(kernel: RBF):
39-
return {"lengthscale": jnp.array([1.0]*kernel.ndims), "variance": jnp.array([1.0])}
39+
return {"lengthscale": jnp.array([1.0] * kernel.ndims), "variance": jnp.array([1.0])}
4040

4141

4242
def squared_distance(x: Array, y: Array):

gpjax/mean_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class Zero(MeanFunction):
2323

2424
def __call__(self, x: Array) -> Array:
2525
out_shape = (x.shape[0], self.output_dim)
26-
return jnp.zeros(shape = out_shape)
26+
return jnp.zeros(shape=out_shape)
2727

2828

2929
@dispatch(Zero)

gpjax/objectives/mlls.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def mll(params: dict, x: Array, y: Array, priors: dict = None):
3939
log_prior_density = evaluate_prior(params, priors)
4040
constant = jnp.array(-1.0) if negative else jnp.array(1.0)
4141
return constant * (random_variable.log_prob(y.squeeze()).mean() + log_prior_density)
42+
4243
return mll
4344

4445

@@ -49,7 +50,9 @@ def marginal_ll(
4950
negative: bool = False,
5051
jitter: float = 1e-6,
5152
) -> Callable:
52-
def mll(params: dict, x: Array, y: Array, priors: dict = {'latent': tfd.Normal(loc=0., scale=1.)}):
53+
def mll(
54+
params: dict, x: Array, y: Array, priors: dict = {"latent": tfd.Normal(loc=0.0, scale=1.0)}
55+
):
5356
params = untransform(params, transformation)
5457
n = x.shape[0]
5558
link = link_function(gp.likelihood)

gpjax/parameters/priors.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import warnings
2+
13
import jax.numpy as jnp
24
from jax.interpreters.ad import JVPTracer
35
from jax.interpreters.partial_eval import DynamicJaxprTracer
46
from multipledispatch import dispatch
57
from tensorflow_probability.substrates.jax import distributions as tfd
6-
from ..gps import NonConjugatePosterior
7-
import warnings
88

9+
from ..gps import NonConjugatePosterior
910
from ..types import Array, NoneType
1011

1112

@@ -16,24 +17,26 @@ def log_density(param: jnp.DeviceArray, density: tfd.Distribution) -> Array:
1617

1718
@dispatch(dict, NoneType)
1819
def evaluate_prior(params: dict, priors: dict) -> Array:
19-
return jnp.array(0.)
20+
return jnp.array(0.0)
2021

2122

2223
@dispatch(dict, dict)
2324
def evaluate_prior(params: dict, priors: dict) -> Array:
2425
lpd = jnp.array(0)
2526
for param, val in priors.items():
26-
lpd+=jnp.sum(log_density(params[param], priors[param]))
27+
lpd += jnp.sum(log_density(params[param], priors[param]))
2728
return lpd
2829

2930

3031
@dispatch(NonConjugatePosterior, dict)
3132
def prior_checks(gp: NonConjugatePosterior, priors: dict) -> dict:
32-
if 'latent' in priors.keys():
33-
latent_prior = priors['latent']
34-
if latent_prior.name != 'Normal':
35-
warnings.warn(f'A {latent_prior.name} distribution prior has been placed on the latent function. It is strongly afvised that a unit-Gaussian prior is used.')
33+
if "latent" in priors.keys():
34+
latent_prior = priors["latent"]
35+
if latent_prior.name != "Normal":
36+
warnings.warn(
37+
f"A {latent_prior.name} distribution prior has been placed on the latent function. It is strongly afvised that a unit-Gaussian prior is used."
38+
)
3639
return priors
3740
else:
38-
priors['latent'] = tfd.Normal(loc=0., scale=1.)
39-
return priors
41+
priors["latent"] = tfd.Normal(loc=0.0, scale=1.0)
42+
return priors

gpjax/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66

77
# Array = Union[jnp.ndarray, ShardedDeviceArray, jnp.DeviceArray] # Cannot currently dispatch on a Union type
88
# Data = Tuple[Array, Array]
9-
NoneType = type(None)
9+
NoneType = type(None)

gpjax/utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import jax.numpy as jnp
21
from typing import Tuple
2+
3+
import jax.numpy as jnp
34
from multipledispatch import dispatch
45

56
from .types import Array
@@ -50,11 +51,13 @@ def standardise(x: jnp.DeviceArray) -> Tuple[jnp.DeviceArray, jnp.DeviceArray, j
5051
"""
5152
xmean = jnp.mean(x, axis=0)
5253
xstd = jnp.std(x, axis=0)
53-
return (x-xmean)/xstd, xmean, xstd
54+
return (x - xmean) / xstd, xmean, xstd
5455

5556

5657
@dispatch(jnp.DeviceArray, jnp.DeviceArray, jnp.DeviceArray)
57-
def standardise(x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceArray) -> jnp.DeviceArray:
58+
def standardise(
59+
x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceArray
60+
) -> jnp.DeviceArray:
5861
"""
5962
Standardise a given matrix with respect to a given mean and standard deviation. This is primarily designed for
6063
standardising a test set of data with respect to the training data.
@@ -64,11 +67,12 @@ def standardise(x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceArra
6467
:param xstd: A precomputed standard deviation vector
6568
:return: A matrix of standardised values
6669
"""
67-
return (x-xmean)/xstd
68-
70+
return (x - xmean) / xstd
6971

7072

71-
def unstandardise(x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceArray) -> jnp.DeviceArray:
73+
def unstandardise(
74+
x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceArray
75+
) -> jnp.DeviceArray:
7276
"""
7377
Unstandardise a given matrix with respect to a previously computed mean and standard deviation. This is designed
7478
for remapping a matrix back onto its original scale.
@@ -78,4 +82,4 @@ def unstandardise(x: jnp.DeviceArray, xmean: jnp.DeviceArray, xstd: jnp.DeviceAr
7882
:param xstd: A standard deviation vector.
7983
:return: A matrix of unstandardised values.
8084
"""
81-
return (x*xstd) + xmean
85+
return (x * xstd) + xmean

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def parse_requirements_file(filename):
2929

3030
setup(
3131
name="GPJax",
32-
version="0.3.0",
32+
version="0.3.1",
3333
author="Thomas Pinder",
3434
author_email="[email protected]",
3535
packages=find_packages(".", exclude=["tests"]),

tests/objectives/test_mlls.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,38 @@
1-
from gpjax.objectives import marginal_ll
2-
from gpjax import Prior
3-
from gpjax.likelihoods import Bernoulli, Gaussian
4-
from gpjax.kernels import RBF
5-
from gpjax.parameters import transform, SoftplusTransformation, initialise
1+
from typing import Callable
2+
63
import jax.numpy as jnp
74
import jax.random as jr
85
import pytest
9-
from typing import Callable
106
from tensorflow_probability.substrates.jax import distributions as tfd
117

8+
from gpjax import Prior
9+
from gpjax.kernels import RBF
10+
from gpjax.likelihoods import Bernoulli, Gaussian
11+
from gpjax.objectives import marginal_ll
12+
from gpjax.parameters import SoftplusTransformation, initialise, transform
13+
1214

1315
def test_conjugate():
14-
posterior = Prior(kernel = RBF()) * Gaussian()
16+
posterior = Prior(kernel=RBF()) * Gaussian()
1517
mll = marginal_ll(posterior)
1618
assert isinstance(mll, Callable)
1719
neg_mll = marginal_ll(posterior, negative=True)
18-
x = jnp.linspace(-1., 1., 20).reshape(-1, 1)
20+
x = jnp.linspace(-1.0, 1.0, 20).reshape(-1, 1)
1921
y = jnp.sin(x)
2022
params = transform(params=initialise(posterior), transformation=SoftplusTransformation)
21-
assert neg_mll(params, x, y) == jnp.array(-1.)*mll(params, x, y)
23+
assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
2224

2325

2426
def test_non_conjugate():
25-
posterior = Prior(kernel = RBF()) * Bernoulli()
27+
posterior = Prior(kernel=RBF()) * Bernoulli()
2628
mll = marginal_ll(posterior)
2729
assert isinstance(mll, Callable)
2830
neg_mll = marginal_ll(posterior, negative=True)
2931
n = 20
30-
x = jnp.linspace(-1., 1., n).reshape(-1, 1)
32+
x = jnp.linspace(-1.0, 1.0, n).reshape(-1, 1)
3133
y = jnp.sin(x)
3234
params = transform(params=initialise(posterior, n), transformation=SoftplusTransformation)
33-
assert neg_mll(params, x, y) == jnp.array(-1.)*mll(params, x, y)
35+
assert neg_mll(params, x, y) == jnp.array(-1.0) * mll(params, x, y)
3436

3537

3638
def test_prior_mll():
@@ -54,4 +56,4 @@ def test_prior_mll():
5456
mll_eval_priors = mll(params, x, y, priors)
5557

5658
assert pytest.approx(mll_eval) == jnp.array(-115.72332969)
57-
assert pytest.approx(mll_eval_priors) == jnp.array(-118.97202259)
59+
assert pytest.approx(mll_eval_priors) == jnp.array(-118.97202259)

0 commit comments

Comments
 (0)