Skip to content

Commit fe8bde4

Browse files
Fix failing test
1 parent cdb0f03 commit fe8bde4

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

tests/test_mean_functions.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# Enable Float64 for more stable matrix inversions.
2-
from jax.config import config
2+
from jax import config
33

44
config.update("jax_enable_x64", True)
55

66

77
import jax
8-
from jax import jit
98
import jax.numpy as jnp
109
import jax.random as jr
1110
from jaxtyping import (
@@ -72,10 +71,7 @@ def test_zero_mean_remains_zero() -> None:
7271
posterior = prior * likelihood
7372

7473
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
75-
negative_mll(posterior, train_data=D)
76-
negative_mll = jit(negative_mll)
77-
78-
opt_posterior, history = gpx.fit(
74+
opt_posterior, _ = gpx.fit(
7975
model=posterior,
8076
objective=negative_mll,
8177
train_data=D,

0 commit comments

Comments
 (0)