Skip to content

Commit 8b540f2

Browse files
committed
feat: add AIR case study and harden ADEV float0 handling
- port a GenJAX-only AIR example (core/main/figs) with docs, tasks, and tests - align ADEV primitive JVP dispatch with JAX float0->Zero canonicalization - add staged flip_mvd regression coverage for seed+vmap+jit bool-cast paths - document AD zero-tangent canonicalization behavior in ADEV guides
1 parent 0a59045 commit 8b540f2

13 files changed

Lines changed: 1652 additions & 28 deletions

File tree

README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,31 @@ pixi run -e faircoin python -m examples.faircoin.main \
339339

340340
**Outputs**: `figs/faircoin_combined_posterior_and_timing_obs50_samples2000.pdf`
341341

342+
### AIR Estimator Study (GenJAX-only)
343+
344+
**What it does**: Trains an AIR-style latent-variable model and compares GenJAX discrete-gradient estimators (`enum`, `reinforce`, `mvd`, `hybrid`) under a shared objective/training loop.
345+
346+
**Figures in the paper**: Port of the PLDI'24 AIR estimator experiment (GenJAX path only).
347+
348+
**Commands**:
349+
```bash
350+
# quick smoke comparison (small architecture + synthetic prior samples)
351+
pixi run air-compare
352+
353+
# single-estimator run with CSV outputs
354+
pixi run python -m examples.air.main train \
355+
--estimator enum \
356+
--small-config \
357+
--num-examples 512 \
358+
--epochs 6 \
359+
--history-csv figs/air_enum_history.csv \
360+
--summary-csv figs/air_enum_summary.csv
361+
```
362+
363+
**Dataset modes**:
364+
- `--dataset synthetic` (default): samples from the AIR prior (no extra framework dependencies).
365+
- `--dataset multi-mnist --data-path /path/to/multi_mnist_uint8.npz`: load pre-generated multi-MNIST arrays.
366+
342367
### Curve Fitting with Outlier Detection
343368

344369
**What it does**: Polynomial regression with robust outlier detection, demonstrating:
@@ -477,6 +502,10 @@ Running without CUDA executes the same probabilistic program, but the SMC benchm
477502

478503
All figures are saved to `figs/`:
479504

505+
### AIR
506+
- `air_*_history.csv` - Per-epoch objective/accuracy/time logs (optional)
507+
- `air_*_summary.csv` - Estimator summary table (optional)
508+
480509
### Faircoin
481510
- `faircoin_combined_posterior_and_timing_obs50_samples2000.pdf` - Framework comparison (timing + posterior accuracy)
482511

examples/AGENTS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Each case has its own `AGENTS.md`—read it before editing that case.
55

66
## Case Study Map
77

8+
- `air/`: AIR estimator comparison (GenJAX-only port from PLDI'24 artifact)
89
- `cone/`: ADEV objective variants (ELBO / IWAE / HVI-style families)
910
- `curvefit/`: polynomial regression + scalable inference + outlier modeling
1011
- `faircoin/`: Beta–Bernoulli baseline benchmarking

examples/air/AGENTS.md

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# AIR Case Study Guide
2+
3+
This case ports the GenJAX AIR estimator experiment from the PLDI'24 programmable-VI artifact.
4+
5+
## Purpose
6+
7+
Evaluate discrete-gradient estimators in an AIR-style latent-variable model:
8+
9+
- `enum`
10+
- `reinforce`
11+
- `mvd`
12+
- `hybrid` (MVD for early `z_pres` sites, ENUM on the final site)
13+
14+
## Key Files
15+
16+
- `core.py`: model/guide definitions, STN transforms, objectives, training/eval utilities
17+
- `main.py`: CLI (`train`, `compare`)
18+
19+
## CLI Commands
20+
21+
```bash
22+
# Quick smoke comparison on synthetic data
23+
pixi run python -m examples.air.main compare --small-config --num-examples 128 --epochs 2
24+
25+
# Train a single estimator and save history
26+
pixi run python -m examples.air.main train \
27+
--estimator enum \
28+
--small-config \
29+
--num-examples 256 \
30+
--epochs 4 \
31+
--history-csv figs/air_enum_history.csv
32+
33+
# Use pre-generated multi-MNIST data (multi_mnist_uint8.npz)
34+
pixi run python -m examples.air.main compare \
35+
--dataset multi-mnist \
36+
--data-path /path/to/multi_mnist_uint8.npz \
37+
--num-examples 2000
38+
```
39+
40+
## Notes
41+
42+
- Default dataset mode is `synthetic` (samples from the AIR prior), so the case runs without Pyro/Torch.
43+
- For `multi-mnist`, supply an existing NPZ file.
44+
- Keep heavy logic in `core.py`; keep `main.py` as orchestration only.
45+
46+
## Tests
47+
48+
- `tests/test_air_example.py`

examples/air/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""AIR case study (GenJAX-only port of the PLDI'24 artifact experiment)."""

0 commit comments

Comments
 (0)