Skip to content

Commit 587b1f7

Browse files
authored
Graph kernel consistency with other kernels (#560)
* toggle dark theme * graph variational families * variational family not working and example * graph svgp fix * more updates * dtype problem * kernel checks in variational family * removed graph gaussian * formatting issues * ran uv run poe format to fix * passing tests and formatting weirdness * subclassing and tests * segregated graph tests that still fail * graph tests, changes to parent variational class signatures and graph kernel to allow SVGP shapes and dtypes * docstring for new class * remove ensure_2d function * tolerant dimensions * graph kernel consistent with other kernels * meaningful name and circular import issue * version bump and typechecking magic
1 parent 1dce552 commit 587b1f7

File tree

4 files changed

+38
-21
lines changed

4 files changed

+38
-21
lines changed

gpjax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
__description__ = "Gaussian processes in JAX and Flax"
4141
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
4242
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
43-
__version__ = "0.13.0"
43+
__version__ = "0.13.1"
4444

4545
__all__ = [
4646
"gps",

gpjax/kernels/computations/eigen.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616

1717
import beartype.typing as tp
18-
import jax.numpy as jnp
1918
from jaxtyping import (
2019
Float,
2120
Num,
@@ -39,17 +38,4 @@ class EigenKernelComputation(AbstractKernelComputation):
3938
def _cross_covariance(
4039
self, kernel: Kernel, x: Num[Array, "N D"], y: Num[Array, "M D"]
4140
) -> Float[Array, "N M"]:
42-
# Transform the eigenvalues of the graph Laplacian according to the
43-
# RBF kernel's SPDE form.
44-
S = jnp.power(
45-
kernel.eigenvalues
46-
+ 2
47-
* kernel.smoothness.value
48-
/ kernel.lengthscale.value
49-
/ kernel.lengthscale.value,
50-
-kernel.smoothness.value,
51-
)
52-
S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
53-
# Scale the transform eigenvalues by the kernel variance
54-
S = jnp.multiply(S, kernel.variance.value)
55-
return kernel(x, y, S=S)
41+
return kernel(x, y)

gpjax/kernels/non_euclidean/graph.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
AbstractKernelComputation,
2626
EigenKernelComputation,
2727
)
28-
from gpjax.kernels.non_euclidean.utils import jax_gather_nd
28+
from gpjax.kernels.non_euclidean.utils import (
29+
calculate_heat_semigroup,
30+
jax_gather_nd,
31+
)
2932
from gpjax.kernels.stationary.base import StationaryKernel
3033
from gpjax.parameters import (
3134
Parameter,
@@ -98,14 +101,12 @@ def __init__(
98101

99102
super().__init__(active_dims, lengthscale, variance, n_dims, compute_engine)
100103

101-
def __call__( # TODO not consistent with general kernel interface
104+
def __call__(
102105
self,
103106
x: Int[Array, "N 1"],
104107
y: Int[Array, "M 1"],
105-
*,
106-
S,
107-
**kwargs,
108108
):
109+
S = calculate_heat_semigroup(self)
109110
Kxx = (jax_gather_nd(self.eigenvectors, x) * S.squeeze()) @ jnp.transpose(
110111
jax_gather_nd(self.eigenvectors, y)
111112
) # shape (n,n)

gpjax/kernels/non_euclidean/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,20 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515

16+
from __future__ import annotations
17+
18+
import beartype.typing as tp
19+
import jax.numpy as jnp
1620
from jaxtyping import (
1721
Float,
1822
Int,
1923
)
2024

2125
from gpjax.typing import Array
2226

27+
if tp.TYPE_CHECKING:
28+
from gpjax.kernels.non_euclidean.graph import GraphKernel
29+
2330

2431
def jax_gather_nd(
2532
params: Float[Array, " N *rest"], indices: Int[Array, " M 1"]
@@ -41,3 +48,26 @@ def jax_gather_nd(
4148
"""
4249
tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
4350
return params[tuple_indices]
51+
52+
53+
def calculate_heat_semigroup(kernel: GraphKernel) -> Float[Array, "N M"]:
54+
r"""Returns the rescaled heat semigroup, S
55+
56+
Args:
57+
kernel: instance of the graph kernel
58+
59+
Returns:
60+
S
61+
"""
62+
S = jnp.power(
63+
kernel.eigenvalues
64+
+ 2
65+
* kernel.smoothness.value
66+
/ kernel.lengthscale.value
67+
/ kernel.lengthscale.value,
68+
-kernel.smoothness.value,
69+
)
70+
S = jnp.multiply(S, kernel.num_vertex / jnp.sum(S))
71+
# Scale the transform eigenvalues by the kernel variance
72+
S = jnp.multiply(S, kernel.variance.value)
73+
return S

0 commit comments

Comments
 (0)