Skip to content

Commit cd73553

Browse files
authored
Merge pull request #29 from CMBSciPol/update-arxiv
Update arxiv
2 parents fd2d047 + d1e550e commit cd73553

22 files changed

Lines changed: 357 additions & 587 deletions

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit)
66
[![Documentation](https://img.shields.io/badge/docs-readthedocs-blue?logo=readthedocs)](https://furax-cs.readthedocs.io/en/latest/)
77
[![Results Explorer](https://img.shields.io/badge/%F0%9F%A4%97%20Results-Explorer-yellow?)](https://askabalan-furax-cs-results.hf.space/)
8-
[![arXiv](https://img.shields.io/badge/arXiv-XXXX.XXXXX-b31b1b?logo=arxiv)](https://arxiv.org/abs/XXXX.XXXXX)
8+
[![arXiv](https://img.shields.io/badge/arXiv-2604.08463-b31b1b?logo=arxiv)](https://arxiv.org/abs/2604.08463)
99

1010
**FURAX-CS** (FURAX Component Separation) is a Python package designed to benchmark and implement advanced component separation techniques for Cosmic Microwave Background (CMB) analysis. It leverages **JAX** for high-performance computing on GPUs and implements novel adaptive clustering methods.
1111

docs/api/index.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,15 @@
1818

1919
## Optimization
2020

21+
Optimization is provided by the [CADRE](https://github.com/CMBSciPol/CADRE) package.
22+
See the [minimization docs](../minimization.md) for full solver reference.
23+
2124
```{eval-rst}
22-
.. autofunction:: furax_cs.optim.minimize.minimize
25+
.. autofunction:: cadre.minimize.minimize
2326
```
2427

2528
```{eval-rst}
26-
.. autofunction:: furax_cs.optim.solvers.get_solver
29+
.. autofunction:: cadre.solvers.get_solver
2730
```
2831

2932
## Binning

docs/minimization.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Minimization Solvers
22

3-
`furax-cs` provides a unified interface for various minimization solvers, including wrappers for **Optax**, **Optimistix**, and **SciPy**.
3+
`furax-cs` exposes a unified interface for various minimization solvers via the
4+
[**CADRE**](https://github.com/CMBSciPol/CADRE) package (Constraint-Aware Descent Routine Executor),
5+
which provides the underlying implementations for **Optax**, **Optimistix**, and **SciPy** solvers.
46

57
## Available Solvers
68

@@ -161,7 +163,7 @@ All other solvers (`optax_lbfgs`, `adam`, `optimistix_*`, etc.) benefit from ext
161163

162164
### Single Run
163165
```python
164-
from furax_cs import minimize
166+
from cadre import minimize
165167

166168
final_params, state = minimize(
167169
fn=my_loss_fn,
@@ -183,7 +185,7 @@ Here is an example of running the same optimization problem with multiple solver
183185
import jax
184186
import jax.numpy as jnp
185187
import optimistix as optx
186-
from furax_cs.optim.solvers import get_solver
188+
from cadre.solvers import get_solver
187189
from functools import partial
188190

189191
# Define a simple quadratic loss function

docs/quick_start.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ from furax.obs import (
3636
from furax.obs.stokes import Stokes
3737
from jax_healpy.clustering import get_cutout_from_mask, get_fullmap_from_cutout
3838

39-
from furax_cs import generate_noise_operator, kmeans_clusters, minimize
39+
from cadre import minimize
40+
from furax_cs import generate_noise_operator, kmeans_clusters
4041
from furax_cs.data import (
4142
get_instrument,
4243
get_mask,

notebooks/06_Optimizer_Benchmark.ipynb

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"source": [
77
"# Optimizer Benchmark\n",
88
"\n",
9-
"This notebook compares different optimization solvers from `furax_cs.optim` on the Rosenbrock function.\n",
9+
"This notebook compares different optimization solvers from `cadre` on the Rosenbrock function.\n",
1010
"\n",
1111
"## Test Matrix\n",
1212
"\n",
@@ -45,7 +45,7 @@
4545
"import jax.numpy as jnp\n",
4646
"import matplotlib.pyplot as plt\n",
4747
"import pandas as pd\n",
48-
"from furax_cs.optim import minimize\n",
48+
"from cadre import minimize\n",
4949
"\n",
5050
"# Enable 64-bit precision\n",
5151
"jax.config.update(\"jax_enable_x64\", True)"
@@ -196,28 +196,6 @@
196196
"plt.show()"
197197
]
198198
},
199-
{
200-
"cell_type": "code",
201-
"execution_count": 13,
202-
"metadata": {},
203-
"outputs": [
204-
{
205-
"data": {
206-
"text/plain": [
207-
"Array([ 35.178079 , 839.05471775, 946.67697787, 1124.11267378,\n",
208-
" 1344.27794419, 1591.60078999, 2023.58745027, 2241.40393647,\n",
209-
" 2429.79396429, 2553.55842245], dtype=float64)"
210-
]
211-
},
212-
"execution_count": 13,
213-
"metadata": {},
214-
"output_type": "execute_result"
215-
}
216-
],
217-
"source": [
218-
"eigvals"
219-
]
220-
},
221199
{
222200
"cell_type": "code",
223201
"execution_count": 36,
@@ -255,32 +233,6 @@
255233
"plt.show()"
256234
]
257235
},
258-
{
259-
"cell_type": "code",
260-
"execution_count": 32,
261-
"metadata": {},
262-
"outputs": [
263-
{
264-
"data": {
265-
"text/plain": [
266-
"Array([[ 2., 0., 0., ..., 0., 0., 0.],\n",
267-
" [ 0., 202., 0., ..., 0., 0., 0.],\n",
268-
" [ 0., 0., 202., ..., 0., 0., 0.],\n",
269-
" ...,\n",
270-
" [ 0., 0., 0., ..., 202., 0., 0.],\n",
271-
" [ 0., 0., 0., ..., 0., 202., 0.],\n",
272-
" [ 0., 0., 0., ..., 0., 0., 200.]], dtype=float64)"
273-
]
274-
},
275-
"execution_count": 32,
276-
"metadata": {},
277-
"output_type": "execute_result"
278-
}
279-
],
280-
"source": [
281-
"H"
282-
]
283-
},
284236
{
285237
"cell_type": "code",
286238
"execution_count": 13,

notebooks/80-plot_runs.ipynb

Lines changed: 19 additions & 24 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ dependencies = [
1717
"pysm3",
1818
"astropy",
1919
"optimistix",
20-
"cobyqa",
21-
"jaxopt",
20+
"jax-cadre",
2221
"scienceplots"
2322
]
2423
description = "GPU Powered CMB Parametric Component Seperation using Furax and JAX"
@@ -30,7 +29,7 @@ version = "0.1.1"
3029

3130
[project.optional-dependencies]
3231
all = [
33-
"furax-cs[plotting,dev,benchmarks,io,docs]"
32+
"furax-cs[plotting,dev,benchmarks,io,docs,scipy]"
3433
]
3534
benchmarks = ["jax-hpc-profiler>=v0.3.2"]
3635
dev = [
@@ -53,6 +52,9 @@ plotting = [
5352
"seaborn",
5453
"scienceplots"
5554
]
55+
scipy = [
56+
"jax-cadre[scipy]"
57+
]
5658

5759
[project.scripts]
5860
bench-bcp = "furax_cs.scripts.bench_bcp:main"

slurm/runners/11-binning.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ RESULTS_DIR="RESULTS/KMEANS_C1D1S1"
6969
OUTPUT_DIR="RESULTS/BINNING_C1D1S1"
7070

7171
# 3 optimal runs (same as section_45.sh)
72-
RUNS=(-r 'BD7500_TD500_BS500_GAL020' 'BD10000_TD6000_BS300_GAL040' 'BD10000_TD7000_BS500_GAL060')
72+
RUNS=(-r 'BD7000_TD500_BS500_GAL020' 'BD10000_TD500_BS150_GAL040' 'BD10000_TD2500_BS300_GAL060')
73+
7374

7475
# =============================================================================
7576
# Phase 1 & 2: For each bin config, bin then run kmeans

src/furax_cs/__init__.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,18 @@
22

33
from importlib import metadata
44

5+
from cadre import (
6+
SOLVER_NAMES,
7+
ScipyMinimizeState,
8+
apply_projection,
9+
condition,
10+
get_solver,
11+
lbfgs_backtrack,
12+
lbfgs_zoom,
13+
minimize,
14+
scipy_minimize,
15+
)
16+
517
from . import r_analysis
618
from .binning import bin_parameter_map
719
from .data import (
@@ -28,17 +40,6 @@
2840
from .kmeans_clusters import kmeans_clusters
2941
from .multires_clusters import multires_clusters
3042
from .noise import generate_noise_operator
31-
from .optim import (
32-
SOLVER_NAMES,
33-
ScipyMinimizeState,
34-
apply_projection,
35-
condition,
36-
get_solver,
37-
lbfgs_backtrack,
38-
lbfgs_zoom,
39-
minimize,
40-
scipy_minimize,
41-
)
4243

4344
__all__ = [
4445
"bin_parameter_map",

src/furax_cs/optim/__init__.py

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,8 @@
1-
"""
2-
Optimization utilities for FURAX component separation.
3-
4-
This package provides:
5-
- L-BFGS solvers with zoom and backtracking linesearch
6-
- Box projection transformation for constrained optimization
7-
- Unified optimization interface supporting optax, optimistix, and scipy
8-
- Function conditioning (parameter transformation and gradient-based scaling)
1+
"""Compatibility shim — optimization has moved to the cadre package.
92
10-
Example usage:
11-
>>> from furax_cs.optim import minimize
12-
>>>
13-
>>> # Simple optimization
14-
>>> params, state = minimize(
15-
... fn=objective,
16-
... init_params={'beta': 1.5},
17-
... solver_name='optax_lbfgs',
18-
... )
3+
Import from cadre directly:
4+
from cadre import minimize, get_solver, condition, ...
195
"""
206

21-
from .minimize import ScipyMinimizeState, minimize, scipy_minimize
22-
from .solvers import (
23-
SOLVER_NAMES,
24-
apply_projection,
25-
get_solver,
26-
lbfgs_backtrack,
27-
lbfgs_zoom,
28-
)
29-
from .utils import condition
30-
31-
__all__ = [
32-
"SOLVER_NAMES",
33-
"ScipyMinimizeState",
34-
"apply_projection",
35-
"condition",
36-
"get_solver",
37-
"lbfgs_backtrack",
38-
"lbfgs_zoom",
39-
"scipy_minimize",
40-
"minimize",
41-
]
7+
from cadre import * # noqa: F401, F403
8+
from cadre import __all__ # noqa: F401

0 commit comments

Comments
 (0)