Skip to content

Commit 0e62e26

Browse files
authored
Support latest Flax (#554)
* Sort inputs * Bump flax pin and minimum python version * Bump uv.lock * Explicitly annotate flax parameters * Adjust test (list -> nnx.List) * Bump version
1 parent e6283ed commit 0e62e26

File tree

11 files changed

+326
-764
lines changed

11 files changed

+326
-764
lines changed

.github/workflows/integration.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
matrix:
1414
# Select the Python versions to test against
1515
os: ["ubuntu-latest", "macos-latest"]
16-
python-version: ["3.10", "3.11", "3.12", "3.13"]
16+
python-version: ["3.11", "3.12", "3.13"]
1717
fail-fast: true
1818
steps:
1919
- name: Check out the code

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
strategy:
1212
matrix:
1313
os: ["ubuntu-latest", "macos-latest"]
14-
python-version: ["3.10", "3.11", "3.12", "3.13"]
14+
python-version: ["3.11", "3.12", "3.13"]
1515
fail-fast: true
1616
runs-on: ${{ matrix.os }}
1717
steps:

gpjax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
__description__ = "Gaussian processes in JAX and Flax"
4141
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
4242
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43-
__version__ = "0.12.2"
43+
__version__ = "0.13.0"
4444

4545
__all__ = [
4646
"gps",

gpjax/gps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from abc import abstractmethod
1717

1818
import beartype.typing as tp
19+
from flax import nnx
1920
import jax.numpy as jnp
2021
import jax.random as jr
21-
from flax import nnx
2222
from jaxtyping import (
2323
Float,
2424
Num,

gpjax/kernels/approximations/rff.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Compute Random Fourier Feature (RFF) kernel approximations."""
22

33
import beartype.typing as tp
4+
from flax import nnx
45
import jax.random as jr
56
from jaxtyping import Float
67

@@ -54,7 +55,7 @@ def __init__(
5455
self._check_valid_base_kernel(base_kernel)
5556
self.base_kernel = base_kernel
5657
self.num_basis_fns = num_basis_fns
57-
self.frequencies = frequencies
58+
self.frequencies = nnx.data(frequencies)
5859
self.compute_engine = compute_engine
5960

6061
if self.frequencies is None:

gpjax/kernels/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def __init__(
253253
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
254254
):
255255
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
256-
kernels_list: list[AbstractKernel] = []
256+
kernels_list: list[AbstractKernel] = nnx.List([])
257257
for kernel in kernels:
258258
if not isinstance(kernel, AbstractKernel):
259259
raise TypeError("can only combine Kernel instances") # pragma: no cover

gpjax/kernels/stationary/base.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _check_lengthscale_dims_compat(
127127
"""
128128

129129
if isinstance(lengthscale, nnx.Variable):
130-
return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims)
130+
return _check_lengthscale_dims_compat(lengthscale.value, n_dims)
131131

132132
lengthscale = jnp.asarray(lengthscale)
133133
ls_shape = jnp.shape(lengthscale)
@@ -146,35 +146,6 @@ def _check_lengthscale_dims_compat(
146146
return n_dims
147147

148148

149-
def _check_lengthscale_dims_compat_old(
150-
lengthscale: tp.Union[LengthscaleCompatible, nnx.Variable[Lengthscale]],
151-
n_dims: tp.Union[int, None],
152-
):
153-
r"""Check that the lengthscale is compatible with n_dims.
154-
155-
If possible, infer the number of input dimensions from the lengthscale.
156-
"""
157-
158-
if isinstance(lengthscale, nnx.Variable):
159-
return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims)
160-
161-
lengthscale = jnp.asarray(lengthscale)
162-
ls_shape = jnp.shape(lengthscale)
163-
164-
if ls_shape == ():
165-
return lengthscale, n_dims
166-
elif ls_shape != () and n_dims is None:
167-
return lengthscale, ls_shape[0]
168-
elif ls_shape != () and n_dims is not None:
169-
if ls_shape != (n_dims,):
170-
raise ValueError(
171-
"Expected `lengthscale` to be compatible with the number "
172-
f"of input dimensions. Got `lengthscale` with shape {ls_shape}, "
173-
f"but the number of input dimensions is {n_dims}."
174-
)
175-
return lengthscale, n_dims
176-
177-
178149
def _check_lengthscale(lengthscale: tp.Any):
179150
"""Check that the lengthscale is a valid value."""
180151

gpjax/mean_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def __init__(
176176
super().__init__(**kwargs)
177177

178178
# Add means to a list, flattening out instances of this class therein, as in GPFlow kernels.
179-
items_list: list[AbstractMeanFunction] = []
179+
items_list: list[AbstractMeanFunction] = nnx.List([])
180180

181181
for item in means:
182182
if not isinstance(item, AbstractMeanFunction):

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name = "gpjax"
77
dynamic = ["version"]
88
description = 'Gaussian processes in JAX.'
99
readme = "README.md"
10-
requires-python = ">=3.10"
10+
requires-python = ">=3.11"
1111
license = { text = "MIT" }
1212
keywords = ["gaussian-processes jax machine-learning bayesian"]
1313
authors = [{ name = "Thomas Pinder", email = "[email protected]" }]
@@ -29,7 +29,7 @@ dependencies = [
2929
"jaxtyping>0.2.10",
3030
"tqdm>4.66.2",
3131
"beartype>0.16.1",
32-
"flax>=0.10.0",
32+
"flax>=0.12.0",
3333
"numpy>=2.0.0",
3434
]
3535

tests/test_kernels/test_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
from flax import nnx
1617
from jax import config
1718
import jax.numpy as jnp
1819
from jaxtyping import (
@@ -113,11 +114,11 @@ def test_combination_kernel(
113114
combination_kernel = combination_type(kernels=kernels)
114115

115116
# Check params are a list of dictionaries
116-
assert combination_kernel.kernels == kernels
117+
assert combination_kernel.kernels == nnx.List(kernels)
117118

118119
# Check combination kernel set
119120
assert len(combination_kernel.kernels) == n_kerns
120-
assert isinstance(combination_kernel.kernels, list)
121+
assert isinstance(combination_kernel.kernels, nnx.List)
121122
assert isinstance(combination_kernel.kernels[0], AbstractKernel)
122123

123124
# Compute gram matrix

0 commit comments

Comments
 (0)