Skip to content

Commit 35dabd7

Browse files
authored
Thomaspinder/replace cola (#529)
* Deprecate cola * Fix flaky test * Upper bound flax * Upper bound flax * Test doctest * Fix flaky test * Migrate to UV and fix parameter tags
1 parent 3885cd8 commit 35dabd7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+2164
-576
lines changed

.github/pull_request_template.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## Checklist
22

3-
- [ ] I've formatted the new code by running `hatch run dev:format` before committing.
3+
- [ ] I've formatted the new code by running `uv run poe format` before committing.
44
- [ ] I've added tests for new code.
55
- [ ] I've added docstrings for the new code.
66

.github/workflows/build_docs.yml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
strategy:
2323
matrix:
2424
os: ["ubuntu-latest"]
25-
python-version: ["3.10"]
25+
python-version: ["3.11"]
2626

2727
steps:
2828
# Grap the latest commit from the branch
@@ -47,14 +47,18 @@ jobs:
4747
run: |
4848
npm install katex
4949
50-
# Install Hatch
51-
- name: Install Hatch
52-
uses: pypa/hatch@install
50+
# Install uv
51+
- name: Install uv
52+
uses: astral-sh/setup-uv@v3
53+
with:
54+
version: "latest"
5355

5456
- name: Build the documentation with MKDocs
5557
run: |
5658
conda install pandoc
57-
hatch run docs:build
59+
uv sync --extra docs
60+
uv run python docs/scripts/gen_examples.py --execute
61+
uv run mkdocs build
5862
5963
- name: Deploy Page 🚀
6064
uses: JamesIves/[email protected]

.github/workflows/integration.yml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
matrix:
1414
# Select the Python versions to test against
1515
os: ["ubuntu-latest", "macos-latest"]
16-
python-version: ["3.10", "3.11"]
16+
python-version: ["3.10", "3.11", "3.12", "3.13"]
1717
fail-fast: true
1818
steps:
1919
- name: Check out the code
@@ -25,10 +25,14 @@ jobs:
2525
with:
2626
python-version: ${{ matrix.python-version }}
2727

28-
# Install Hatch
29-
- name: Install Hatch
30-
uses: pypa/hatch@install
28+
# Install uv
29+
- name: Install uv
30+
uses: astral-sh/setup-uv@v3
31+
with:
32+
version: "latest"
3133

3234
# Run the unit tests and build the coverage report
3335
- name: Run Integration Tests
34-
run: hatch run docs:integration
36+
run: |
37+
uv sync --extra docs
38+
uv run python tests/integration_tests.py

.github/workflows/stale_prs.yml

Lines changed: 0 additions & 45 deletions
This file was deleted.

.github/workflows/test_docs.yml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
strategy:
1818
matrix:
1919
os: ["ubuntu-latest"]
20-
python-version: ["3.10"]
20+
python-version: ["3.11"]
2121

2222
steps:
2323
# Grap the latest commit from the branch
@@ -33,11 +33,15 @@ jobs:
3333
auto-update-conda: true
3434
python-version: ${{ matrix.python-version }}
3535

36-
# Install Hatch
37-
- name: Install Hatch
38-
uses: pypa/hatch@install
36+
# Install uv
37+
- name: Install uv
38+
uses: astral-sh/setup-uv@v3
39+
with:
40+
version: "latest"
3941

4042
- name: Build the documentation with MKDocs
4143
run: |
4244
conda install pandoc
43-
hatch run docs:build
45+
uv sync --extra docs
46+
uv run python docs/scripts/gen_examples.py --execute
47+
uv run mkdocs build

.github/workflows/tests.yml

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@ on:
88
jobs:
99
unit-tests:
1010
name: Run Tests
11-
runs-on: ubuntu-latest
1211
strategy:
1312
matrix:
14-
# Select the Python versions to test against
1513
os: ["ubuntu-latest", "macos-latest"]
16-
python-version: ["3.10", "3.11"]
14+
python-version: ["3.10", "3.11", "3.12", "3.13"]
1715
fail-fast: true
16+
runs-on: ${{ matrix.os }}
1817
steps:
1918
- name: Check out the code
2019
uses: actions/[email protected]
@@ -25,18 +24,22 @@ jobs:
2524
with:
2625
python-version: ${{ matrix.python-version }}
2726

28-
# Install Hatch
29-
- name: Install Hatch
30-
uses: pypa/hatch@install
27+
# Install uv
28+
- name: Install uv
29+
uses: astral-sh/setup-uv@v3
30+
with:
31+
version: "latest"
3132

3233
# Install the dependencies
3334
- name: Check docstrings
3435
run: |
35-
hatch run dev:docstrings
36+
uv sync --extra dev
37+
uv run xdoctest ./gpjax
3638
3739
# Run the unit tests and build the coverage report
3840
- name: Run Tests
39-
run: hatch run dev:coverage
41+
run: uv run pytest . -v --cov=./gpjax --cov-report=xml:./coverage.xml
42+
4043

4144
- name: Upload code coverage
4245
uses: codecov/codecov-action@v3

docs/contributing.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,16 @@ you through every detail!
7272
Always use a `feature` branch. It's good practice to avoid
7373
work on the ``main`` branch of any repository.
7474

75-
4. We use [Hatch](https://hatch.pypa.io/latest/) for packaging and dependency management. Project requirements are in ``pyproject.toml``. To install GPJax into a Hatch virtual environment, run:
75+
4. We use [uv](https://docs.astral.sh/uv/) for packaging and dependency management. Project requirements are in ``pyproject.toml``. To install GPJax with uv, run:
7676

7777
```bash
78-
$ hatch env create
78+
$ uv sync --extra dev
7979
```
8080

8181
At this point we recommend you check your installation passes the supplied unit tests:
8282

8383
```bash
84-
$ hatch run dev:all-tests
84+
$ uv run poe all-tests
8585
```
8686

8787
5. Add changed files using `git add` and then `git commit` files to record your
@@ -142,7 +142,7 @@ request, we recommend you check the following:
142142
accepted. Test coverage can be checked with:
143143

144144
```bash
145-
$ hatch run dev:coverage
145+
$ uv run poe coverage
146146
```
147147

148148
Navigate to the newly created folder `htmlcov` and open `index.html` to view

docs/installation.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ hardware acceleration support as detailed in the
3131
```bash
3232
git clone https://github.com/thomaspinder/GPJax.git
3333
cd GPJax
34-
hatch shell create
34+
uv sync --extra dev
3535
```
3636

3737
!!! tip
@@ -45,5 +45,5 @@ hardware acceleration support as detailed in the
4545
and recommend you check your installation passes the supplied unit tests:
4646

4747
```bash
48-
hatch run dev:all-tests
48+
uv run poe all-tests
4949
```

examples/barycentres.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# extension: .py
99
# format_name: percent
1010
# format_version: '1.3'
11-
# jupytext_version: 1.16.7
11+
# jupytext_version: 1.17.3
1212
# kernelspec:
1313
# display_name: gpjax
1414
# language: python

examples/classification.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
# extension: .py
99
# format_name: percent
1010
# format_version: '1.3'
11-
# jupytext_version: 1.16.7
11+
# jupytext_version: 1.11.2
1212
# kernelspec:
13-
# display_name: gpjax
13+
# display_name: .venv
1414
# language: python
1515
# name: python3
1616
# ---
@@ -22,7 +22,6 @@
2222
# with non-Gaussian likelihoods via maximum a posteriori (MAP). We focus on a classification task here.
2323

2424
# %%
25-
import cola
2625
from flax import nnx
2726
import jax
2827

@@ -41,7 +40,7 @@
4140
import optax as ox
4241

4342
from examples.utils import use_mpl_style
44-
from gpjax.lower_cholesky import lower_cholesky
43+
from gpjax.linalg import lower_cholesky, PSD, solve
4544

4645
config.update("jax_enable_x64", True)
4746

@@ -219,7 +218,7 @@
219218
# Compute (latent) function value map estimates at training points:
220219
Kxx = opt_posterior.prior.kernel.gram(x)
221220
Kxx += identity_matrix(D.n) * jitter
222-
Kxx = cola.PSD(Kxx)
221+
Kxx = PSD(Kxx)
223222
Lx = lower_cholesky(Kxx)
224223
f_hat = Lx @ opt_posterior.latent.value
225224

@@ -267,10 +266,10 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> npd.MultivariateNorma
267266
Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs)
268267
Kxx = opt_posterior.prior.kernel.gram(x)
269268
Kxx += identity_matrix(D.n) * jitter
270-
Kxx = cola.PSD(Kxx)
269+
Kxx = PSD(Kxx)
271270

272271
# Kxx⁻¹ Kxt
273-
Kxx_inv_Kxt = cola.solve(Kxx, Kxt)
272+
Kxx_inv_Kxt = solve(Kxx, Kxt)
274273

275274
# Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
276275
laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Kxt.T, H_inv), Kxx_inv_Kxt)

0 commit comments

Comments
 (0)