Skip to content

Commit 9ba68a4

Browse files
frazanedaniel-doddthomaspinderThomas-Christie
authored
Flax/nnx backend (#440)
* add flax v0.8.0 to deps, temporarily from github main branch * main gps objects as nnx modules * integrators as nnx dataclasses and some static typing refactoring * likelihoods as nnx dataclasses modules and some static typing refactoring * small refactoring * mean functions as nnx dataclasses modules and some refactoring * bugfix * objectives as nnx dataclasses modules * variational families with nnx * kernels base with nnx * wip stationary kernels * wip nonstationary kernels * wip non euclidean kernels * computations with nnx * rff with nnx * bugfix * stationary kernels as normal classes * nonstationary kernels as normal classes * noneuclidean kernels as normal classes * rff as standard class + stationary kernel abstract class for static typing * started work on parameters * more objects as normal classes * gps as normal classes * integrators as normal classes * dataset is not a pytree * removed superfluous inits * register dataset as pytree * use parameters here and there * set active_dims default to 1 * start working on tests * active_dims defaults to None * rewrite objectives as functions Co-authored-by: Daniel Dodd <[email protected]> * black + isort * remove objective from cite * fix dataset repr * pass tests for variational families * active_dims defaults to None * use generic Objective type * small fixes * make 'active_dims' required parameter, fix static typing and beartype for parameters, rewrite and pass tests for stationary kernels * pass tests/test_kernels/test_computation.py * rewrite tests for nonstationary kernels + pass tests * adapt to nnx's explicit variables + miscellaneous fixes * rewrite of objectives as simple functions, [WIP] started rewriting tests * rewrite and pass tests for objectives * rewrite fit function * remove gpjax.base module * remove base module tests * rewrite and pass tests for fit * finish kernels and pass all tests * pass all tests except decision making * pass all tests 🚀 * update and run classification notebook (python cells) * pass doctests * pass integration tests, more checks to parameters * linting and formatting * update barycentres and classification examples * update project files * update ruff and make it happy * lint + format all doc examples * [skip ci] change how dimensions are specified for kernels, update kernel tests * [skip ci] api reference looks pretty now, implemented template pattern, improved docstrings * [skip ci] wip - fixing math rendering in documentation - almost there * Update notebooks. (#447) * Update yacht.py * Update likelihoods_guide.py * Revert "Update likelihoods_guide.py" This reverts commit 5f51cfe. * Update oceanmodelling.py * Update likelihoods.py (#446) * Update likelihoods.py * Update likelihoods.py * Update likelihoods.py * Adding tagged parameters and updated notebooks * Update likelihoods.py (#446) * Update likelihoods.py * Update likelihoods.py * Update likelihoods.py * Update notebooks * Fix linting * Fix missing dep. * Fix integration test * Readd docs deps * Fix docstrings * Update lockfile * Update parameter refs * Fix broken tests * Remove PyTrees doc * Failing split order * NNX update * add flax v0.8.0 to deps, temporarily from github main branch * main gps objects as nnx modules * integrators as nnx dataclasses and some static typing refactoring * likelihoods as nnx dataclasses modules and some static typing refactoring * small refactoring * mean functions as nnx dataclasses modules and some refactoring * bugfix * objectives as nnx dataclasses modules * variational families with nnx * kernels base with nnx * wip stationary kernels * wip nonstationary kernels * wip non euclidean kernels * computations with nnx * rff with nnx * bugfix * stationary kernels as normal classes * nonstationary kernels as normal classes * noneuclidean kernels as normal classes * rff as standard class + stationary kernel abstract class for static typing * started work on parameters * more objects as normal classes * gps as normal classes * integrators as normal classes * dataset is not a pytree * removed superfluous inits * register dataset as pytree * use parameters here and there * set active_dims default to 1 * start working on tests * active_dims defaults to None * rewrite objectives as functions Co-authored-by: Daniel Dodd <[email protected]> * black + isort * remove objective from cite * fix dataset repr * pass tests for variational families * active_dims defaults to None * use generic Objective type * small fixes * make 'active_dims' required parameter, fix static typing and beartype for parameters, rewrite and pass tests for stationary kernels * pass tests/test_kernels/test_computation.py * rewrite tests for nonstationary kernels + pass tests * adapt to nnx's explicit variables + miscellaneous fixes * rewrite of objectives as simple functions, [WIP] started rewriting tests * rewrite and pass tests for objectives * rewrite fit function * remove gpjax.base module * remove base module tests * rewrite and pass tests for fit * finish kernels and pass all tests * pass all tests except decision making * pass all tests 🚀 * update and run classification notebook (python cells) * pass doctests * pass integration tests, more checks to parameters * linting and formatting * update barycentres and classification examples * update project files * update ruff and make it happy * lint + format all doc examples * [skip ci] change how dimensions are specified for kernels, update kernel tests * [skip ci] api reference looks pretty now, implemented template pattern, improved docstrings * [skip ci] wip - fixing math rendering in documentation - almost there * Update notebooks. (#447) * Update yacht.py * Update likelihoods_guide.py * Revert "Update likelihoods_guide.py" This reverts commit 5f51cfe. * Update oceanmodelling.py * Update likelihoods.py (#446) * Update likelihoods.py * Update likelihoods.py * Update likelihoods.py * Update notebooks * Adding tagged parameters and updated notebooks * Fix linting * Fix missing dep. * Fix integration test * Readd docs deps * Fix docstrings * Update lockfile * Update parameter refs * Fix broken tests * Remove PyTrees doc * Failing split order * NNX update * rename static dir * move examples dir in top level * add _examples generated dir to gitignore * update pyproject deps * update mkdocs config * add examples generation script * adapt relative paths in md files * Update Ruff and incorporate changes * update github workflow for building doc, without executing notebookf for now * Add backend doc * Add backend doc * Add backend doc * Add replace to transform * Merge with main * Update parameters docstring * Respond to comments * Fix e2e tests * Fix mplstyle refs * bump deps * Update poetry * Update poetry * Fix shutil * Drop flax base * add scikit-learn dependency for docs * bugfix: change directory before running jupytext * use local mpl style file * do not use MCMC for classification (it is *very* slow) * [skip-ci] update github workflows for docs * Fix split * Fix split * Fix split * Fix xdoctest * Fix doc * Add serial build * Update parameters transform and backend doc * Update parameters transform and backend doc * Bump Python --------- Signed-off-by: Thomas Pinder <[email protected]> Co-authored-by: Daniel Dodd <[email protected]> Co-authored-by: Daniel Dodd <[email protected]> Co-authored-by: Thomas Pinder <[email protected]> Co-authored-by: Thomas-Christie <[email protected]>
1 parent 7ae0adf commit 9ba68a4

File tree

139 files changed

+6152
-9642
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+6152
-9642
lines changed

.github/workflows/build_docs.yml

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,16 @@ jobs:
4747
- name: Install and configure Poetry
4848
uses: snok/install-poetry@v1
4949
with:
50-
version: 1.2.2
50+
version: 1.5.1
5151
virtualenvs-create: false
5252
virtualenvs-in-project: false
5353
installer-parallel: true
5454

55-
- name: Install LaTex
56-
run: |
57-
sudo apt-get update
58-
sudo apt-get install texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra dvipng cm-super
59-
6055
- name: Build the documentation with MKDocs
6156
run: |
62-
cp docs/examples/gpjax.mplstyle .
6357
poetry install --all-extras --with docs
6458
conda install pandoc
65-
poetry run mkdocs build
59+
poetry run python docs/scripts/gen_examples.py --execute && poetry run mkdocs build
6660
6761
- name: Deploy Page 🚀
6862
uses: JamesIves/[email protected]

.github/workflows/integration.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
- name: Install Poetry
3030
uses: snok/[email protected]
3131
with:
32-
version: 1.4.0
32+
version: 1.5.1
3333

3434
# Configure Poetry to use the virtual environment in the project
3535
- name: Setup Poetry
@@ -39,7 +39,7 @@ jobs:
3939
# Install the dependencies
4040
- name: Install Package
4141
run: |
42-
poetry install --all-extras --with docs
42+
poetry install --with docs
4343
4444
# Run the unit tests and build the coverage report
4545
- name: Run Integration Tests

.github/workflows/test_docs.yml

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,32 +33,17 @@ jobs:
3333
auto-update-conda: true
3434
python-version: ${{ matrix.python-version }}
3535

36-
# Install katex for math support
37-
- name: Install NPM
38-
uses: actions/setup-node@v3
39-
with:
40-
node-version: 16
41-
- name: Install KaTeX
42-
run: |
43-
npm install katex
44-
45-
- name: Install LaTex
46-
run: |
47-
sudo apt-get update
48-
sudo apt-get install texlive-fonts-recommended texlive-fonts-extra texlive-latex-extra dvipng cm-super
49-
5036
# Install Poetry and build the documentation
5137
- name: Install and configure Poetry
5238
uses: snok/install-poetry@v1
5339
with:
54-
version: 1.2.2
40+
version: 1.5.1
5541
virtualenvs-create: false
5642
virtualenvs-in-project: false
5743
installer-parallel: true
5844

5945
- name: Build the documentation with MKDocs
6046
run: |
61-
cp docs/examples/gpjax.mplstyle .
6247
poetry install --all-extras --with docs
6348
conda install pandoc
64-
poetry run mkdocs build
49+
poetry run python docs/scripts/gen_examples.py --execute && poetry run mkdocs build

.github/workflows/tests.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,13 @@ jobs:
2626
python-version: ${{ matrix.python-version }}
2727

2828
# Install Poetry
29-
- name: Install Poetry
30-
uses: snok/install-poetry@v1.3.3
29+
- name: Install and configure Poetry
30+
uses: snok/install-poetry@v1
3131
with:
32-
version: 1.4.0
32+
version: 1.5.1
33+
virtualenvs-create: false
34+
virtualenvs-in-project: false
35+
installer-parallel: true
3336

3437
# Configure Poetry to use the virtual environment in the project
3538
- name: Setup Poetry
@@ -39,7 +42,7 @@ jobs:
3942
# Install the dependencies
4043
- name: Install Package
4144
run: |
42-
poetry install --with tests
45+
poetry install --with dev
4346
4447
- name: Check docstrings
4548
run: |

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,4 @@ package-lock.json
152152
node_modules/
153153

154154
docs/api
155+
docs/_examples

.pre-commit-config.yaml

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ repos:
4646
language: system
4747
types: [python]
4848
exclude: examples/
49-
- repo: https://github.com/econchick/interrogate
50-
rev: 1.5.0
51-
hooks:
52-
- id: interrogate
53-
args:
54-
[
55-
"gpjax",
56-
"--config",
57-
"pyproject.toml",
58-
]
59-
pass_filenames: false
49+
# - repo: https://github.com/econchick/interrogate
50+
# rev: 1.5.0
51+
# hooks:
52+
# - id: interrogate
53+
# args:
54+
# [
55+
# "gpjax",
56+
# "--config",
57+
# "pyproject.toml",
58+
# ]
59+
# pass_filenames: false

README.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,9 @@ helped to shape GPJax into the package it is today.
7272
## Notebook examples
7373

7474
> - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/)
75-
> - [**Classification with MCMC**](https://docs.jaxgaussianprocesses.com/examples/classification/)
75+
> - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/)
7676
> - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/)
7777
> - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/)
78-
> - [**BlackJax Integration**](https://docs.jaxgaussianprocesses.com/examples/classification/#mcmc-inference)
7978
> - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation)
8079
> - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
8180
> - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/)
@@ -146,13 +145,10 @@ posterior = prior * likelihood
146145
# Define an optimiser
147146
optimiser = ox.adam(learning_rate=1e-2)
148147

149-
# Define the marginal log-likelihood
150-
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))
151-
152148
# Obtain Type 2 MLEs of the hyperparameters
153149
opt_posterior, history = gpx.fit(
154150
model=posterior,
155-
objective=negative_mll,
151+
objective=gpx.objectives.conjugate_mll,
156152
train_data=D,
157153
optim=optimiser,
158154
num_iters=500,

benchmarks/__init__.py

Whitespace-only changes.

benchmarks/asv.conf.json

Lines changed: 0 additions & 25 deletions
This file was deleted.

benchmarks/kernels.py

Lines changed: 0 additions & 99 deletions
This file was deleted.

0 commit comments

Comments
 (0)