|
| 1 | +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | + |
| 16 | + |
| 17 | +from dataclasses import dataclass |
| 18 | +from typing import NamedTuple, Union |
| 19 | +import jax.numpy as jnp |
| 20 | +from jaxtyping import Float, Int |
| 21 | +import tensorflow_probability.substrates.jax as tfp |
| 22 | + |
| 23 | +from gpjax.base import ( |
| 24 | + param_field, |
| 25 | + static_field, |
| 26 | +) |
| 27 | +from gpjax.kernels.base import AbstractKernel |
| 28 | + |
| 29 | +from gpjax.typing import ( |
| 30 | + Array, |
| 31 | + ScalarInt, |
| 32 | +) |
| 33 | + |
| 34 | +tfb = tfp.bijectors |
| 35 | + |
| 36 | +CatKernelParams = NamedTuple( |
| 37 | + "CatKernelParams", |
| 38 | + [("stddev", Float[Array, "N 1"]), ("cholesky_lower", Float[Array, " N*(N-1)//2"])], |
| 39 | +) |
| 40 | + |
| 41 | + |
| 42 | +@dataclass |
| 43 | +class CatKernel(AbstractKernel): |
| 44 | + r"""The categorical kernel is defined for a fixed number of values of categorical input. |
| 45 | +
|
| 46 | + It stores a standard dev for each input value (i.e. the diagonal of the gram), and a lower cholesky factor for correlations. |
| 47 | + It returns the corresponding values from an the gram matrix when called. |
| 48 | +
|
| 49 | + Args: |
| 50 | + stddev (Float[Array, "N"]): The standard deviation parameters, one for each input space value. |
| 51 | + cholesky_lower (Float[Array, "N*(N-1)//2 N"]): The parameters for the Cholesky factor of the gram matrix. |
| 52 | + inspace_vals (list): The values in the input space this CatKernel works for. Stored for order reference, making clear the indices used for each input space value. |
| 53 | + name (str): The name of the kernel. |
| 54 | + input_1hot (bool): If True, the kernel expect to be called with a 1-hot encoding of the input space values. If False, it expects the indices of the input space values. |
| 55 | +
|
| 56 | + Raises: |
| 57 | + ValueError: If the number of diagonal variance parameters does not match the number of input space values. |
| 58 | + """ |
| 59 | + |
| 60 | + stddev: Float[Array, " N"] = param_field(jnp.ones((2,)), bijector=tfb.Softplus()) |
| 61 | + cholesky_lower: Float[Array, "N N"] = param_field( |
| 62 | + jnp.eye(2), bijector=tfb.CorrelationCholesky() |
| 63 | + ) |
| 64 | + inspace_vals: list = static_field(None) |
| 65 | + name: str = "Categorical Kernel" |
| 66 | + input_1hot: bool = static_field(False) |
| 67 | + |
| 68 | + def __post_init__(self): |
| 69 | + if self.inspace_vals is not None and len(self.inspace_vals) != len(self.stddev): |
| 70 | + raise ValueError( |
| 71 | + f"The number of stddev parameters ({len(self.stddev)}) has to match the number of input space values ({len(self.inspace_vals)}), unless inspace_vals is None." |
| 72 | + ) |
| 73 | + |
| 74 | + @property |
| 75 | + def explicit_gram(self) -> Float[Array, "N N"]: |
| 76 | + """Access the PSD gram matrix resulting from the parameters. |
| 77 | +
|
| 78 | + Returns: |
| 79 | + Float[Array, "N N"]: The gram matrix. |
| 80 | + """ |
| 81 | + L = self.stddev.reshape(-1, 1) * self.cholesky_lower |
| 82 | + return L @ L.T |
| 83 | + |
| 84 | + def __call__( # TODO not consistent with general kernel interface |
| 85 | + self, |
| 86 | + x: Union[ScalarInt, Int[Array, " N"]], |
| 87 | + y: Union[ScalarInt, Int[Array, " N"]], |
| 88 | + ): |
| 89 | + r"""Compute the (co)variance between a pair of dictionary indices. |
| 90 | +
|
| 91 | + Args: |
| 92 | + x (Union[ScalarInt, Int[Array, "N"]]): The index of the first dictionary entry, or its one-hot encoding. |
| 93 | + y (Union[ScalarInt, Int[Array, "N"]]): The index of the second dictionary entry, or its one-hot encoding. |
| 94 | +
|
| 95 | + Returns |
| 96 | + ------- |
| 97 | + ScalarFloat: The value of $k(v_i, v_j)$. |
| 98 | + """ |
| 99 | + try: |
| 100 | + x = x.squeeze() |
| 101 | + y = y.squeeze() |
| 102 | + except AttributeError: |
| 103 | + pass |
| 104 | + if self.input_1hot: |
| 105 | + return self.explicit_gram[jnp.outer(x, y) == 1] |
| 106 | + else: |
| 107 | + return self.explicit_gram[x, y] |
| 108 | + |
| 109 | + @staticmethod |
| 110 | + def num_cholesky_lower_params(num_inspace_vals: ScalarInt) -> ScalarInt: |
| 111 | + """Compute the number of parameters required to store the lower triangular Cholesky factor of the gram matrix. |
| 112 | +
|
| 113 | + Args: |
| 114 | + num_inspace_vals (ScalarInt): The number of values in the input space. |
| 115 | +
|
| 116 | + Returns: |
| 117 | + ScalarInt: The number of parameters required to store the lower triangle of the Cholesky factor of the gram matrix. |
| 118 | + """ |
| 119 | + return num_inspace_vals * (num_inspace_vals - 1) // 2 |
| 120 | + |
| 121 | + @staticmethod |
| 122 | + def gram_to_stddev_cholesky_lower(gram: Float[Array, "N N"]) -> CatKernelParams: |
| 123 | + """Compute the standard deviation and lower triangular Cholesky factor of the gram matrix. |
| 124 | +
|
| 125 | + Args: |
| 126 | + gram (Float[Array, "N N"]): The gram matrix. |
| 127 | +
|
| 128 | + Returns: |
| 129 | + tuple[Float[Array, "N"], Float[Array, "N N"]]: The standard deviation and lower triangular Cholesky factor of the gram matrix, where the latter is scaled to result in unit variances. |
| 130 | + """ |
| 131 | + stddev = jnp.sqrt(jnp.diag(gram)) |
| 132 | + L = jnp.linalg.cholesky(gram) / stddev.reshape(-1, 1) |
| 133 | + return CatKernelParams(stddev, L) |
0 commit comments