@@ -30,8 +30,7 @@ def cross_covariance(
3030 """
3131 z1 = self .compute_features (x )
3232 z2 = self .compute_features (y )
33- z1 /= self .kernel .num_basis_fns
34- return self .kernel .base_kernel .variance * jnp .matmul (z1 , z2 .T )
33+ return self .scaling * jnp .matmul (z1 , z2 .T )
3534
3635 def gram (self , inputs : Float [Array , "N D" ]) -> DenseLinearOperator :
3736 r"""Compute an approximate Gram matrix.
@@ -47,9 +46,7 @@ def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator:
4746 $`N \times N`$ Gram matrix.
4847 """
4948 z1 = self .compute_features (inputs )
50- matrix = jnp .matmul (z1 , z1 .T ) # shape: (n_samples, n_samples)
51- matrix /= self .kernel .num_basis_fns
52- return DenseLinearOperator (self .kernel .base_kernel .variance * matrix )
49+ return DenseLinearOperator (self .scaling * jnp .matmul (z1 , z1 .T ))
5350
5451 def compute_features (self , x : Float [Array , "N D" ]) -> Float [Array , "N L" ]:
5552 r"""Compute the features for the inputs.
@@ -66,3 +63,7 @@ def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]:
6663 z = jnp .matmul (x , (frequencies / scaling_factor ).T )
6764 z = jnp .concatenate ([jnp .cos (z ), jnp .sin (z )], axis = - 1 )
6865 return z
66+
67+ @property
68+ def scaling (self ):
69+ return self .kernel .base_kernel .variance / self .kernel .num_basis_fns
0 commit comments