We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent ffb57d8 commit cdb0f03Copy full SHA for cdb0f03
tests/test_mean_functions.py
@@ -4,6 +4,7 @@
4
config.update("jax_enable_x64", True)
5
6
7
+import jax
8
from jax import jit
9
import jax.numpy as jnp
10
import jax.random as jr
@@ -85,3 +86,14 @@ def test_zero_mean_remains_zero() -> None:
85
86
)
87
88
assert opt_posterior.prior.mean_function.constant == 0.0
89
+
90
91
+def test_zero_mean_pytree_no_leaves():
92
+ zero_mean = Zero()
93
+ leaves = jax.tree_util.tree_leaves(zero_mean)
94
+ assert len(leaves) == 0
95
96
97
+def test_initialising_zero_mean_with_constant_raises_error():
98
+ with pytest.raises(TypeError):
99
+ Zero(constant=jnp.array([1.0]))
0 commit comments