Skip to content

Commit 77aa81c

Browse files
Merge categorical kernel in
Categorical kernels allow to have a neatly parametrized kernel for variables that take on categorical values.
2 parents dd695f5 + 0dfb739 commit 77aa81c

File tree

4 files changed

+212
-3
lines changed

4 files changed

+212
-3
lines changed

gpjax/kernels/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
DiagonalKernelComputation,
2828
EigenKernelComputation,
2929
)
30-
from gpjax.kernels.non_euclidean import GraphKernel
30+
from gpjax.kernels.non_euclidean import GraphKernel, CatKernel
3131
from gpjax.kernels.nonstationary import (
3232
ArcCosine,
3333
Linear,
@@ -49,6 +49,7 @@
4949
"ArcCosine",
5050
"RBF",
5151
"GraphKernel",
52+
"CatKernel",
5253
"Matern12",
5354
"Matern32",
5455
"Matern52",

gpjax/kernels/non_euclidean/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414
# ==============================================================================
1515

1616
from gpjax.kernels.non_euclidean.graph import GraphKernel
17+
from gpjax.kernels.non_euclidean.categorical import CatKernel
1718

18-
__all__ = ["GraphKernel"]
19+
__all__ = ["GraphKernel", "CatKernel"]
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
17+
from dataclasses import dataclass
18+
from typing import NamedTuple, Union
19+
import jax.numpy as jnp
20+
from jaxtyping import Float, Int
21+
import tensorflow_probability.substrates.jax as tfp
22+
23+
from gpjax.base import (
24+
param_field,
25+
static_field,
26+
)
27+
from gpjax.kernels.base import AbstractKernel
28+
29+
from gpjax.typing import (
30+
Array,
31+
ScalarInt,
32+
)
33+
34+
tfb = tfp.bijectors
35+
36+
CatKernelParams = NamedTuple(
37+
"CatKernelParams",
38+
[("stddev", Float[Array, "N 1"]), ("cholesky_lower", Float[Array, " N*(N-1)//2"])],
39+
)
40+
41+
42+
@dataclass
43+
class CatKernel(AbstractKernel):
44+
r"""The categorical kernel is defined for a fixed number of values of categorical input.
45+
46+
It stores a standard dev for each input value (i.e. the diagonal of the gram), and a lower cholesky factor for correlations.
47+
It returns the corresponding values from an the gram matrix when called.
48+
49+
Args:
50+
stddev (Float[Array, "N"]): The standard deviation parameters, one for each input space value.
51+
cholesky_lower (Float[Array, "N*(N-1)//2 N"]): The parameters for the Cholesky factor of the gram matrix.
52+
inspace_vals (list): The values in the input space this CatKernel works for. Stored for order reference, making clear the indices used for each input space value.
53+
name (str): The name of the kernel.
54+
input_1hot (bool): If True, the kernel expect to be called with a 1-hot encoding of the input space values. If False, it expects the indices of the input space values.
55+
56+
Raises:
57+
ValueError: If the number of diagonal variance parameters does not match the number of input space values.
58+
"""
59+
60+
stddev: Float[Array, " N"] = param_field(jnp.ones((2,)), bijector=tfb.Softplus())
61+
cholesky_lower: Float[Array, "N N"] = param_field(
62+
jnp.eye(2), bijector=tfb.CorrelationCholesky()
63+
)
64+
inspace_vals: list = static_field(None)
65+
name: str = "Categorical Kernel"
66+
input_1hot: bool = static_field(False)
67+
68+
def __post_init__(self):
69+
if self.inspace_vals is not None and len(self.inspace_vals) != len(self.stddev):
70+
raise ValueError(
71+
f"The number of stddev parameters ({len(self.stddev)}) has to match the number of input space values ({len(self.inspace_vals)}), unless inspace_vals is None."
72+
)
73+
74+
@property
75+
def explicit_gram(self) -> Float[Array, "N N"]:
76+
"""Access the PSD gram matrix resulting from the parameters.
77+
78+
Returns:
79+
Float[Array, "N N"]: The gram matrix.
80+
"""
81+
L = self.stddev.reshape(-1, 1) * self.cholesky_lower
82+
return L @ L.T
83+
84+
def __call__( # TODO not consistent with general kernel interface
85+
self,
86+
x: Union[ScalarInt, Int[Array, " N"]],
87+
y: Union[ScalarInt, Int[Array, " N"]],
88+
):
89+
r"""Compute the (co)variance between a pair of dictionary indices.
90+
91+
Args:
92+
x (Union[ScalarInt, Int[Array, "N"]]): The index of the first dictionary entry, or its one-hot encoding.
93+
y (Union[ScalarInt, Int[Array, "N"]]): The index of the second dictionary entry, or its one-hot encoding.
94+
95+
Returns
96+
-------
97+
ScalarFloat: The value of $k(v_i, v_j)$.
98+
"""
99+
try:
100+
x = x.squeeze()
101+
y = y.squeeze()
102+
except AttributeError:
103+
pass
104+
if self.input_1hot:
105+
return self.explicit_gram[jnp.outer(x, y) == 1]
106+
else:
107+
return self.explicit_gram[x, y]
108+
109+
@staticmethod
110+
def num_cholesky_lower_params(num_inspace_vals: ScalarInt) -> ScalarInt:
111+
"""Compute the number of parameters required to store the lower triangular Cholesky factor of the gram matrix.
112+
113+
Args:
114+
num_inspace_vals (ScalarInt): The number of values in the input space.
115+
116+
Returns:
117+
ScalarInt: The number of parameters required to store the lower triangle of the Cholesky factor of the gram matrix.
118+
"""
119+
return num_inspace_vals * (num_inspace_vals - 1) // 2
120+
121+
@staticmethod
122+
def gram_to_stddev_cholesky_lower(gram: Float[Array, "N N"]) -> CatKernelParams:
123+
"""Compute the standard deviation and lower triangular Cholesky factor of the gram matrix.
124+
125+
Args:
126+
gram (Float[Array, "N N"]): The gram matrix.
127+
128+
Returns:
129+
tuple[Float[Array, "N"], Float[Array, "N N"]]: The standard deviation and lower triangular Cholesky factor of the gram matrix, where the latter is scaled to result in unit variances.
130+
"""
131+
stddev = jnp.sqrt(jnp.diag(gram))
132+
L = jnp.linalg.cholesky(gram) / stddev.reshape(-1, 1)
133+
return CatKernelParams(stddev, L)

tests/test_kernels/test_non_euclidean.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
import jax.numpy as jnp
1515
import networkx as nx
1616

17-
from gpjax.kernels.non_euclidean import GraphKernel
17+
from gpjax.kernels.non_euclidean import GraphKernel, CatKernel
1818
from gpjax.linops import identity
19+
import jax.random as jr
1920

2021
# # Enable Float64 for more stable matrix inversions.
2122
config.update("jax_enable_x64", True)
@@ -46,3 +47,76 @@ def test_graph_kernel():
4647
Kxx += identity(n_verticies) * 1e-6
4748
eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense())
4849
assert all(eigen_values > 0)
50+
51+
52+
def test_cat_kernel():
53+
x = jr.normal(jr.PRNGKey(123), (5000, 3))
54+
gram = jnp.cov(x.T)
55+
params = CatKernel.gram_to_stddev_cholesky_lower(gram)
56+
dk = CatKernel(
57+
inspace_vals=list(range(len(gram))),
58+
stddev=params.stddev,
59+
cholesky_lower=params.cholesky_lower,
60+
)
61+
assert jnp.allclose(dk.explicit_gram, gram)
62+
63+
sdev = jnp.ones((2,))
64+
cholesky_lower = jnp.eye(2)
65+
inspace_vals = [0.0, 1.0]
66+
67+
# Initialize CatKernel object
68+
dict_kernel = CatKernel(
69+
stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals
70+
)
71+
72+
assert dict_kernel.stddev.shape == sdev.shape
73+
assert jnp.allclose(dict_kernel.stddev, sdev)
74+
assert jnp.allclose(dict_kernel.cholesky_lower, cholesky_lower)
75+
assert dict_kernel.inspace_vals == inspace_vals
76+
77+
78+
def test_cat_kernel_gram_to_stddev_cholesky_lower():
79+
gram = jnp.array([[1.0, 0.5], [0.5, 1.0]])
80+
sdev_expected = jnp.array([1.0, 1.0])
81+
cholesky_lower_expected = jnp.array([[1.0, 0.0], [0.5, 0.8660254]])
82+
83+
# Compute sdev and cholesky_lower from gram
84+
sdev, cholesky_lower = CatKernel.gram_to_stddev_cholesky_lower(gram)
85+
86+
assert jnp.allclose(sdev, sdev_expected)
87+
assert jnp.allclose(cholesky_lower, cholesky_lower_expected)
88+
89+
90+
def test_cat_kernel_call():
91+
sdev = jnp.ones((2,))
92+
cholesky_lower = jnp.eye(2)
93+
inspace_vals = [0.0, 1.0]
94+
95+
# Initialize CatKernel object
96+
dict_kernel = CatKernel(
97+
stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals
98+
)
99+
100+
# Compute kernel value for pair of inputs
101+
kernel_value = dict_kernel.__call__(0, 1)
102+
103+
assert jnp.allclose(kernel_value, 0.0) # since cholesky_lower is identity matrix
104+
105+
106+
def test_cat_kernel_explicit_gram():
107+
sdev = jnp.ones((2,))
108+
cholesky_lower = jnp.eye(2)
109+
inspace_vals = [0.0, 1.0]
110+
111+
# Initialize CatKernel object
112+
dict_kernel = CatKernel(
113+
stddev=sdev, cholesky_lower=cholesky_lower, inspace_vals=inspace_vals
114+
)
115+
116+
# Compute explicit gram matrix
117+
explicit_gram = dict_kernel.explicit_gram
118+
119+
assert explicit_gram.shape == (2, 2)
120+
assert jnp.allclose(
121+
explicit_gram, jnp.eye(2)
122+
) # since sdev are ones and cholesky_lower is identity matrix

0 commit comments

Comments
 (0)