Skip to content

Commit 7e78b0b

Browse files
committed
grf log prob keyword
1 parent 3d95e27 commit 7e78b0b

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ __unet.py
1111
cifar10.ipynb
1212
grfs.ipynb
1313
simple.ipynb
14-
mnist_clouds.py
14+
mnist_clouds.py
15+
ldm.py

examples/grfs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def diffuse(x, t, eps):
243243
key, key_L = jr.split(key)
244244

245245
log_likelihood_fn = sbgm.ode.get_log_likelihood_fn(
246-
model, sde, dataset.data_shape, exact_log_prob=True
246+
model, sde, data_shape=dataset.data_shape, exact_log_prob=True
247247
)
248248
L_X = log_likelihood_fn(X[0], Q[0], A[0], key_L)
249249

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sbgm"
3-
version = "0.0.33"
3+
version = "0.0.34"
44
description = "Score-based Diffusion models in JAX."
55
readme = "README.md"
66
requires-python ="~=3.12"

0 commit comments

Comments
 (0)