Skip to content

Commit cdb0f03

Browse files
Add additional tests for zero mean
1 parent ffb57d8 commit cdb0f03

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

tests/test_mean_functions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
config.update("jax_enable_x64", True)
55

66

7+
import jax
78
from jax import jit
89
import jax.numpy as jnp
910
import jax.random as jr
@@ -85,3 +86,14 @@ def test_zero_mean_remains_zero() -> None:
8586
)
8687

8788
assert opt_posterior.prior.mean_function.constant == 0.0
89+
90+
91+
def test_zero_mean_pytree_no_leaves():
92+
zero_mean = Zero()
93+
leaves = jax.tree_util.tree_leaves(zero_mean)
94+
assert len(leaves) == 0
95+
96+
97+
def test_initialising_zero_mean_with_constant_raises_error():
98+
with pytest.raises(TypeError):
99+
Zero(constant=jnp.array([1.0]))

0 commit comments

Comments
 (0)