We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cdb0f03 commit fe8bde4Copy full SHA for fe8bde4
tests/test_mean_functions.py
@@ -1,11 +1,10 @@
1
# Enable Float64 for more stable matrix inversions.
2
-from jax.config import config
+from jax import config
3
4
config.update("jax_enable_x64", True)
5
6
7
import jax
8
-from jax import jit
9
import jax.numpy as jnp
10
import jax.random as jr
11
from jaxtyping import (
@@ -72,10 +71,7 @@ def test_zero_mean_remains_zero() -> None:
72
71
posterior = prior * likelihood
73
74
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
75
- negative_mll(posterior, train_data=D)
76
- negative_mll = jit(negative_mll)
77
-
78
- opt_posterior, history = gpx.fit(
+ opt_posterior, _ = gpx.fit(
79
model=posterior,
80
objective=negative_mll,
81
train_data=D,
0 commit comments