You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This package is designed to minimize likelihoods computed by [FURAX](https://github.com/CMBSciPol/furax), a JAX-based CMB analysis framework. It provides distributed grid search capabilities specifically optimized for:
15
14
16
15
-**Spatial spectral index variability:** Efficiently explore parameter spaces for spatially-varying spectral indices in foreground models
17
16
-**Foreground component optimization:** Test and compare different foreground component configurations to find the optimal model choice
18
-
-**Likelihood model optimization:** Systematically search through discrete model configurations and continuously optimize their parameters
17
+
-**Likelihood model optimization:** Systematically search through discrete model configurations
19
18
20
-
The distributed grid search is built to handle the computational demands of CMB likelihood analysis, leveraging JAX's performance and enabling efficient parallel exploration of both discrete and continuous parameter spaces.
19
+
The distributed grid search is built to handle the computational demands of CMB likelihood analysis, leveraging JAX's performance and enabling efficient parallel exploration of discrete parameter spaces.
20
+
21
+
> **Note:** Continuous optimization features (formerly `optimize`) have been moved to [furax-cs](https://github.com/CMBSciPol/furax-cs). Please use `furax_cs.minimize` for gradient-based optimization.
21
22
22
23
---
23
24
24
-
This repository provides two complementary optimization tools:
25
+
This repository provides:
25
26
26
27
1.**Distributed Grid Search for Discrete Optimization:**
27
28
Explore a parameter space by evaluating a user-defined objective function on a grid of discrete values. The search runs in parallel across available processes, automatically handling batching, progress tracking, and result aggregation.
28
29
29
-
2.**Continuous Optimization with Optax:**
30
-
Minimize continuous functions using gradient-based methods (such as LBFGS). This routine leverages Optax for iterative parameter updates and includes built-in progress monitoring.
31
-
32
30
---
33
31
34
32
## Getting Started
@@ -47,7 +45,7 @@ pip install jax_grid_search
47
45
48
46
For comprehensive tutorials and hands-on examples, see the **[examples directory](./examples/)** which contains:
49
47
50
-
-**5 interactive Jupyter notebooks** covering basic to advanced concepts
48
+
-**Interactive Jupyter notebooks** covering basic to advanced concepts
51
49
-**Distributed computing examples** with MPI setup
52
50
-**Complete API demonstrations** with visualization
53
51
@@ -57,7 +55,7 @@ For comprehensive tutorials and hands-on examples, see the **[examples directory
Define your objective function and parameter grid, then run a distributed grid search. The objective function must return a dictionary with a `"value"` key.
See [02-advanced-grid-search.ipynb](./examples/02-advanced-grid-search.ipynb) for complete examples with visualization.
194
192
195
-
### 2. Continuous Optimization using Optax
196
-
197
-
Use the continuous optimization routine to minimize a function with gradient-based methods (e.g., LBFGS). The example below minimizes a simple quadratic function.
198
-
199
-
```python
200
-
import jax.numpy as jnp
201
-
import optax
202
-
from jax_grid_search import optimize
203
-
from jax_progress import TqdmProgressMeter
204
-
205
-
# Define a continuous objective function (e.g., quadratic)
206
-
defquadratic(x):
207
-
return jnp.sum((x -3.0) **2)
208
-
209
-
# Initial parameters and an optimizer (e.g., LBFGS)
210
-
init_params = jnp.array([0.0])
211
-
optimizer = optax.lbfgs()
212
-
213
-
# Run continuous optimization with progress monitoring (optional)
214
-
best_params, opt_state = optimize(
215
-
init_params,
216
-
quadratic,
217
-
opt=optimizer,
218
-
max_iter=50,
219
-
tol=1e-10,
220
-
progress=TqdmProgressMeter(total=50)
221
-
)
222
-
223
-
print("Optimized Parameters:", best_params)
224
-
```
225
-
226
-
#### Using Different Optimizers
227
-
228
-
The library supports various Optax optimizers beyond LBFGS:
229
-
230
-
```python
231
-
import optax
232
-
from jax_grid_search import optimize
233
-
from jax_progress import TqdmProgressMeter
234
-
235
-
defrosenbrock(x):
236
-
# Classic optimization test function
237
-
return100* (x[1] - x[0]**2)**2+ (1- x[0])**2
238
-
239
-
init_params = jnp.array([-1.0, 1.0])
240
-
241
-
# Try different optimizers
242
-
optimizers = {
243
-
"LBFGS": optax.lbfgs(),
244
-
"Adam": optax.adam(learning_rate=0.01),
245
-
"SGD": optax.sgd(learning_rate=0.1),
246
-
"RMSprop": optax.rmsprop(learning_rate=0.01)
247
-
}
248
-
249
-
for name, optimizer in optimizers.items():
250
-
result, state = optimize(
251
-
init_params, rosenbrock, optimizer,
252
-
max_iter=1000, tol=1e-8,
253
-
progress=TqdmProgressMeter(total=1000)
254
-
)
255
-
print(f"{name}: {result}, final value: {rosenbrock(result)}")
256
-
```
257
-
258
-
#### Parameter Bounds and Constraints
259
-
260
-
Use box constraints to limit parameter values during optimization:
261
-
262
-
```python
263
-
from jax_progress import TqdmProgressMeter
264
-
265
-
# Constrain parameters to [0, 10] range
266
-
lower_bounds = jnp.array([0.0, 0.0])
267
-
upper_bounds = jnp.array([10.0, 10.0])
268
-
269
-
result, state = optimize(
270
-
init_params,
271
-
objective_function,
272
-
optax.adam(0.1),
273
-
max_iter=100,
274
-
tol=1e-6,
275
-
progress=TqdmProgressMeter(total=100),
276
-
lower_bound=lower_bounds,
277
-
upper_bound=upper_bounds
278
-
)
279
-
```
280
-
281
-
#### Update History and Debugging
282
-
283
-
Track optimization progress for analysis and debugging:
284
-
285
-
```python
286
-
from jax_progress import TqdmProgressMeter
287
-
288
-
result, state = optimize(
289
-
init_params,
290
-
objective_function,
291
-
optax.lbfgs(),
292
-
max_iter=100,
293
-
tol=1e-8,
294
-
progress=TqdmProgressMeter(total=100),
295
-
log_updates=True# Enable update history logging
296
-
)
297
-
298
-
# Plot optimization history
299
-
import matplotlib.pyplot as plt
300
-
if state.update_history isnotNone:
301
-
history = state.update_history
302
-
plt.figure(figsize=(12, 4))
303
-
304
-
plt.subplot(1, 2, 1)
305
-
plt.plot(history[:, 0])
306
-
plt.ylabel('Update Norm')
307
-
plt.xlabel('Iteration')
308
-
plt.yscale('log')
309
-
310
-
plt.subplot(1, 2, 2)
311
-
plt.plot(history[:, 1])
312
-
plt.ylabel('Objective Value')
313
-
plt.xlabel('Iteration')
314
-
plt.show()
315
-
```
316
-
317
-
#### Running multiple optimization tasks with vmap
318
-
319
-
You can run multiple optimization tasks in parallel using `jax.vmap`. This is useful when optimizing multiple functions or parameters simultaneously.
320
-
321
-
(This is very useful for simulating multiple noise realizations for example)
322
-
323
-
The jax-progress library automatically handles vmap tracking internally.
324
-
325
-
```python
326
-
import jax
327
-
import jax.numpy as jnp
328
-
import optax
329
-
from jax_grid_search import optimize
330
-
from jax_progress import TqdmProgressMeter
331
-
332
-
# Define multiple objective functions
333
-
defobjective_fn(x, normal):
334
-
return jnp.sum(((x -3.0) **2) + normal)
335
-
336
-
progress = TqdmProgressMeter(total=50)
337
-
338
-
defsolve_one(seed):
339
-
init_params = jnp.array([0.0])
340
-
normal = jax.random.normal(jax.random.PRNGKey(seed), init_params.shape)
341
-
optimizer = optax.lbfgs()
342
-
# Run continuous optimization with progress monitoring (optional)
343
-
best_params, opt_state = optimize(
344
-
init_params,
345
-
objective_fn,
346
-
opt=optimizer,
347
-
max_iter=50,
348
-
tol=1e-4,
349
-
progress=progress,
350
-
normal=normal
351
-
)
352
-
353
-
return best_params
354
-
355
-
jax.vmap(solve_one)(jnp.arange(10))
356
-
```
357
-
358
-
### 3. Function Conditioning
359
-
360
-
Improve optimization performance by transforming parameters to similar scales and normalizing outputs. This is essential for problems with parameters in different ranges (e.g., temperature vs spectral index) or large objective values (e.g., chi-square with many pixels).
361
-
362
-
```python
363
-
import jax.numpy as jnp
364
-
import optax
365
-
from jax_grid_search import condition, optimize
366
-
from jax_progress import TqdmProgressMeter
367
-
368
-
# Function with parameters in different scales and large output
### 4. Optimizing Likelihood parameters and models
402
-
403
-
You can use the continuous optimization to optimize the parameters of a model that is defined in a function.
404
-
For performance purposes, you need to make sure that the discrete parameters that can control the likelihood model can be jitted (using `lax.cond` for example or other lax control flow functions).
0 commit comments