Skip to content

Commit e98050e

Browse files
authored
fix feature map (#326)
1 parent b7cc06c commit e98050e

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

gpjax/kernels/computations/basis_functions.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ def cross_covariance(
3030
"""
3131
z1 = self.compute_features(x)
3232
z2 = self.compute_features(y)
33-
z1 /= self.kernel.num_basis_fns
34-
return self.kernel.base_kernel.variance * jnp.matmul(z1, z2.T)
33+
return self.scaling * jnp.matmul(z1, z2.T)
3534

3635
def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator:
3736
r"""Compute an approximate Gram matrix.
@@ -47,9 +46,7 @@ def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator:
4746
$`N \times N`$ Gram matrix.
4847
"""
4948
z1 = self.compute_features(inputs)
50-
matrix = jnp.matmul(z1, z1.T) # shape: (n_samples, n_samples)
51-
matrix /= self.kernel.num_basis_fns
52-
return DenseLinearOperator(self.kernel.base_kernel.variance * matrix)
49+
return DenseLinearOperator(self.scaling * jnp.matmul(z1, z1.T))
5350

5451
def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]:
5552
r"""Compute the features for the inputs.
@@ -66,3 +63,7 @@ def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]:
6663
z = jnp.matmul(x, (frequencies / scaling_factor).T)
6764
z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
6865
return z
66+
67+
@property
68+
def scaling(self):
69+
return self.kernel.base_kernel.variance / self.kernel.num_basis_fns

0 commit comments

Comments
 (0)