Skip to content

Commit b8a6b95

Browse files
committed
Fix AIR convergence and align ADEV Bernoulli probs semantics
1 parent 8b540f2 commit b8a6b95

8 files changed

Lines changed: 449 additions & 518 deletions

File tree

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,32 @@ pixi run python -m examples.air.main train \
364364
- `--dataset synthetic` (default): samples from the AIR prior (no extra framework dependencies).
365365
- `--dataset multi-mnist --data-path /path/to/multi_mnist_uint8.npz`: load pre-generated multi-MNIST arrays.
366366

367+
**Reliable multi-MNIST acquisition**:
368+
```bash
369+
# Generates/downloads and writes examples/air/data/multi_mnist_uint8.npz
370+
pixi run air-fetch-data
371+
```
372+
373+
This uses `pyro.contrib.examples.multi_mnist` in the `perfbench-pyro` environment.
374+
375+
**GPU execution (recommended for AIR)**:
376+
```bash
377+
# Verify CUDA JAX is available
378+
pixi run -e cuda cuda-info
379+
380+
# Train one estimator on multi-MNIST (2048 examples)
381+
pixi run air-train-gpu
382+
383+
# Compare estimators on multi-MNIST (2048 examples)
384+
pixi run air-compare-gpu
385+
```
386+
387+
Notes:
388+
- AIR on large batches/examples is memory-heavy; use `--eval-batch-size` to bound eval-time memory.
389+
- On systems with constrained `/tmp`, set `TMPDIR=/dev/shm` and disable Triton GEMM autotuning:
390+
`XLA_FLAGS='--xla_gpu_enable_triton_gemm=false --xla_gpu_autotune_level=0'`.
391+
- Our current CLI defaults are tuned for smoke/repro runs. Paper-scale runs may require longer epochs and more data.
392+
367393
### Curve Fitting with Outlier Detection
368394

369395
**What it does**: Polynomial regression with robust outlier detection, demonstrating:

0 commit comments

Comments
 (0)