Skip to content

Commit 5a10eff

Browse files
committed
Replace jaxopt with optax L-BFGS and remove scipy.stats from tests
- Replace deprecated jaxopt.ScipyMinimize with optax.lbfgs (JAX-native, JIT-compatible L-BFGS optimizer) - Replace scipy.stats.spearmanr in tests with a simple rank-correlation helper using jax.numpy - Update dependency: jaxopt -> optax
1 parent 390bad5 commit 5a10eff

5 files changed

Lines changed: 50 additions & 17 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- new convenience methods on `FunctionEstimator`: `leverage(X)`, `empirical_variance(X, y)`, `get_obs_variance(X)`
88
- `obs_variance` weights are included in predictor serialization (`to_json`/`from_json`)
99
- `sigma` now accepts per-feature vectors of shape `(p,)` or `(1, p)` for multi-output GPs, giving each output column its own noise level
10+
- replace deprecated `jaxopt` dependency with `optax` for L-BFGS optimization
1011
- fix `requires-python` from `>=3.6` to `>=3.10`
1112

1213
# v1.6.1

mellon/compute_ls_time.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from jax.numpy import exp, unique, corrcoef, zeros, abs, stack
33
from jax.numpy import sum as arraysum
44
from jax.numpy.linalg import norm
5-
from jaxopt import ScipyMinimize
65
from .density_estimator import DensityEstimator
76
from .validation import validate_time_x
87

@@ -95,8 +94,10 @@ def ls_loss(log_ls):
9594
covs = cov_func_curry(ls)(delta_t, zeros((1, 1))).reshape((n_times, n_times))
9695
return norm(covs - corrs)
9796

98-
opt = ScipyMinimize(fun=ls_loss, method="L-BFGS-B", jit=False).run(0.0)
99-
ls = exp(opt.params).item()
97+
from .inference import minimize_lbfgsb
98+
99+
result = minimize_lbfgsb(ls_loss, 0.0, jit=False)
100+
ls = exp(result.pre_transformation).item()
100101

101102
if return_data:
102103
return ls, densities, predictors, unique_times

mellon/inference.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from jax.scipy.special import gammaln
88
import jax.scipy.stats.norm as norm
99
import jax
10+
import optax
1011
from jax.example_libraries.optimizers import adam
11-
from jaxopt import ScipyMinimize
1212
from .conditional import (
1313
FullConditional,
1414
ExpFullConditional,
@@ -269,22 +269,50 @@ def step(step, opt_state):
269269
return results
270270

271271

272-
def minimize_lbfgsb(loss_func, initial_value, jit=DEFAULT_JIT):
272+
def minimize_lbfgsb(loss_func, initial_value, jit=DEFAULT_JIT, maxiter=500, tol=1e-8):
273273
R"""
274-
Minimizes function with a starting guess of initial_value.
274+
Minimizes function using L-BFGS via optax.
275275
276276
:param loss_func: Loss function to minimize.
277277
:type loss_func: function
278278
:param initial_value: Initial guess.
279279
:type initial_value: array-like
280+
:param jit: Whether to JIT-compile the optimization step.
281+
:type jit: bool
282+
:param maxiter: Maximum number of iterations.
283+
:type maxiter: int
284+
:param tol: Gradient norm tolerance for convergence.
285+
:type tol: float
280286
:return: Results - A named tuple containing pre_transformation, opt_state,
281287
loss: The optimized parameters, final state of the optimizer, and the
282288
final loss value,
283289
:rtype: array-like, array-like, Object
284290
"""
285-
opt = ScipyMinimize(fun=loss_func, method="L-BFGS-B", jit=jit).run(initial_value)
291+
solver = optax.lbfgs()
292+
293+
def step(x, opt_state):
294+
value, grad = jax.value_and_grad(loss_func)(x)
295+
updates, new_state = solver.update(
296+
grad, opt_state, x,
297+
value=value, grad=grad, value_fn=loss_func,
298+
)
299+
new_x = optax.apply_updates(x, updates)
300+
return new_x, new_state, value, grad
301+
302+
if jit:
303+
step = jax.jit(step)
304+
305+
x = jax.numpy.asarray(initial_value)
306+
opt_state = solver.init(x)
307+
loss_val = loss_func(x)
308+
309+
for _ in range(maxiter):
310+
x, opt_state, loss_val, grad = step(x, opt_state)
311+
if jax.numpy.linalg.norm(grad) < tol:
312+
break
313+
286314
Results = namedtuple("Results", "pre_transformation opt_state loss")
287-
results = Results(opt.params, opt.state, opt.state.fun_val.item())
315+
results = Results(x, opt_state, float(loss_val))
288316
return results
289317

290318

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies = [
2222
# flexible foundations to avoid broken resolutions.
2323
# See: https://github.com/astral-sh/uv/issues/5161
2424
"jax",
25-
"jaxopt",
25+
"optax",
2626
"scikit-learn",
2727
"pynndescent",
2828
]
@@ -80,4 +80,4 @@ python_version = "3.9"
8080
warn_return_any = true
8181
warn_unused_configs = true
8282
disallow_untyped_defs = false
83-
disallow_incomplete_defs = false
83+
disallow_incomplete_defs = false

tests/test_leverage.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44
import mellon
55

66

7+
def _spearman_correlation(a, b):
8+
"""Simple Spearman rank correlation without scipy."""
9+
a, b = jnp.asarray(a).ravel(), jnp.asarray(b).ravel()
10+
rank_a = jnp.argsort(jnp.argsort(a)).astype(float)
11+
rank_b = jnp.argsort(jnp.argsort(b)).astype(float)
12+
return jnp.corrcoef(rank_a, rank_b)[0, 1]
13+
14+
715
@pytest.fixture
816
def setup_data():
917
n = 50
@@ -49,10 +57,7 @@ def test_sparse_gp_leverage_correlates_with_full(setup_data):
4957
est_sparse.fit(X, y)
5058
h_sparse = est_sparse.predict.leverage(X, sigma=sigma)
5159

52-
# Spearman correlation via ranks
53-
from scipy.stats import spearmanr
54-
55-
corr, _ = spearmanr(h_full, h_sparse)
60+
corr = _spearman_correlation(h_full, h_sparse)
5661
assert corr > 0.8, f"Spearman correlation {corr} too low between full and sparse leverage."
5762

5863

@@ -190,9 +195,7 @@ def test_obs_variance_correlates_with_true_noise():
190195

191196
var = est.predict.obs_variance(X)
192197

193-
from scipy.stats import spearmanr
194-
195-
corr, _ = spearmanr(true_noise_std**2, var)
198+
corr = _spearman_correlation(true_noise_std**2, var)
196199
assert corr > 0.3, (
197200
f"obs_variance should correlate with true noise variance, got Spearman={corr}"
198201
)

0 commit comments

Comments
 (0)