@@ -127,7 +127,7 @@ def _check_lengthscale_dims_compat(
127127 """
128128
129129 if isinstance (lengthscale , nnx .Variable ):
130- return _check_lengthscale_dims_compat_old (lengthscale .value , n_dims )
130+ return _check_lengthscale_dims_compat (lengthscale .value , n_dims )
131131
132132 lengthscale = jnp .asarray (lengthscale )
133133 ls_shape = jnp .shape (lengthscale )
@@ -146,35 +146,6 @@ def _check_lengthscale_dims_compat(
146146 return n_dims
147147
148148
149- def _check_lengthscale_dims_compat_old (
150- lengthscale : tp .Union [LengthscaleCompatible , nnx .Variable [Lengthscale ]],
151- n_dims : tp .Union [int , None ],
152- ):
153- r"""Check that the lengthscale is compatible with n_dims.
154-
155- If possible, infer the number of input dimensions from the lengthscale.
156- """
157-
158- if isinstance (lengthscale , nnx .Variable ):
159- return _check_lengthscale_dims_compat_old (lengthscale .value , n_dims )
160-
161- lengthscale = jnp .asarray (lengthscale )
162- ls_shape = jnp .shape (lengthscale )
163-
164- if ls_shape == ():
165- return lengthscale , n_dims
166- elif ls_shape != () and n_dims is None :
167- return lengthscale , ls_shape [0 ]
168- elif ls_shape != () and n_dims is not None :
169- if ls_shape != (n_dims ,):
170- raise ValueError (
171- "Expected `lengthscale` to be compatible with the number "
172- f"of input dimensions. Got `lengthscale` with shape { ls_shape } , "
173- f"but the number of input dimensions is { n_dims } ."
174- )
175- return lengthscale , n_dims
176-
177-
178149def _check_lengthscale (lengthscale : tp .Any ):
179150 """Check that the lengthscale is a valid value."""
180151
0 commit comments