Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
# Select the Python versions to test against
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.10", "3.11", "3.12", "3.13"]
python-version: ["3.11", "3.12", "3.13"]
fail-fast: true
steps:
- name: Check out the code
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.10", "3.11", "3.12", "3.13"]
python-version: ["3.11", "3.12", "3.13"]
fail-fast: true
runs-on: ${{ matrix.os }}
steps:
Expand Down
2 changes: 1 addition & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
__description__ = "Gaussian processes in JAX and Flax"
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
__version__ = "0.12.2"
__version__ = "0.13.0"

__all__ = [
"gps",
Expand Down
2 changes: 1 addition & 1 deletion gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from abc import abstractmethod

import beartype.typing as tp
from flax import nnx
import jax.numpy as jnp
import jax.random as jr
from flax import nnx
from jaxtyping import (
Float,
Num,
Expand Down
3 changes: 2 additions & 1 deletion gpjax/kernels/approximations/rff.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Compute Random Fourier Feature (RFF) kernel approximations."""

import beartype.typing as tp
from flax import nnx
import jax.random as jr
from jaxtyping import Float

Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(
self._check_valid_base_kernel(base_kernel)
self.base_kernel = base_kernel
self.num_basis_fns = num_basis_fns
self.frequencies = frequencies
self.frequencies = nnx.data(frequencies)
self.compute_engine = compute_engine

if self.frequencies is None:
Expand Down
2 changes: 1 addition & 1 deletion gpjax/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def __init__(
compute_engine: AbstractKernelComputation = DenseKernelComputation(),
):
# Add kernels to a list, flattening out instances of this class therein, as in GPFlow kernels.
kernels_list: list[AbstractKernel] = []
kernels_list: list[AbstractKernel] = nnx.List([])
for kernel in kernels:
if not isinstance(kernel, AbstractKernel):
raise TypeError("can only combine Kernel instances") # pragma: no cover
Expand Down
31 changes: 1 addition & 30 deletions gpjax/kernels/stationary/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _check_lengthscale_dims_compat(
"""

if isinstance(lengthscale, nnx.Variable):
return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims)
return _check_lengthscale_dims_compat(lengthscale.value, n_dims)

lengthscale = jnp.asarray(lengthscale)
ls_shape = jnp.shape(lengthscale)
Expand All @@ -146,35 +146,6 @@ def _check_lengthscale_dims_compat(
return n_dims


def _check_lengthscale_dims_compat_old(
lengthscale: tp.Union[LengthscaleCompatible, nnx.Variable[Lengthscale]],
n_dims: tp.Union[int, None],
):
r"""Check that the lengthscale is compatible with n_dims.

If possible, infer the number of input dimensions from the lengthscale.
"""

if isinstance(lengthscale, nnx.Variable):
return _check_lengthscale_dims_compat_old(lengthscale.value, n_dims)

lengthscale = jnp.asarray(lengthscale)
ls_shape = jnp.shape(lengthscale)

if ls_shape == ():
return lengthscale, n_dims
elif ls_shape != () and n_dims is None:
return lengthscale, ls_shape[0]
elif ls_shape != () and n_dims is not None:
if ls_shape != (n_dims,):
raise ValueError(
"Expected `lengthscale` to be compatible with the number "
f"of input dimensions. Got `lengthscale` with shape {ls_shape}, "
f"but the number of input dimensions is {n_dims}."
)
return lengthscale, n_dims


def _check_lengthscale(lengthscale: tp.Any):
"""Check that the lengthscale is a valid value."""

Expand Down
2 changes: 1 addition & 1 deletion gpjax/mean_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(
super().__init__(**kwargs)

# Add means to a list, flattening out instances of this class therein, as in GPFlow kernels.
items_list: list[AbstractMeanFunction] = []
items_list: list[AbstractMeanFunction] = nnx.List([])

for item in means:
if not isinstance(item, AbstractMeanFunction):
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "gpjax"
dynamic = ["version"]
description = 'Gaussian processes in JAX.'
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
license = { text = "MIT" }
keywords = ["gaussian-processes jax machine-learning bayesian"]
authors = [{ name = "Thomas Pinder", email = "[email protected]" }]
Expand All @@ -29,7 +29,7 @@ dependencies = [
"jaxtyping>0.2.10",
"tqdm>4.66.2",
"beartype>0.16.1",
"flax>=0.10.0",
"flax>=0.12.0",
"numpy>=2.0.0",
]

Expand Down
5 changes: 3 additions & 2 deletions tests/test_kernels/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================

from flax import nnx
from jax import config
import jax.numpy as jnp
from jaxtyping import (
Expand Down Expand Up @@ -113,11 +114,11 @@ def test_combination_kernel(
combination_kernel = combination_type(kernels=kernels)

# Check params are a list of dictionaries
assert combination_kernel.kernels == kernels
assert combination_kernel.kernels == nnx.List(kernels)

# Check combination kernel set
assert len(combination_kernel.kernels) == n_kerns
assert isinstance(combination_kernel.kernels, list)
assert isinstance(combination_kernel.kernels, nnx.List)
assert isinstance(combination_kernel.kernels[0], AbstractKernel)

# Compute gram matrix
Expand Down
Loading
Loading