Skip to content

Commit 3eadf11

Browse files
committed
exact_log_prob keyword, rademacher noise, examples install packages, version
1 parent 2787055 commit 3eadf11

File tree

19 files changed

+415
-63
lines changed

19 files changed

+415
-63
lines changed

README.md

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,21 @@ Install via
7777
pip install sbgm
7878
```
7979

80-
See [examples](https://github.com/homerjed/sbgm/tree/main/examples).
80+
and for the [examples](https://github.com/homerjed/sbgm/tree/main/examples), run
8181

82-
To run on the `cifar10` image dataset, try something like
82+
```
83+
pip install .[examples]
84+
```
85+
86+
To fit a diffusion model to the `cifar10` image dataset, try something like
8387

8488
```python
8589
import sbgm
8690
import data
8791
import configs
8892

89-
datasets_path = "."
90-
root_dir = "."
93+
datasets_path = "./"
94+
root_dir = "./"
9195

9296
config = configs.cifar10_config()
9397

@@ -129,7 +133,7 @@ model = sbgm.train.train(
129133
* UNet and transformer score network implementations,
130134
* VP, SubVP and VE SDEs (neural network $\beta(t)$ and $\sigma(t)$ functions are on the list!),
131135
* Multi-modal conditioning (basically just optional parameter and image conditioning methods),
132-
* Checkpointing optimiser and model,
136+
* Checkpointing for optimiser and model,
133137
* Multi-device training and sampling.
134138

135139
### Samples

configs/cifar10.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def cifar10_config():
3535
# Sampling
3636
config.use_ema = False
3737
config.sample_size = 5
38-
config.exact_logp = False
38+
config.exact_log_prob = False
3939
config.ode_sample = True
4040
config.eu_sample = True
4141

configs/flowers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def flowers_config():
3535

3636
# Sampling
3737
config.sample_size = 5
38-
config.exact_logp = False
38+
config.exact_log_prob = False
3939
config.ode_sample = True
4040
config.eu_sample = True
4141

configs/grfs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def grfs_config():
3434
# Sampling
3535
config.use_ema = False
3636
config.sample_size = 5
37-
config.exact_logp = False
37+
config.exact_log_prob = False
3838
config.ode_sample = True
3939
config.eu_sample = True
4040

configs/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def mnist_config():
3434

3535
# Sampling
3636
config.sample_size = 8
37-
config.exact_logp = False
37+
config.exact_log_prob = False
3838
config.ode_sample = True
3939
config.eu_sample = True
4040
config.use_ema = False

configs/moons.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def moons_config():
3131
# Sampling
3232
config.use_ema = True
3333
config.sample_size = 64 # Squared in sampling
34-
config.exact_logp = True
34+
config.exact_log_prob = True
3535
config.ode_sample = False
3636
config.eu_sample = False
3737

configs/quijote.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def quijote_config():
3434
# Sampling
3535
config.use_ema = False
3636
config.sample_size = 5
37-
config.exact_logp = False
37+
config.exact_log_prob = False
3838
config.ode_sample = True
3939
config.eu_sample = True
4040

examples/cifar10.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
},
3737
{
3838
"cell_type": "code",
39-
"execution_count": 3,
39+
"execution_count": null,
4040
"metadata": {},
4141
"outputs": [],
4242
"source": [
@@ -78,7 +78,7 @@
7878
"# Sampling\n",
7979
"use_ema = False\n",
8080
"sample_size = 5 # Squared for a grid\n",
81-
"exact_logp = False\n",
81+
"exact_log_prob = False\n",
8282
"ode_sample = True # Sample the ODE during training\n",
8383
"eu_sample = True # Euler-Maruyama sample the SDE during training\n",
8484
"\n",
@@ -459,7 +459,7 @@
459459
},
460460
{
461461
"cell_type": "code",
462-
"execution_count": 13,
462+
"execution_count": null,
463463
"metadata": {},
464464
"outputs": [
465465
{
@@ -475,7 +475,7 @@
475475
],
476476
"source": [
477477
"log_likelihood_fn = sbgm.ode.get_log_likelihood_fn(\n",
478-
" model, sde, dataset.data_shape, exact_logp=False, n_eps=64\n",
478+
" model, sde, data_shape=dataset.data_shape, exact_log_prob=False, n_eps=64\n",
479479
")\n",
480480
"L_X = log_likelihood_fn(X[0], None, A[0], key)\n",
481481
"\n",

examples/grfs.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
},
4545
{
4646
"cell_type": "code",
47-
"execution_count": 3,
47+
"execution_count": null,
4848
"metadata": {},
4949
"outputs": [],
5050
"source": [
@@ -82,7 +82,7 @@
8282
"# Sampling\n",
8383
"use_ema = False\n",
8484
"sample_size = 5 # Squared for a grid\n",
85-
"exact_logp = False\n",
85+
"exact_log_prob = False\n",
8686
"ode_sample = True # Sample the ODE during training\n",
8787
"eu_sample = True # Euler-Maruyama sample the SDE during training\n",
8888
"\n",
@@ -415,7 +415,7 @@
415415
},
416416
{
417417
"cell_type": "code",
418-
"execution_count": 11,
418+
"execution_count": null,
419419
"metadata": {},
420420
"outputs": [
421421
{
@@ -433,7 +433,7 @@
433433
"key, key_L = jr.split(key)\n",
434434
"\n",
435435
"log_likelihood_fn = sbgm.ode.get_log_likelihood_fn(\n",
436-
" model, sde, dataset.data_shape, exact_logp=True\n",
436+
" model, sde, data_shape=dataset.data_shape, exact_log_prob=True\n",
437437
")\n",
438438
"L_X = log_likelihood_fn(X[0], Q[0], A[0], key_L)\n",
439439
"\n",

examples/grfs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
# Sampling
5757
use_ema = False
5858
sample_size = 5 # Squared for a grid
59-
exact_logp = False
59+
exact_log_prob = False
6060
ode_sample = True # Sample the ODE during training
6161
eu_sample = True # Euler-Maruyama sample the SDE during training
6262

@@ -245,7 +245,7 @@ def diffuse(x, t, eps):
245245
key, key_L = jr.split(key)
246246

247247
log_likelihood_fn = sbgm.ode.get_log_likelihood_fn(
248-
model, sde, dataset.data_shape, exact_logp=True
248+
model, sde, dataset.data_shape, exact_log_prob=True
249249
)
250250
L_X = log_likelihood_fn(X[0], Q[0], A[0], key_L)
251251

0 commit comments

Comments
 (0)