|
7 | 7 | from jax.scipy.special import gammaln |
8 | 8 | import jax.scipy.stats.norm as norm |
9 | 9 | import jax |
| 10 | +import optax |
10 | 11 | from jax.example_libraries.optimizers import adam |
11 | | -from jaxopt import ScipyMinimize |
12 | 12 | from .conditional import ( |
13 | 13 | FullConditional, |
14 | 14 | ExpFullConditional, |
@@ -269,22 +269,50 @@ def step(step, opt_state): |
269 | 269 | return results |
270 | 270 |
|
271 | 271 |
|
272 | | -def minimize_lbfgsb(loss_func, initial_value, jit=DEFAULT_JIT): |
| 272 | +def minimize_lbfgsb(loss_func, initial_value, jit=DEFAULT_JIT, maxiter=500, tol=1e-8): |
273 | 273 | R""" |
274 | | - Minimizes function with a starting guess of initial_value. |
| 274 | + Minimizes function using L-BFGS via optax. |
275 | 275 |
|
276 | 276 | :param loss_func: Loss function to minimize. |
277 | 277 | :type loss_func: function |
278 | 278 | :param initial_value: Initial guess. |
279 | 279 | :type initial_value: array-like |
| 280 | + :param jit: Whether to JIT-compile the optimization step. |
| 281 | + :type jit: bool |
| 282 | + :param maxiter: Maximum number of iterations. |
| 283 | + :type maxiter: int |
| 284 | + :param tol: Gradient norm tolerance for convergence. |
| 285 | + :type tol: float |
280 | 286 | :return: Results - A named tuple containing pre_transformation, opt_state, |
281 | 287 | loss: The optimized parameters, final state of the optimizer, and the |
282 | 288 | final loss value, |
283 | 289 | :rtype: array-like, array-like, Object |
284 | 290 | """ |
285 | | - opt = ScipyMinimize(fun=loss_func, method="L-BFGS-B", jit=jit).run(initial_value) |
| 291 | + solver = optax.lbfgs() |
| 292 | + |
| 293 | + def step(x, opt_state): |
| 294 | + value, grad = jax.value_and_grad(loss_func)(x) |
| 295 | + updates, new_state = solver.update( |
| 296 | + grad, opt_state, x, |
| 297 | + value=value, grad=grad, value_fn=loss_func, |
| 298 | + ) |
| 299 | + new_x = optax.apply_updates(x, updates) |
| 300 | + return new_x, new_state, value, grad |
| 301 | + |
| 302 | + if jit: |
| 303 | + step = jax.jit(step) |
| 304 | + |
| 305 | + x = jax.numpy.asarray(initial_value) |
| 306 | + opt_state = solver.init(x) |
| 307 | + loss_val = loss_func(x) |
| 308 | + |
| 309 | + for _ in range(maxiter): |
| 310 | + x, opt_state, loss_val, grad = step(x, opt_state) |
| 311 | + if jax.numpy.linalg.norm(grad) < tol: |
| 312 | + break |
| 313 | + |
286 | 314 | Results = namedtuple("Results", "pre_transformation opt_state loss") |
287 | | - results = Results(opt.params, opt.state, opt.state.fun_val.item()) |
| 315 | + results = Results(x, opt_state, float(loss_val)) |
288 | 316 | return results |
289 | 317 |
|
290 | 318 |
|
|
0 commit comments