Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit)
[![Documentation](https://img.shields.io/badge/docs-readthedocs-blue?logo=readthedocs)](https://furax-cs.readthedocs.io/en/latest/)
[![Results Explorer](https://img.shields.io/badge/%F0%9F%A4%97%20Results-Explorer-yellow?)](https://askabalan-furax-cs-results.hf.space/)
[![arXiv](https://img.shields.io/badge/arXiv-XXXX.XXXXX-b31b1b?logo=arxiv)](https://arxiv.org/abs/XXXX.XXXXX)
[![arXiv](https://img.shields.io/badge/arXiv-2604.08463-b31b1b?logo=arxiv)](https://arxiv.org/abs/2604.08463)

**FURAX-CS** (FURAX Component Separation) is a Python package designed to benchmark and implement advanced component separation techniques for Cosmic Microwave Background (CMB) analysis. It leverages **JAX** for high-performance computing on GPUs and implements novel adaptive clustering methods.

Expand Down
7 changes: 5 additions & 2 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@

## Optimization

Optimization is provided by the [CADRE](https://github.com/CMBSciPol/CADRE) package.
See the [minimization docs](../minimization.md) for full solver reference.

```{eval-rst}
.. autofunction:: furax_cs.optim.minimize.minimize
.. autofunction:: cadre.minimize.minimize
```

```{eval-rst}
.. autofunction:: furax_cs.optim.solvers.get_solver
.. autofunction:: cadre.solvers.get_solver
```

## Binning
Expand Down
8 changes: 5 additions & 3 deletions docs/minimization.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Minimization Solvers

`furax-cs` provides a unified interface for various minimization solvers, including wrappers for **Optax**, **Optimistix**, and **SciPy**.
`furax-cs` exposes a unified interface for various minimization solvers via the
[**CADRE**](https://github.com/CMBSciPol/CADRE) package (Constraint-Aware Descent Routine Executor),
which provides the underlying implementations for **Optax**, **Optimistix**, and **SciPy** solvers.

## Available Solvers

Expand Down Expand Up @@ -161,7 +163,7 @@ All other solvers (`optax_lbfgs`, `adam`, `optimistix_*`, etc.) benefit from ext

### Single Run
```python
from furax_cs import minimize
from cadre import minimize

final_params, state = minimize(
fn=my_loss_fn,
Expand All @@ -183,7 +185,7 @@ Here is an example of running the same optimization problem with multiple solver
import jax
import jax.numpy as jnp
import optimistix as optx
from furax_cs.optim.solvers import get_solver
from cadre.solvers import get_solver
from functools import partial

# Define a simple quadratic loss function
Expand Down
3 changes: 2 additions & 1 deletion docs/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ from furax.obs import (
from furax.obs.stokes import Stokes
from jax_healpy.clustering import get_cutout_from_mask, get_fullmap_from_cutout

from furax_cs import generate_noise_operator, kmeans_clusters, minimize
from cadre import minimize
from furax_cs import generate_noise_operator, kmeans_clusters
from furax_cs.data import (
get_instrument,
get_mask,
Expand Down
52 changes: 2 additions & 50 deletions notebooks/06_Optimizer_Benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"source": [
"# Optimizer Benchmark\n",
"\n",
"This notebook compares different optimization solvers from `furax_cs.optim` on the Rosenbrock function.\n",
"This notebook compares different optimization solvers from `cadre` on the Rosenbrock function.\n",
"\n",
"## Test Matrix\n",
"\n",
Expand Down Expand Up @@ -45,7 +45,7 @@
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"from furax_cs.optim import minimize\n",
"from cadre import minimize\n",
"\n",
"# Enable 64-bit precision\n",
"jax.config.update(\"jax_enable_x64\", True)"
Expand Down Expand Up @@ -196,28 +196,6 @@
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array([ 35.178079 , 839.05471775, 946.67697787, 1124.11267378,\n",
" 1344.27794419, 1591.60078999, 2023.58745027, 2241.40393647,\n",
" 2429.79396429, 2553.55842245], dtype=float64)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"eigvals"
]
},
{
"cell_type": "code",
"execution_count": 36,
Expand Down Expand Up @@ -255,32 +233,6 @@
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Array([[ 2., 0., 0., ..., 0., 0., 0.],\n",
" [ 0., 202., 0., ..., 0., 0., 0.],\n",
" [ 0., 0., 202., ..., 0., 0., 0.],\n",
" ...,\n",
" [ 0., 0., 0., ..., 202., 0., 0.],\n",
" [ 0., 0., 0., ..., 0., 202., 0.],\n",
" [ 0., 0., 0., ..., 0., 0., 200.]], dtype=float64)"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"H"
]
},
{
"cell_type": "code",
"execution_count": 13,
Expand Down
43 changes: 19 additions & 24 deletions notebooks/80-plot_runs.ipynb

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ dependencies = [
"pysm3",
"astropy",
"optimistix",
"cobyqa",
"jaxopt",
"jax-cadre",
"scienceplots"
]
description = "GPU Powered CMB Parametric Component Seperation using Furax and JAX"
Expand All @@ -30,7 +29,7 @@ version = "0.1.1"

[project.optional-dependencies]
all = [
"furax-cs[plotting,dev,benchmarks,io,docs]"
"furax-cs[plotting,dev,benchmarks,io,docs,scipy]"
]
benchmarks = ["jax-hpc-profiler>=v0.3.2"]
dev = [
Expand All @@ -53,6 +52,9 @@ plotting = [
"seaborn",
"scienceplots"
]
scipy = [
"jax-cadre[scipy]"
]

[project.scripts]
bench-bcp = "furax_cs.scripts.bench_bcp:main"
Expand Down
3 changes: 2 additions & 1 deletion slurm/runners/11-binning.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ RESULTS_DIR="RESULTS/KMEANS_C1D1S1"
OUTPUT_DIR="RESULTS/BINNING_C1D1S1"

# 3 optimal runs (same as section_45.sh)
RUNS=(-r 'BD7500_TD500_BS500_GAL020' 'BD10000_TD6000_BS300_GAL040' 'BD10000_TD7000_BS500_GAL060')
RUNS=(-r 'BD7000_TD500_BS500_GAL020' 'BD10000_TD500_BS150_GAL040' 'BD10000_TD2500_BS300_GAL060')


# =============================================================================
# Phase 1 & 2: For each bin config, bin then run kmeans
Expand Down
23 changes: 12 additions & 11 deletions src/furax_cs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

from importlib import metadata

from cadre import (
SOLVER_NAMES,
ScipyMinimizeState,
apply_projection,
condition,
get_solver,
lbfgs_backtrack,
lbfgs_zoom,
minimize,
scipy_minimize,
)

from . import r_analysis
from .binning import bin_parameter_map
from .data import (
Expand All @@ -28,17 +40,6 @@
from .kmeans_clusters import kmeans_clusters
from .multires_clusters import multires_clusters
from .noise import generate_noise_operator
from .optim import (
SOLVER_NAMES,
ScipyMinimizeState,
apply_projection,
condition,
get_solver,
lbfgs_backtrack,
lbfgs_zoom,
minimize,
scipy_minimize,
)

__all__ = [
"bin_parameter_map",
Expand Down
43 changes: 5 additions & 38 deletions src/furax_cs/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,8 @@
"""
Optimization utilities for FURAX component separation.

This package provides:
- L-BFGS solvers with zoom and backtracking linesearch
- Box projection transformation for constrained optimization
- Unified optimization interface supporting optax, optimistix, and scipy
- Function conditioning (parameter transformation and gradient-based scaling)
"""Compatibility shim — optimization has moved to the cadre package.

Example usage:
>>> from furax_cs.optim import minimize
>>>
>>> # Simple optimization
>>> params, state = minimize(
... fn=objective,
... init_params={'beta': 1.5},
... solver_name='optax_lbfgs',
... )
Import from cadre directly:
from cadre import minimize, get_solver, condition, ...
"""

from .minimize import ScipyMinimizeState, minimize, scipy_minimize
from .solvers import (
SOLVER_NAMES,
apply_projection,
get_solver,
lbfgs_backtrack,
lbfgs_zoom,
)
from .utils import condition

__all__ = [
"SOLVER_NAMES",
"ScipyMinimizeState",
"apply_projection",
"condition",
"get_solver",
"lbfgs_backtrack",
"lbfgs_zoom",
"scipy_minimize",
"minimize",
]
from cadre import * # noqa: F401, F403
from cadre import __all__ # noqa: F401
Loading
Loading