Skip to content

Commit dbaa5ce

Browse files
authored
Merge pull request #154 from JaxGaussianProcesses/incorporate_jax_linear_operator
Incorporate JaxLinOp with GPJax
2 parents 23e30a8 + b41f7e4 commit dbaa5ce

12 files changed

+604
-863
lines changed

examples/natgrads.pct.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import jax.random as jr
2626
import matplotlib.pyplot as plt
2727
import optax as ox
28-
from jax import jit, lax
2928
from jax.config import config
3029

3130
import gpjax as gpx
@@ -97,7 +96,7 @@
9796
n_iters=5000,
9897
batch_size=256,
9998
key=jr.PRNGKey(42),
100-
moment_optim=ox.sgd(0.1),
99+
moment_optim=ox.sgd(0.01),
101100
hyper_optim=ox.adam(1e-3),
102101
)
103102

0 commit comments

Comments
 (0)