Skip to content

Commit fd2d047

Browse files
authored
Merge pull request #28 from CMBSciPol/docs
Docs
2 parents bc9335e + 2574459 commit fd2d047

3 files changed

Lines changed: 70 additions & 2 deletions

File tree

docs/minimization.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,50 @@ This prevents numerical under/overflow across the extreme dynamic range between
9595

9696
The step size α is capped at the distance to the nearest bound (α_max), then a line search finds the optimal α in [0, α_max]. If a parameter hits a bound, it becomes an active constraint.
9797

98+
## Termination Control (ADABK solvers)
99+
100+
Four parameters control when ADABK solvers decide to stop. Pass them via the `options` dict of `minimize()`:
101+
102+
```python
103+
final_params, state = minimize(
104+
fn=my_loss_fn,
105+
init_params=params,
106+
solver_name="ADABK0",
107+
lower_bound=lower,
108+
upper_bound=upper,
109+
options={
110+
"cooldown": 50, # default: 20
111+
"min_steps": 200, # default: 10
112+
"verbose_print": True, # default: False
113+
"max_linesearch_steps": 100, # default: 50
114+
},
115+
)
116+
```
117+
118+
| Parameter | Default | Description |
119+
|-----------|---------|-------------|
120+
| `cooldown` | `20` | Steps to suppress termination after a constraint is released. Prevents premature convergence caused by a transient function spike when a bound constraint opens. |
121+
| `min_steps` | `10` | Minimum iterations before termination is ever considered. Useful when the initial gradient is near zero but the landscape is not yet explored. |
122+
| `verbose_print` | `False` | Print per-step diagnostics: current `f`, `f_diff`, `best_f`, cooldown status, and termination decision. Uses `jax.debug.print` so it is JIT-compatible. |
123+
| `max_linesearch_steps` | `50` | Maximum bounded line-search steps per iteration. |
124+
125+
### How termination is decided
126+
127+
Termination requires **all** of the following to hold simultaneously:
128+
129+
1. `f_diff = |f_current − f_prev| < atol + rtol × max(1, |best_f|)` — spike-immune f-change check
130+
2. Cauchy-convergence in y-space (base Optimistix check)
131+
3. Step count ≥ `min_steps`
132+
4. Not inside the cooldown window after the last constraint release
133+
134+
### CLI equivalents
135+
136+
When using `kmeans-model` or `ptep-model`, the same parameters are available as flags:
137+
138+
```bash
139+
kmeans-model ... --cooldown 50 --min-steps 200 --verbose
140+
```
141+
98142
## Conditioning
99143

100144
Conditioning (preconditioning) transforms the optimization problem to improve convergence. It applies two transformations before optimization:

src/furax_cs/optim/minimize.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,20 @@ def minimize(
220220
Box constraints.
221221
precondition : bool
222222
Whether to apply parameter transformation and output scaling.
223-
solver_options : dict, optional
224-
Additional arguments passed to the solver factory (get_solver).
223+
options : dict, optional
224+
Extra arguments passed to the solver factory (get_solver).
225+
For active-set solvers (``ADABK{N}`` family) the recognised keys are:
226+
227+
* ``cooldown`` (int, default 20) — steps to suppress termination
228+
after a constraint release.
229+
* ``min_steps`` (int, default 10) — minimum iterations before
230+
termination is considered.
231+
* ``verbose_print`` (bool, default False) — print per-step debug
232+
info via ``jax.debug.print`` (JIT-compatible).
233+
* ``max_linesearch_steps`` (int, default 50) — maximum line-search
234+
steps per iteration (active-set and ``optax_lbfgs`` solvers).
235+
* ``linesearch`` (str) — linesearch variant for ``optax_lbfgs``
236+
(``"zoom"`` or ``"backtracking"``).
225237
**fn_kwargs
226238
Additional arguments passed to fn.
227239

src/furax_cs/optim/solvers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,18 @@ def get_solver(
371371
Lower bounds for box projection (optax solvers only).
372372
upper : PyTree, optional
373373
Upper bounds for box projection (optax solvers only).
374+
verbose_print : bool
375+
If True, print per-step termination diagnostics for active-set
376+
solvers via ``jax.debug.print`` (JIT-compatible).
377+
min_steps : int
378+
Minimum iterations before termination is considered
379+
(active-set solvers only).
380+
cooldown : int
381+
Steps to suppress termination after a constraint release
382+
(active-set solvers only).
383+
max_linesearch_steps : int
384+
Maximum line-search steps per iteration (active-set and
385+
``optax_lbfgs`` solvers).
374386
375387
Returns
376388
-------

0 commit comments

Comments
 (0)