Skip to content

CMBSciPol/CADRE

Repository files navigation

CADRE

Constraint-Aware Descent Routine Executor — JAX-native constrained optimization.

PyPI pre-commit License: MIT Docs: minimization

CADRE provides a unified interface to multiple JAX-compatible optimization backends, with first-class support for box-constrained problems via an active-set method (ADABK family).

This is the minimizer used in Furax-CS package for CMB component separation.

Installation

pip install jax-cadre

With optional scipy solvers (scipy_tnc, scipy_cobyqa):

pip install jax-cadre[scipy]

Quick start

from cadre import minimize
import jax.numpy as jnp

def loss(params, target):
    return jnp.sum((params - target) ** 2)

target = jnp.array([1.0, 2.0, 3.0])
lower  = jnp.zeros(3)
upper  = jnp.ones(3) * 5.0

params, state = minimize(
    loss,
    init_params=jnp.zeros(3),
    solver_name="ADABK0",   # or "optax_lbfgs"
    lower_bound=lower,
    upper_bound=upper,
    target=target,
)

print(f"Optimal params: {params}")

Solvers

Solver Description
ADABK0 Active-set + AdaBelief, 1 constraint released/step. Best for noisy landscapes.
ADABK{N} Active-set + AdaBelief, up to N×10 % constraints released/step.
optax_lbfgs L-BFGS with zoom linesearch. Best for smooth landscapes.
adam, adabelief, adaw, sgd First-order optax solvers with optional projection.
optimistix_bfgs/lbfgs/ncg_* Optimistix solvers.
scipy_tnc, scipy_cobyqa Scipy solvers via jaxopt (requires cadre[scipy]).

Full solver documentation and ADABK internals: docs

Advanced usage

from cadre import get_solver
import optimistix as optx

solver, _ = get_solver("ADABK0", rtol=1e-6, atol=1e-6)

state = solver.init(loss, init_params, target, {}, f_struct, None, frozenset())
# ... step manually

License

MIT — see LICENSE.

About

Constraint-Aware Descent Routine Executor

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors