Skip to content

Commit 3bbc8cb

Browse files
authored
Merge pull request #176 from JaxGaussianProcesses/remove_chex
Remove chex
2 parents 8d52d7b + bb63fea commit 3bbc8cb

34 files changed

+743
-369
lines changed

.circleci/config.yml

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
version: 2.1
22

33
orbs:
4-
python: circleci/[email protected]
4+
# python: circleci/[email protected]
55
codecov: codecov/[email protected]
66

77
commands:
@@ -51,7 +51,6 @@ commands:
5151
- run:
5252
name: Upload to PyPI
5353
command: twine upload dist/* -r << parameters.pkgname >> --verbose
54-
5554
install_pandoc:
5655
description: "Install pandoc"
5756
parameters:
@@ -83,15 +82,12 @@ jobs:
8382
resource_class: large
8483
steps:
8584
- checkout
86-
- restore_cache:
87-
keys:
88-
- pip-cache
8985
- run:
90-
name: Update pip
91-
command: pip install --upgrade pip
92-
- python/install-packages:
93-
pkg-manager: pip-dist
94-
path-args: .[dev]
86+
name: Install dependencies
87+
command: |
88+
pip install --upgrade pip
89+
pip install -r requirements/dev.txt
90+
pip install -e .
9591
- run:
9692
name: Run tests
9793
command: |
@@ -103,10 +99,6 @@ jobs:
10399
curl -Os https://uploader.codecov.io/v0.1.0_4653/linux/codecov
104100
chmod +x codecov
105101
./codecov -t ${CODECOV_TOKEN}
106-
- save_cache:
107-
key: pip-cache
108-
paths:
109-
- ~/.cache/pip
110102
- store_test_results:
111103
path: test-results
112104
- store_artifacts:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ parameter_state = gpx.initialise(posterior, key=key)
123123
Finally, we run an optimisation loop using the Adam optimiser via the `fit` callable.
124124

125125
```python
126-
inference_state = gpx.fit(mll, parameter_state, opt, n_iters=500)
126+
inference_state = gpx.fit(mll, parameter_state, opt, num_iters=500)
127127
```
128128

129129
## 3. Making predictions

docs/README.md

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ description and a code example. The docstring is concluded with a description
4141
of the objects attributes with corresponding types.
4242

4343
```python
44-
@dataclass
4544
class Prior(AbstractPrior):
4645
"""A Gaussian process prior object. The GP is parameterised by a
4746
`mean <https://gpjax.readthedocs.io/en/latest/api.html#module-gpjax.mean_functions>`_
@@ -78,9 +77,4 @@ class Prior(AbstractPrior):
7877
### Documentation syntax
7978

8079
A helpful cheatsheet for writing restructured text can be found
81-
[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). In addition to that, we adopt the following convention when documenting
82-
`dataclass` objects.
83-
84-
* Class attributes should be specified using the `Attributes:` tag.
85-
* Method argument should be specified using the `Args:` tags.
86-
* All attributes and arguments should have types.
80+
[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst).

examples/README.md

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Where to find the docs
2+
3+
The GPJax documentation can be found here:
4+
https://gpjax.readthedocs.io/en/latest/
5+
6+
# How to build the docs
7+
8+
1. Install the requirements using `pip install -r docs/requirements.txt`
9+
2. Make sure `pandoc` is installed
10+
3. Run the make script `make html`
11+
12+
The corresponding HTML files can then be found in `docs/_build/html/`.
13+
14+
# How to write code documentation
15+
16+
Our documentation it is written in ReStructuredText for Sphinx. This is a
17+
meta-language that is compiled into online documentation. For more details see
18+
[Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html).
19+
As a result, our docstrings adhere to a specific syntax that has to be kept in
20+
mind. Below we provide some guidelines.
21+
22+
## How much information to put in a docstring
23+
24+
A docstring should be informative. If in doubt, then it is best to add more
25+
information to a docstring than less. Many users will skim documentation, so
26+
please ensure the opening sentence or two of a docstring contains the core
27+
information. Adding examples and mathematical descriptions to documentation is
28+
highly desirable.
29+
30+
We are making an active effort within GPJax to improve our documentation. If you
31+
spot any areas where there is missing information within the existing
32+
documentation, then please either raise an issue or
33+
[create a pull request](https://gpjax.readthedocs.io/en/latest/contributing.html).
34+
35+
## An example docstring
36+
37+
An example docstring that adheres the principles of GPJax is given below.
38+
The docstring contains a simple, snappy introduction with links to auxillary
39+
components. More detail is then provided in the form of a mathematical
40+
description and a code example. The docstring is concluded with a description
41+
of the objects attributes with corresponding types.
42+
43+
```python
44+
class Prior(AbstractPrior):
45+
"""A Gaussian process prior object. The GP is parameterised by a
46+
`mean <https://gpjax.readthedocs.io/en/latest/api.html#module-gpjax.mean_functions>`_
47+
and `kernel <https://gpjax.readthedocs.io/en/latest/api.html#module-gpjax.kernels>`_ function.
48+
49+
A Gaussian process prior parameterised by a mean function :math:`m(\\cdot)` and a kernel
50+
function :math:`k(\\cdot, \\cdot)` is given by
51+
52+
.. math::
53+
54+
p(f(\\cdot)) = \mathcal{GP}(m(\\cdot), k(\\cdot, \\cdot)).
55+
56+
To invoke a ``Prior`` distribution, only a kernel function is required. By default,
57+
the mean function will be set to zero. In general, this assumption will be reasonable
58+
assuming the data being modelled has been centred.
59+
60+
Example:
61+
>>> import gpjax as gpx
62+
>>>
63+
>>> kernel = gpx.kernels.RBF()
64+
>>> prior = gpx.Prior(kernel = kernel)
65+
66+
Attributes:
67+
kernel (Kernel): The kernel function used to parameterise the prior.
68+
mean_function (MeanFunction): The mean function used to parameterise the prior. Defaults to zero.
69+
name (str): The name of the GP prior. Defaults to "GP prior".
70+
"""
71+
72+
kernel: Kernel
73+
mean_function: Optional[AbstractMeanFunction] = Zero()
74+
name: Optional[str] = "GP prior"
75+
```
76+
77+
### Documentation syntax
78+
79+
A helpful cheatsheet for writing restructured text can be found
80+
[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). In addition to that, we adopt the following convention when documenting
81+
`` objects.
82+
83+
* Class attributes should be specified using the `Attributes:` tag.
84+
* Method argument should be specified using the `Args:` tags.
85+
* All attributes and arguments should have types.

examples/barycentres.pct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def fit_gp(x: jax.Array, y: jax.Array) -> dx.MultivariateNormalTri:
115115
objective=negative_mll,
116116
parameter_state=parameter_state,
117117
optax_optim=optimiser,
118-
n_iters=1000,
118+
num_iters=1000,
119119
)
120120

121121
learned_params, training_history = inference_state.unpack()

examples/classification.pct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
objective=negative_mll,
9292
parameter_state=parameter_state,
9393
optax_optim=optimiser,
94-
n_iters=1000,
94+
num_iters=1000,
9595
)
9696

9797
map_estimate, training_history = inference_state.unpack()

examples/collapsed_vi.pct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
objective=negative_elbo,
110110
parameter_state=parameter_state,
111111
optax_optim=optimiser,
112-
n_iters=2000,
112+
num_iters=2000,
113113
)
114114

115115
learned_params, training_history = inference_state.unpack()

examples/graph_kernels.pct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137
objective=negative_mll,
138138
parameter_state=parameter_state,
139139
optax_optim=optimiser,
140-
n_iters=1000,
140+
num_iters=1000,
141141
)
142142

143143
learned_params, training_history = inference_state.unpack()

examples/haiku.pct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def forward(x):
185185
objective=negative_mll,
186186
parameter_state=parameter_state,
187187
optax_optim=optimiser,
188-
n_iters=2500,
188+
num_iters=2500,
189189
)
190190

191191
learned_params, training_history = inference_state.unpack()

examples/kernels.pct.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -228,28 +228,7 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict:
228228
# domain is a circle, this is $2\pi$. Next we define the kernel's `__call__`
229229
# function which is a direct implementation of Equation (1). Finally, we define
230230
# the Kernel's parameter property which contains just one value $\tau$ that we
231-
# initialise to 4 in the kernel's `__post_init__`.
232-
#
233-
# #### Aside on dataclasses
234-
#
235-
# One can see in the above definition of a `Polar` kernel that we decorated the
236-
# class with a `@dataclass` command. Dataclasses are simply regular classs
237-
# objects in Python, however, much of the boilerplate code has been removed. For
238-
# example, without a `@dataclass` decorator, the instantiation of the above
239-
# `Polar` kernel would be done through
240-
# ```python
241-
# class Polar(jk.kernels.AbstractKernel):
242-
# def __init__(self, period: float = 2*jnp.pi):
243-
# super().__init__()
244-
# self.period = period
245-
# ```
246-
# As objects become increasingly large and complex, the conciseness of a
247-
# dataclass becomes increasingly attractive. To ensure full compatability with
248-
# Jax, it is crucial that the dataclass decorator is imported from Chex, not
249-
# base Python's `dataclass` module. Functionally, the two objects are identical.
250-
# However, unlike regular Python dataclasses, it is possilbe to apply operations
251-
# such as `jit`, `vmap` and `grad` to the dataclasses given by Chex as they are
252-
# registrered PyTrees.
231+
# initialise to 4 in the kernel's `__init__`.
253232
#
254233
#
255234
# ### Custom Parameter Bijection
@@ -312,7 +291,7 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict:
312291
objective=negative_mll,
313292
parameter_state=parameter_state,
314293
optax_optim=optimiser,
315-
n_iters=1000,
294+
num_iters=1000,
316295
)
317296

318297
learned_params, training_history = inference_state.unpack()

0 commit comments

Comments
 (0)