Skip to content

Commit 7587007

Browse files
authored
Update progress bar (#19)
* refactor: remove continuous optimization features * ci: update python versions to 3.10-3.13 * ci: add notebook execution workflow * docs: add notebooks workflow badge * docs: update README to reflect removal of continuous optimization * Add notebook dependencies and update CI workflows to include main branch * test: enforce strict MPI requirement in distributed notebook * ci: install libopenmpi-dev for notebook workflow * Update notebooks workflow and dependencies
1 parent b456e7a commit 7587007

13 files changed

Lines changed: 140 additions & 2615 deletions

.github/workflows/notebooks.yml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
name: Notebooks
2+
3+
on:
4+
push:
5+
branches: [ "main" ]
6+
pull_request:
7+
branches: [ "main" ]
8+
9+
jobs:
10+
test-notebooks:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ["3.10", "3.11", "3.12", "3.13"]
15+
steps:
16+
- uses: actions/checkout@v4
17+
18+
- name: Set up Python ${{ matrix.python-version }}
19+
uses: actions/setup-python@v5
20+
with:
21+
python-version: ${{ matrix.python-version }}
22+
23+
- name: Install system dependencies
24+
run: |
25+
sudo apt-get update
26+
sudo apt-get install -y libopenmpi-dev
27+
python -m pip install --upgrade pip
28+
pip install setuptools cython mpi4py
29+
30+
- name: Install dependencies
31+
run: |
32+
pip install .[notebooks]
33+
pip install nbconvert ipykernel
34+
35+
- name: Execute notebooks
36+
run: |
37+
mkdir -p executed_notebooks
38+
find examples -name "*.ipynb" -print0 | xargs -0 -I {} jupyter nbconvert --to notebook --execute {} --output-dir executed_notebooks
39+
40+
41+
- name: Upload executed notebooks
42+
if: always()
43+
uses: actions/upload-artifact@v4
44+
with:
45+
name: executed-notebooks-${{ matrix.python-version }}
46+
path: executed_notebooks/*.ipynb

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
runs-on: ubuntu-latest
1313
strategy:
1414
matrix:
15-
python-version: ['3.10', '3.11', '3.12']
15+
python-version: ['3.10', '3.11', '3.12', '3.13']
1616

1717
steps:
1818
- name: Checkout code

README.md

Lines changed: 10 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,32 @@
1-
# Distributed Grid Search & Continuous Optimization using JAX
1+
# Distributed Grid Search using JAX
22

3-
[![Testing](https://github.com/CMBSciPol/jax-grid-search/actions/workflows/tests.yml/badge.svg)](https://github.com/CMBSciPol/jax-grid-search/actions/workflows/tests.yml)
3+
[![Tests](https://github.com/CMBSciPol/jax-grid-search/actions/workflows/tests.yml/badge.svg)](https://github.com/CMBSciPol/jax-grid-search/actions/workflows/tests.yml)
4+
[![Notebooks](https://img.shields.io/github/actions/workflow/status/CMBSciPol/jax-grid-search/notebooks.yml?logo=jupyter&label=notebooks)](https://github.com/CMBSciPol/jax-grid-search/actions/workflows/notebooks.yml)
45
[![Code Formatting](https://github.com/CMBSciPol/jax-grid-search/actions/workflows/formatting.yml/badge.svg)](https://github.com/CMBSciPol/jax-grid-search/actions/workflows/formatting.yml)
56
[![Upload Python Package](https://github.com/CMBSciPol/jax-grid-search/actions/workflows/python-publish.yml/badge.svg)](https://github.com/CMBSciPol/jax-grid-search/actions/workflows/python-publish.yml)
67
[![PyPI version](https://badge.fury.io/py/jax-grid-search.svg)](https://badge.fury.io/py/jax-grid-search)
78
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
89
<a href="https://doi.org/10.5281/zenodo.17674777"><img src="https://zenodo.org/badge/917061582.svg" alt="DOI"></a>
910

10-
11-
1211
## About
1312

1413
This package is designed to minimize likelihoods computed by [FURAX](https://github.com/CMBSciPol/furax), a JAX-based CMB analysis framework. It provides distributed grid search capabilities specifically optimized for:
1514

1615
- **Spatial spectral index variability:** Efficiently explore parameter spaces for spatially-varying spectral indices in foreground models
1716
- **Foreground component optimization:** Test and compare different foreground component configurations to find the optimal model choice
18-
- **Likelihood model optimization:** Systematically search through discrete model configurations and continuously optimize their parameters
17+
- **Likelihood model optimization:** Systematically search through discrete model configurations
1918

20-
The distributed grid search is built to handle the computational demands of CMB likelihood analysis, leveraging JAX's performance and enabling efficient parallel exploration of both discrete and continuous parameter spaces.
19+
The distributed grid search is built to handle the computational demands of CMB likelihood analysis, leveraging JAX's performance and enabling efficient parallel exploration of discrete parameter spaces.
20+
21+
> **Note:** Continuous optimization features (formerly `optimize`) have been moved to [furax-cs](https://github.com/CMBSciPol/furax-cs). Please use `furax_cs.minimize` for gradient-based optimization.
2122
2223
---
2324

24-
This repository provides two complementary optimization tools:
25+
This repository provides:
2526

2627
1. **Distributed Grid Search for Discrete Optimization:**
2728
Explore a parameter space by evaluating a user-defined objective function on a grid of discrete values. The search runs in parallel across available processes, automatically handling batching, progress tracking, and result aggregation.
2829

29-
2. **Continuous Optimization with Optax:**
30-
Minimize continuous functions using gradient-based methods (such as LBFGS). This routine leverages Optax for iterative parameter updates and includes built-in progress monitoring.
31-
3230
---
3331

3432
## Getting Started
@@ -47,7 +45,7 @@ pip install jax_grid_search
4745

4846
For comprehensive tutorials and hands-on examples, see the **[examples directory](./examples/)** which contains:
4947

50-
- **5 interactive Jupyter notebooks** covering basic to advanced concepts
48+
- **Interactive Jupyter notebooks** covering basic to advanced concepts
5149
- **Distributed computing examples** with MPI setup
5250
- **Complete API demonstrations** with visualization
5351

@@ -57,7 +55,7 @@ For comprehensive tutorials and hands-on examples, see the **[examples directory
5755

5856
## Usage Examples
5957

60-
### 1. Distributed Grid Search (Discrete Optimization)
58+
### Distributed Grid Search (Discrete Optimization)
6159

6260
Define your objective function and parameter grid, then run a distributed grid search. The objective function must return a dictionary with a `"value"` key.
6361

@@ -192,217 +190,6 @@ def image_filter_objective(kernel, bias_matrix):
192190

193191
See [02-advanced-grid-search.ipynb](./examples/02-advanced-grid-search.ipynb) for complete examples with visualization.
194192

195-
### 2. Continuous Optimization using Optax
196-
197-
Use the continuous optimization routine to minimize a function with gradient-based methods (e.g., LBFGS). The example below minimizes a simple quadratic function.
198-
199-
```python
200-
import jax.numpy as jnp
201-
import optax
202-
from jax_grid_search import optimize
203-
from jax_progress import TqdmProgressMeter
204-
205-
# Define a continuous objective function (e.g., quadratic)
206-
def quadratic(x):
207-
return jnp.sum((x - 3.0) ** 2)
208-
209-
# Initial parameters and an optimizer (e.g., LBFGS)
210-
init_params = jnp.array([0.0])
211-
optimizer = optax.lbfgs()
212-
213-
# Run continuous optimization with progress monitoring (optional)
214-
best_params, opt_state = optimize(
215-
init_params,
216-
quadratic,
217-
opt=optimizer,
218-
max_iter=50,
219-
tol=1e-10,
220-
progress=TqdmProgressMeter(total=50)
221-
)
222-
223-
print("Optimized Parameters:", best_params)
224-
```
225-
226-
#### Using Different Optimizers
227-
228-
The library supports various Optax optimizers beyond LBFGS:
229-
230-
```python
231-
import optax
232-
from jax_grid_search import optimize
233-
from jax_progress import TqdmProgressMeter
234-
235-
def rosenbrock(x):
236-
# Classic optimization test function
237-
return 100 * (x[1] - x[0]**2)**2 + (1 - x[0])**2
238-
239-
init_params = jnp.array([-1.0, 1.0])
240-
241-
# Try different optimizers
242-
optimizers = {
243-
"LBFGS": optax.lbfgs(),
244-
"Adam": optax.adam(learning_rate=0.01),
245-
"SGD": optax.sgd(learning_rate=0.1),
246-
"RMSprop": optax.rmsprop(learning_rate=0.01)
247-
}
248-
249-
for name, optimizer in optimizers.items():
250-
result, state = optimize(
251-
init_params, rosenbrock, optimizer,
252-
max_iter=1000, tol=1e-8,
253-
progress=TqdmProgressMeter(total=1000)
254-
)
255-
print(f"{name}: {result}, final value: {rosenbrock(result)}")
256-
```
257-
258-
#### Parameter Bounds and Constraints
259-
260-
Use box constraints to limit parameter values during optimization:
261-
262-
```python
263-
from jax_progress import TqdmProgressMeter
264-
265-
# Constrain parameters to [0, 10] range
266-
lower_bounds = jnp.array([0.0, 0.0])
267-
upper_bounds = jnp.array([10.0, 10.0])
268-
269-
result, state = optimize(
270-
init_params,
271-
objective_function,
272-
optax.adam(0.1),
273-
max_iter=100,
274-
tol=1e-6,
275-
progress=TqdmProgressMeter(total=100),
276-
lower_bound=lower_bounds,
277-
upper_bound=upper_bounds
278-
)
279-
```
280-
281-
#### Update History and Debugging
282-
283-
Track optimization progress for analysis and debugging:
284-
285-
```python
286-
from jax_progress import TqdmProgressMeter
287-
288-
result, state = optimize(
289-
init_params,
290-
objective_function,
291-
optax.lbfgs(),
292-
max_iter=100,
293-
tol=1e-8,
294-
progress=TqdmProgressMeter(total=100),
295-
log_updates=True # Enable update history logging
296-
)
297-
298-
# Plot optimization history
299-
import matplotlib.pyplot as plt
300-
if state.update_history is not None:
301-
history = state.update_history
302-
plt.figure(figsize=(12, 4))
303-
304-
plt.subplot(1, 2, 1)
305-
plt.plot(history[:, 0])
306-
plt.ylabel('Update Norm')
307-
plt.xlabel('Iteration')
308-
plt.yscale('log')
309-
310-
plt.subplot(1, 2, 2)
311-
plt.plot(history[:, 1])
312-
plt.ylabel('Objective Value')
313-
plt.xlabel('Iteration')
314-
plt.show()
315-
```
316-
317-
#### Running multiple optimization tasks with vmap
318-
319-
You can run multiple optimization tasks in parallel using `jax.vmap`. This is useful when optimizing multiple functions or parameters simultaneously.
320-
321-
(This is very useful for simulating multiple noise realizations for example)
322-
323-
The jax-progress library automatically handles vmap tracking internally.
324-
325-
```python
326-
import jax
327-
import jax.numpy as jnp
328-
import optax
329-
from jax_grid_search import optimize
330-
from jax_progress import TqdmProgressMeter
331-
332-
# Define multiple objective functions
333-
def objective_fn(x, normal):
334-
return jnp.sum(((x - 3.0) ** 2) + normal)
335-
336-
progress = TqdmProgressMeter(total=50)
337-
338-
def solve_one(seed):
339-
init_params = jnp.array([0.0])
340-
normal = jax.random.normal(jax.random.PRNGKey(seed), init_params.shape)
341-
optimizer = optax.lbfgs()
342-
# Run continuous optimization with progress monitoring (optional)
343-
best_params, opt_state = optimize(
344-
init_params,
345-
objective_fn,
346-
opt=optimizer,
347-
max_iter=50,
348-
tol=1e-4,
349-
progress=progress,
350-
normal=normal
351-
)
352-
353-
return best_params
354-
355-
jax.vmap(solve_one)(jnp.arange(10))
356-
```
357-
358-
### 3. Function Conditioning
359-
360-
Improve optimization performance by transforming parameters to similar scales and normalizing outputs. This is essential for problems with parameters in different ranges (e.g., temperature vs spectral index) or large objective values (e.g., chi-square with many pixels).
361-
362-
```python
363-
import jax.numpy as jnp
364-
import optax
365-
from jax_grid_search import condition, optimize
366-
from jax_progress import TqdmProgressMeter
367-
368-
# Function with parameters in different scales and large output
369-
npix = 12 * 64**2 # HEALPix pixels
370-
371-
def objective(params):
372-
# Simulate chi-square scaled by number of pixels
373-
return npix * ((params['temp'] - 20)**2 + (params['beta'] - 1.5)**2)
374-
375-
# Apply conditioning: parameter scaling + output normalization
376-
lower = {'temp': 10.0, 'beta': 0.5}
377-
upper = {'temp': 40.0, 'beta': 3.0}
378-
379-
conditioned_fn, to_opt, from_opt = condition(
380-
objective,
381-
lower=lower,
382-
upper=upper,
383-
factor=npix # Normalize by problem size
384-
)
385-
386-
# Transform parameters to [0,1] space, optimize, then transform back
387-
init_opt = to_opt({'temp': 15.0, 'beta': 1.0})
388-
389-
result_opt, _ = optimize(
390-
init_opt,
391-
conditioned_fn,
392-
optax.lbfgs(),
393-
max_iter=100,
394-
tol=1e-6,
395-
progress=TqdmProgressMeter(total=100)
396-
)
397-
398-
result = from_opt(result_opt) # Back to physical space
399-
print(f"Optimized: temp={result['temp']:.2f}, beta={result['beta']:.3f}")
400-
```
401-
### 4. Optimizing Likelihood parameters and models
402-
403-
You can use the continuous optimization to optimize the parameters of a model that is defined in a function.
404-
For performance purposes, you need to make sure that the discrete parameters that can control the likelihood model can be jitted (using `lax.cond` for example or other lax control flow functions).
405-
406193
## Citation
407194

408195
```

examples/01-basic-grid-search.ipynb

Lines changed: 14 additions & 7 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)