Constraint-Aware Descent Routine Executor — JAX-native constrained optimization.
CADRE provides a unified interface to multiple JAX-compatible optimization backends, with first-class support for box-constrained problems via an active-set method (ADABK family).
This is the minimizer used in Furax-CS package for CMB component separation.
pip install jax-cadreWith optional scipy solvers (scipy_tnc, scipy_cobyqa):
pip install jax-cadre[scipy]from cadre import minimize
import jax.numpy as jnp
def loss(params, target):
return jnp.sum((params - target) ** 2)
target = jnp.array([1.0, 2.0, 3.0])
lower = jnp.zeros(3)
upper = jnp.ones(3) * 5.0
params, state = minimize(
loss,
init_params=jnp.zeros(3),
solver_name="ADABK0", # or "optax_lbfgs"
lower_bound=lower,
upper_bound=upper,
target=target,
)
print(f"Optimal params: {params}")| Solver | Description |
|---|---|
ADABK0 |
Active-set + AdaBelief, 1 constraint released/step. Best for noisy landscapes. |
ADABK{N} |
Active-set + AdaBelief, up to N×10 % constraints released/step. |
optax_lbfgs |
L-BFGS with zoom linesearch. Best for smooth landscapes. |
adam, adabelief, adaw, sgd |
First-order optax solvers with optional projection. |
optimistix_bfgs/lbfgs/ncg_* |
Optimistix solvers. |
scipy_tnc, scipy_cobyqa |
Scipy solvers via jaxopt (requires cadre[scipy]). |
Full solver documentation and ADABK internals: docs
from cadre import get_solver
import optimistix as optx
solver, _ = get_solver("ADABK0", rtol=1e-6, atol=1e-6)
state = solver.init(loss, init_params, target, {}, f_struct, None, frozenset())
# ... step manuallyMIT — see LICENSE.