Skip to content

Commit f0475eb

Browse files
authored
Merge pull request #26 from CMBSciPol/docs
enhance docs
2 parents cc008ed + 56991ce commit f0475eb

12 files changed

Lines changed: 115 additions & 31 deletions

File tree

docs/minimization.md

Lines changed: 104 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,111 @@
77
The `solver_name` argument in the `minimize` function accepts the following:
88

99
### Recommended
10-
* **`active_set`**: **Best for noisy maps.** Uses a projected gradient method with active set constraints. Robust against noise but might be slower on very clean data.
11-
* **`optax_lbfgs`**: **Best for noiseless runs.** L-BFGS with zoom linesearch (Strong Wolfe conditions). Very fast and accurate for smooth, noise-free landscapes.
10+
* **`ADABK0`**: **Best for noisy maps.** Active-set method with AdaBelief direction and Top-K constraint release (K=0, i.e. one constraint released per iteration). Very robust in low-SNR regions. See [How ADABK Works](#how-adabk-works) below.
11+
* **`optax_lbfgs`**: **Best for noiseless runs (systematics).** L-BFGS with zoom linesearch (Strong Wolfe conditions). Very fast and accurate for smooth, noise-free landscapes.
1212

1313
### Other Options
14-
* `optax_lbfgs`: L-BFGS.
15-
* `adam`: Simple Adam optimizer (good for stochastic settings).
16-
* `scipy_tnc`: Wrapper for SciPy's Truncated Newton (TNC).
17-
* `optimistix_bfgs`: Standard BFGS from Optimistix.
18-
* `optimistix_lbfgs`: Standard L-BFGS from Optimistix.
19-
* `optimistix_ncg_*`: Nonlinear Conjugate Gradient variants (`pr`, `hs`, `fr`, `dy`).
14+
15+
**Active set variants** (self-conditioned):
16+
17+
* `ADABK{N}` — AdaBelief + Top-K active set. `N * 0.1` = fraction of constraints released per step. `ADABK0` releases 1 constraint/step (most stable), `ADABK5` releases up to 50%. (see [How ADABK Works](#how-adabk-works) and paper for more info)
18+
* `active_set` — Active set with Adam direction.
19+
* `active_set_sgd` — Active set with SGD direction.
20+
* `active_set_adabelief` — Active set with AdaBelief direction.
21+
* `active_set_adaw` — Active set with AdamW direction.
22+
23+
**Optax L-BFGS:**
24+
25+
* `optax_lbfgs` — L-BFGS with zoom linesearch (default) or backtracking.
26+
27+
**Optax first-order:**
28+
29+
* `adam` — Adam optimizer.
30+
* `sgd` — SGD with backtracking linesearch.
31+
* `adabelief` — AdaBelief optimizer.
32+
* `adaw` / `adamw` — AdamW optimizer.
33+
34+
**Optimistix:**
35+
36+
* `optimistix_bfgs` — Full BFGS.
37+
* `optimistix_lbfgs` — Limited-memory BFGS.
38+
* `optimistix_ncg_pr` — Nonlinear Conjugate Gradient (Polak-Ribière).
39+
* `optimistix_ncg_hs` — Nonlinear Conjugate Gradient (Hestenes-Stiefel).
40+
* `optimistix_ncg_fr` — Nonlinear Conjugate Gradient (Fletcher-Reeves).
41+
* `optimistix_ncg_dy` — Nonlinear Conjugate Gradient (Dai-Yuan).
42+
43+
**SciPy** (self-conditioned):
44+
45+
* `scipy_tnc` — Truncated Newton (TNC).
46+
* `scipy_cobyqa` — COBYQA (derivative-free constrained optimizer).
47+
48+
49+
## How ADABK Works
50+
51+
ADABK (Adaptive AdaBelief with Top-K Active Set, also called **AdaTopK** in the paper) is a JAX-native optimizer that combines the TNC active-set constraint strategy with the AdaBelief adaptive gradient method.
52+
53+
### Internal parameter space
54+
55+
Physical parameters **x** (bounded by **l**, **u**) are mapped to a normalized [0, 1] representation via an affine transform:
56+
57+
**y** = (**x****l**) / (**u****l**)
58+
59+
This normalizes the optimization landscape and ensures consistent step sizes across parameters with different physical scales.
60+
61+
### Active set and pivot vector
62+
63+
Each parameter has a pivot value p_i:
64+
65+
* p_i = −1: parameter is at the lower bound (active constraint)
66+
* p_i = +1: parameter is at the upper bound (active constraint)
67+
* p_i = 0: parameter is free
68+
69+
Only free parameters (p_i = 0) are optimized at each iteration.
70+
71+
### Top-K constraint release
72+
73+
At each iteration, a release score is computed for every active constraint:
74+
75+
score_i = p_i × (−g_i)
76+
77+
A positive score means the negative gradient points into the feasible region — releasing this constraint could decrease the objective. The Top-K fraction K controls how many constraints are released per iteration:
78+
79+
* **K = 0** (`ADABK0`): releases 1 constraint at a time. Most stable, consistently reaches the lowest objective values.
80+
* **K = N** (`ADABK{N}`): releases up to `N × 0.1` fraction of active constraints.
81+
82+
### Projected gradient and AdaBelief direction
83+
84+
Gradients for active constraints are zeroed out: **g_proj** = **g** ⊙ (p = 0). The projected gradient is then fed to AdaBelief, which adapts step sizes based on gradient variance. This makes it better suited to noisy gradient landscapes (low-SNR regions) than classical quasi-Newton methods (L-BFGS, TNC) which tend to reset their curvature history when gradients are unreliable.
85+
86+
### Dynamic state rescaling
87+
88+
When the gradient norm falls outside [10⁻¹⁵, 10¹⁵], the cost function and AdaBelief moment estimates are rescaled:
89+
90+
**m** ← f_scale · **m** , **v** ← f_scale² · **v**
91+
92+
This prevents numerical under/overflow across the extreme dynamic range between the bright Galactic plane and faint high-latitude sky, without resetting the optimizer's momentum.
93+
94+
### Bounded line search
95+
96+
The step size α is capped at the distance to the nearest bound (α_max), then a line search finds the optimal α in [0, α_max]. If a parameter hits a bound, it becomes an active constraint.
97+
98+
## Conditioning
99+
100+
Conditioning (preconditioning) transforms the optimization problem to improve convergence. It applies two transformations before optimization:
101+
102+
1. **Parameter scaling**: min-max normalization to [0, 1] based on bounds.
103+
2. **Gradient scaling**: the objective is scaled by 1/‖∇f‖ at initialization (like SciPy TNC's `fscale`), so the initial gradient norm is ≈ 1.
104+
105+
### Self-conditioned solvers
106+
107+
These solvers handle conditioning internally and ignore the `precondition` flag:
108+
109+
* **Active set variants** (`active_set`, `active_set_sgd`, `active_set_adabelief`, `active_set_adaw`, `ADABK{N}`) — use internal affine transform + dynamic state rescaling.
110+
* **SciPy solvers** (`scipy_tnc`, `scipy_cobyqa`) — SciPy handles bounds and scaling internally.
111+
112+
### Externally conditioned solvers
113+
114+
All other solvers (`optax_lbfgs`, `adam`, `optimistix_*`, etc.) benefit from external conditioning when dealing with poorly scaled problems. Pass `precondition=True` (or a custom scaling function) to `minimize`.
20115

21116
## Minimizing Programmatically
22117

@@ -34,7 +129,7 @@ final_params, state = minimize(
34129
)
35130
```
36131

37-
### Advanced: Steping interactively with Solvers
132+
### Advanced: Stepping Interactively with Solvers
38133

39134
Since most solvers (except SciPy) are JAX-compatible, you can step through the optimization process manually. This is useful for custom logging or adaptive strategies.
40135

src/furax_cs/r_analysis/binning.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
_ALL_PARAM_NAMES = ["beta_dust", "temp_dust", "beta_pl"]
2929

3030

31-
3231
def _squeeze_patches(arr: np.ndarray) -> np.ndarray:
3332
"""Squeeze n_gridpts=1 leading dim from patch arrays if present."""
3433
if arr.ndim > 1:

src/furax_cs/r_analysis/caching.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@
88
from furax import HomothetyOperator
99
from furax.obs import negative_log_likelihood, sky_signal
1010
from furax.obs.stokes import Stokes
11-
from jaxtyping import Array, Float, Int
12-
1311
from furax_cs.optim import minimize
12+
from jaxtyping import Array, Float, Int
1413

1514

1615
def compute_w(

src/furax_cs/r_analysis/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import datasets
44
import matplotlib.pyplot as plt
55
import scienceplots # noqa: F401
6-
76
from furax_cs.data.instruments import get_instrument
87

98
from ..logging_utils import (

src/furax_cs/scripts/bench_bcp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
from furax import HomothetyOperator
6060
from furax.obs import negative_log_likelihood
6161
from furax.obs.landscapes import Stokes
62-
6362
from furax_cs import load_from_cache, minimize, save_to_cache
6463
from furax_cs.logging_utils import info
6564

src/furax_cs/scripts/bench_clusters.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
from furax.obs import negative_log_likelihood, spectral_cmb_variance
3232
from furax.obs.stokes import Stokes
33-
3433
from furax_cs import (
3534
generate_noise_operator,
3635
kmeans_clusters,

src/furax_cs/scripts/compute_mr.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import argparse
22

33
import jax.numpy as jnp
4-
54
from furax_cs import get_mask
65
from furax_cs.multires_clusters import multires_clusters
76

src/furax_cs/scripts/distributed_gridding.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,6 @@
7070
from furax.obs.landscapes import FrequencyLandscape
7171
from furax.obs.operators import NoiseDiagonalOperator
7272
from furax.obs.stokes import Stokes
73-
from jax_grid_search import DistributedGridSearch
74-
from jax_healpy.clustering import (
75-
find_kmeans_clusters,
76-
get_cutout_from_mask,
77-
normalize_by_first_occurrence,
78-
)
79-
8073
from furax_cs import (
8174
MASK_CHOICES,
8275
dump_default_search_space,
@@ -90,6 +83,12 @@
9083
sanitize_mask_name,
9184
)
9285
from furax_cs.logging_utils import info, success
86+
from jax_grid_search import DistributedGridSearch
87+
from jax_healpy.clustering import (
88+
find_kmeans_clusters,
89+
get_cutout_from_mask,
90+
normalize_by_first_occurrence,
91+
)
9392

9493
jax.config.update("jax_enable_x64", True)
9594

src/furax_cs/scripts/fgbuster_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@
6565
)
6666

6767
from furax.obs.stokes import Stokes
68-
from jax_healpy.clustering import get_cutout_from_mask, get_fullmap_from_cutout
69-
7068
from furax_cs import (
7169
MASK_CHOICES,
7270
generate_noise_operator,
@@ -79,6 +77,7 @@
7977
sanitize_mask_name,
8078
)
8179
from furax_cs.logging_utils import info, success
80+
from jax_healpy.clustering import get_cutout_from_mask, get_fullmap_from_cutout
8281

8382
jax.config.update("jax_enable_x64", True)
8483

src/furax_cs/scripts/kmeans_model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@
7272
sky_signal,
7373
)
7474
from furax.obs.stokes import Stokes
75-
from jax_healpy.clustering import get_cutout_from_mask, normalize_by_first_occurrence
76-
7775
from furax_cs import (
7876
MASK_CHOICES,
7977
generate_noise_operator,
@@ -87,6 +85,7 @@
8785
sanitize_mask_name,
8886
)
8987
from furax_cs.logging_utils import info, success
88+
from jax_healpy.clustering import get_cutout_from_mask, normalize_by_first_occurrence
9089

9190
jax.config.update("jax_enable_x64", True)
9291

0 commit comments

Comments
 (0)