Skip to content

Commit 34d744f

Browse files
authored
Merge pull request #99 from thomaspinder/Fix-fit_batches-and-add-history
Remove TensorFlow dependency, Fix fit batches and add history
2 parents 99ef16f + 9b56623 commit 34d744f

File tree

8 files changed

+511
-59
lines changed

8 files changed

+511
-59
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ We have recently set up a Slack channel where we hope to facilitate discussions
3434

3535
- [**Conjugate Inference**](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)
3636
- [**Classification with MCMC**](https://gpjax.readthedocs.io/en/latest/nbs/classification.html)
37-
- [**Sparse Variational Inference**](https://gpjax.readthedocs.io/en/latest/nbs/sparse_regression.html)
37+
- [**Sparse Variational Inference**](https://gpjax.readthedocs.io/en/latest/nbs/uncollapsed_vi.html)
3838
- [**BlackJax Integration**](https://gpjax.readthedocs.io/en/latest/nbs/classification.html)
3939
- [**Laplace Approximations**](https://gpjax.readthedocs.io/en/latest/nbs/classification.html#Laplace-approximation)
40-
- [**TensorFlow Probability Integration**](https://gpjax.readthedocs.io/en/latest/nbs/tfp_intergation.html)
40+
- [**TensorFlow Probability Integration**](https://gpjax.readthedocs.io/en/latest/nbs/tfp_integration.html)
4141
- [**Inference on Non-Euclidean Spaces**](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html#Custom-Kernel)
4242
- [**Inference on Graphs**](https://gpjax.readthedocs.io/en/latest/nbs/graph_kernels.html)
4343
- [**Learning Gaussian Process Barycentres**](https://gpjax.readthedocs.io/en/latest/nbs/graph_kernels.html)
@@ -139,7 +139,7 @@ pip install gpjax
139139

140140
### Development version
141141

142-
To install the latest, possibly unstable, version, the following steps should be followed. It is by no means compulsory, but we do advise that you do all of the below inside a virtual environment.
142+
To install the latest (possibly unstable) version, the following steps should be followed. It is by no means compulsory, but we do advise that you do all of the below inside a virtual environment.
143143

144144
```bash
145145
git clone https://github.com/thomaspinder/GPJax.git
@@ -156,7 +156,7 @@ python -m pytest tests/
156156

157157
## Citing GPJax
158158

159-
If you use GPJax in your research, please cite our [JOSS paper](https://joss.theoj.org/papers/10.21105/joss.04455#). A sample Bibtex file is
159+
If you use GPJax in your research, please cite our [JOSS paper](https://joss.theoj.org/papers/10.21105/joss.04455#). Sample Bibtex is given below:
160160
```
161161
@article{Pinder2022,
162162
doi = {10.21105/joss.04455},

docs/_static/GP.svg

Lines changed: 434 additions & 0 deletions
Loading

docs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Welcome to GPJax!
77
GPJax is a didactic Gaussian process library that supports GPU acceleration and just-in-time compilation. We seek to provide a flexible API as close as possible to how the underlying mathematics is written on paper to enable researchers to rapidly prototype and develop new ideas.
88

99

10-
.. image:: ./_static/GP.pdf
10+
.. image:: ./_static/GP.svg
1111
:width: 800
1212
:alt: Gaussian process posterior.
1313
:align: center

docs/nbs/uncollapsed_vi.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@
255255
" train_data = D, \n",
256256
" optax_optim = optimiser,\n",
257257
" n_iters=4000,\n",
258-
" seed = 42,\n",
258+
" key = jr.PRNGKey(42),\n",
259259
" batch_size= 128\n",
260260
")\n",
261261
"learned_params, training_history = inference_state.unpack()\n",
@@ -322,7 +322,7 @@
322322
"custom_cell_magics": "kql"
323323
},
324324
"kernelspec": {
325-
"display_name": "Python 3.9.7 ('gpjax')",
325+
"display_name": "Python 3.10.0 ('base')",
326326
"language": "python",
327327
"name": "python3"
328328
},
@@ -336,11 +336,11 @@
336336
"name": "python",
337337
"nbconvert_exporter": "python",
338338
"pygments_lexer": "ipython3",
339-
"version": "3.9.7"
339+
"version": "3.10.0"
340340
},
341341
"vscode": {
342342
"interpreter": {
343-
"hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf"
343+
"hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf"
344344
}
345345
}
346346
},

gpjax/abstractions.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
import jax.numpy as jnp
55
import jax.random as jr
66
import optax
7-
from chex import PRNGKey, dataclass
7+
from chex import dataclass
88
from jax import lax
99
from jax.experimental import host_callback
1010
from jaxtyping import f64
1111
from tqdm.auto import tqdm
1212

1313
from .parameters import trainable_params
14-
from .types import Dataset
15-
from .utils import convert_seed
14+
from .types import Dataset, PRNGKeyType
1615

1716

1817
@dataclass(frozen=True)
@@ -82,11 +81,14 @@ def _progress_bar_scan(func):
8281
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`."""
8382

8483
def wrapper_progress_bar(carry, x):
85-
i = x
84+
if type(x) is tuple:
85+
iter_num, *_ = x
86+
else:
87+
iter_num = x
8688
result = func(carry, x)
8789
*_, loss_val = result
88-
_update_progress_bar(loss_val, i)
89-
return close_tqdm(result, i)
90+
_update_progress_bar(loss_val, iter_num)
91+
return close_tqdm(result, iter_num)
9092

9193
return wrapper_progress_bar
9294

@@ -119,16 +121,18 @@ def loss(params):
119121
params = trainable_params(params, trainables)
120122
return objective(params)
121123

124+
iter_nums = jnp.arange(n_iters)
125+
122126
@progress_bar_scan(n_iters, log_rate)
123-
def step(carry, i):
127+
def step(carry, iter_num):
124128
params, opt_state = carry
125129
loss_val, loss_gradient = jax.value_and_grad(loss)(params)
126130
updates, opt_state = optax_optim.update(loss_gradient, opt_state, params)
127131
params = optax.apply_updates(params, updates)
128132
carry = params, opt_state
129133
return carry, loss_val
130134

131-
(params, _), history = jax.lax.scan(step, (params, opt_state), jnp.arange(n_iters))
135+
(params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums)
132136
inf_state = InferenceState(params=params, history=history)
133137
return inf_state
134138

@@ -139,7 +143,7 @@ def fit_batches(
139143
trainables: tp.Dict,
140144
train_data: Dataset,
141145
optax_optim,
142-
seed: tp.Union[int, PRNGKey],
146+
key: PRNGKeyType,
143147
batch_size: int,
144148
n_iters: tp.Optional[int] = 100,
145149
log_rate: tp.Optional[int] = 10,
@@ -152,7 +156,7 @@ def fit_batches(
152156
trainables (dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained.
153157
train_data (Dataset): The training dataset.
154158
optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set.
155-
seed (int): The random seed for the mini-batch sampling.
159+
key (PRNGKeyType): The PRNG key for the mini-batch sampling.
156160
batch_size(int): The batch_size.
157161
n_iters (int, optional): The number of optimisation steps to run. Defaults to 100.
158162
log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10.
@@ -162,33 +166,42 @@ def fit_batches(
162166

163167
opt_state = optax_optim.init(params)
164168

165-
prng = convert_seed(seed)
166-
167-
x, y, n = train_data.X, train_data.y, train_data.n
168-
169169
def loss(params, batch):
170170
params = trainable_params(params, trainables)
171171
return objective(params, batch)
172172

173-
@progress_bar_scan(n_iters, log_rate)
174-
def step(carry, _):
175-
params, opt_state, prng = carry
173+
keys = jax.random.split(key, n_iters)
174+
iter_nums = jnp.arange(n_iters)
176175

177-
indicies = jr.choice(prng, n, (batch_size,), replace=True)
176+
@progress_bar_scan(n_iters, log_rate)
177+
def step(carry, iter_num__and__key):
178+
iter_num, key = iter_num__and__key
179+
params, opt_state = carry
178180

179-
batch = Dataset(X=x[indicies], y=y[indicies])
181+
batch = get_batch(train_data, batch_size, key)
180182

181183
loss_val, loss_gradient = jax.value_and_grad(loss)(params, batch)
182184
updates, opt_state = optax_optim.update(loss_gradient, opt_state, params)
183185
params = optax.apply_updates(params, updates)
184186

185-
prng, _ = jr.split(prng)
186-
187-
carry = params, opt_state, prng
187+
carry = params, opt_state
188188
return carry, loss_val
189189

190-
(params, _, _), history = jax.lax.scan(
191-
step, (params, opt_state, prng), jnp.arange(n_iters)
192-
)
190+
(params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys))
193191
inf_state = InferenceState(params=params, history=history)
194192
return inf_state
193+
194+
195+
def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset:
196+
"""Batch the data into mini-batches.
197+
Args:
198+
train_data (Dataset): The training dataset.
199+
batch_size (int): The batch size.
200+
Returns:
201+
Dataset: The batched dataset.
202+
"""
203+
x, y, n = train_data.X, train_data.y, train_data.n
204+
205+
indicies = jr.choice(key, n, (batch_size,), replace=True)
206+
207+
return Dataset(X=x[indicies], y=y[indicies])

gpjax/utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,3 @@ def array_to_dict(parameter_array) -> tp.Dict:
7979
return jax.tree_util.tree_unflatten(flattened_pytree[1], parameter_array)
8080

8181
return dict_to_array, array_to_dict
82-
83-
84-
def convert_seed(seed: tp.Union[int, PRNGKey]) -> PRNGKey:
85-
"""Ensure that seeds type."""
86-
87-
if isinstance(seed, int):
88-
rng = jr.PRNGKey(seed)
89-
else: # key is of type PRNGKey
90-
rng = seed
91-
92-
return rng

tests/test_abstractions.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import gpjax as gpx
77
from gpjax import RBF, Dataset, Gaussian, Prior, initialise, transform
8-
from gpjax.abstractions import InferenceState, fit, fit_batches
8+
from gpjax.abstractions import InferenceState, fit, fit_batches, get_batch
99

1010

1111
@pytest.mark.parametrize("n_iters", [10])
@@ -67,16 +67,40 @@ def test_batch_fitting(n_iters, nb, ndata):
6767
D = Dataset(X=x, y=y)
6868

6969
optimiser = optax.adam(learning_rate=0.1)
70-
seed = 42
71-
print("-" * 80)
72-
print(params)
73-
print("-" * 80)
70+
key = jr.PRNGKey(42)
7471
inference_state = fit_batches(
75-
objective, params, trainable_status, D, optimiser, seed, nb, n_iters
72+
objective, params, trainable_status, D, optimiser, key, nb, n_iters
7673
)
7774
optimised_params, history = inference_state.params, inference_state.history
7875
optimised_params = transform(optimised_params, constrainer)
7976
assert isinstance(inference_state, InferenceState)
8077
assert isinstance(optimised_params, dict)
8178
assert isinstance(history, jnp.ndarray)
8279
assert history.shape[0] == n_iters
80+
81+
82+
@pytest.mark.parametrize("batch_size", [1, 2, 50])
83+
@pytest.mark.parametrize("ndim", [1, 2, 3])
84+
@pytest.mark.parametrize("ndata", [50])
85+
@pytest.mark.parametrize("key", [jr.PRNGKey(123)])
86+
def test_get_batch(ndata, ndim, batch_size, key):
87+
x = jnp.sort(
88+
jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, ndim)), axis=0
89+
)
90+
y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1
91+
D = Dataset(X=x, y=y)
92+
93+
B = get_batch(D, batch_size, key)
94+
95+
assert B.n == batch_size
96+
assert B.X.shape[1:] == x.shape[1:]
97+
assert B.y.shape[1:] == y.shape[1:]
98+
99+
# test no caching of batches:
100+
key, subkey = jr.split(key)
101+
Bnew = get_batch(D, batch_size, subkey)
102+
assert Bnew.n == batch_size
103+
assert Bnew.X.shape[1:] == x.shape[1:]
104+
assert Bnew.y.shape[1:] == y.shape[1:]
105+
assert (Bnew.X != B.X).all()
106+
assert (Bnew.y != B.y).all()

tests/test_utilities.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
import jax.numpy as jnp
2-
import jax.random as jr
32
import pytest
4-
from chex import PRNGKey
53

64
from gpjax.utils import (
75
I,
86
as_constant,
97
concat_dictionaries,
10-
convert_seed,
118
dict_array_coercion,
129
merge_dictionaries,
1310
sort_dictionary,
@@ -68,8 +65,3 @@ def test_array_coercion(d):
6865
assert array_to_dict(dict_to_array(params)) == params
6966
assert isinstance(dict_to_array(params), list)
7067
assert isinstance(array_to_dict(dict_to_array(params)), dict)
71-
72-
73-
@pytest.mark.parametrize("seed", [1, 2, jr.PRNGKey(42)])
74-
def convert_seed(seed):
75-
assert isinstance(convert_seed(seed), PRNGKey)

0 commit comments

Comments
 (0)