Skip to content

Commit 586d190

Browse files
committed
merged unets, consistent q_dim name in models
1 parent 1b140c8 commit 586d190

File tree

10 files changed

+199
-628
lines changed

10 files changed

+199
-628
lines changed

.gitignore

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,4 @@ __pycache__/
33
imgs/
44
exps/
55
_fisher.py
6-
_set_transformer.py
7-
sgm.egg-info/
8-
main_o.py
6+
_set_transformer.py

README.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ model = sbgm.train.train(
118118
* UNet and transformer score network implementations,
119119
* VP, SubVP and VE SDEs (neural network $\beta(t)$ and $\sigma(t)$ functions are on the list!),
120120
* Multi-modal conditioning (basically just optional parameter and image conditioning methods),
121+
* Checkpointing optimiser and model,
121122
* Multi-device training and sampling.
122123

123124
### Samples
@@ -169,8 +170,4 @@ ODE sampling
169170
primaryClass={stat.ML},
170171
url={https://arxiv.org/abs/2101.09258},
171172
}
172-
```
173-
174-
<!-- <p align="center">
175-
<img src="figs/flowers_eu.png" width="350" title="hover text">
176-
</p> -->
173+
```

configs/grfs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def grfs_config():
1212

1313
# Model
1414
config.model = model = ml_collections.ConfigDict()
15-
model.model_type = "UNetXY"
15+
model.model_type = "UNet"
1616
model.is_biggan = False
1717
model.dim_mults = [1, 1, 1]
1818
model.hidden_size = 128

configs/quijote.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ def quijote_config():
88

99
# Data
1010
config.dataset_name = "quijote"
11-
config.n_pix = 32
11+
config.n_pix = 64
1212

1313
# Model
1414
config.model = model = ml_collections.ConfigDict()

data/quijote.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def quijote(key, n_pix, split=0.5):
5656
key_train, key_valid = jr.split(key)
5757

5858
data_shape = (1, n_pix, n_pix)
59-
context_shape = (1, n_pix, n_pix)
59+
context_shape = None #(1, n_pix, n_pix)
6060
parameter_dim = 5
6161

6262
X, A = get_quijote_data(n_pix)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sbgm"
3-
version = "0.0.11"
3+
version = "0.0.12"
44
description = "Score-based Diffusion models in JAX."
55
readme = "README.md"
66
requires-python ="~=3.12"

sbgm/models/__init__.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from typing import Sequence
22
import equinox as eqx
33
from jaxtyping import Key
4+
import numpy as np
45
import ml_collections
56

67
from ._mixer import Mixer2d
78
from ._mlp import ResidualNetwork
89
from ._unet import UNet
9-
from ._unet_xy import UNetXY
1010

1111

1212
def get_model(
@@ -17,6 +17,12 @@ def get_model(
1717
parameter_dim: int,
1818
config: ml_collections.ConfigDict
1919
) -> eqx.Module:
20+
# Grab channel assuming 'q' is a map like x
21+
if context_shape is not None:
22+
context_channels, *_ = context_shape.shape
23+
else:
24+
context_channels = None
25+
2026
if model_type == "Mixer":
2127
model = Mixer2d(
2228
data_shape,
@@ -26,7 +32,7 @@ def get_model(
2632
mix_hidden_size=config.model.mix_hidden_size,
2733
num_blocks=config.model.num_blocks,
2834
t1=config.t1,
29-
q_dim=context_shape,
35+
q_dim=context_channels,
3036
a_dim=parameter_dim,
3137
key=model_key
3238
)
@@ -42,22 +48,7 @@ def get_model(
4248
num_res_blocks=config.model.num_res_blocks,
4349
attn_resolutions=config.model.attn_resolutions,
4450
final_activation=config.model.final_activation,
45-
a_dim=parameter_dim,
46-
key=model_key
47-
)
48-
if model_type == "UNetXY":
49-
model = UNetXY(
50-
data_shape=data_shape,
51-
is_biggan=config.model.is_biggan,
52-
dim_mults=config.model.dim_mults,
53-
hidden_size=config.model.hidden_size,
54-
heads=config.model.heads,
55-
dim_head=config.model.dim_head,
56-
dropout_rate=config.model.dropout_rate,
57-
num_res_blocks=config.model.num_res_blocks,
58-
attn_resolutions=config.model.attn_resolutions,
59-
final_activation=config.model.final_activation,
60-
q_dim=context_shape[0], # Just grab channel assuming 'q' is a map like x
51+
q_dim=context_channels,
6152
a_dim=parameter_dim,
6253
key=model_key
6354
)
@@ -68,9 +59,11 @@ def get_model(
6859
depth=config.model.depth,
6960
activation=config.model.activation,
7061
dropout_p=config.model.dropout_p,
71-
y_dim=parameter_dim,
62+
q_dim=parameter_dim,
7263
key=model_key
7364
)
65+
if model_type == "CCT":
66+
raise NotImplementedError
7467
if model_type == "DiT":
7568
raise NotImplementedError
7669
return model

sbgm/models/_mlp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
in_size: int,
4141
width_size: int,
4242
depth: int,
43-
y_dim: int,
43+
q_dim: int,
4444
activation: Callable,
4545
dropout_p: float = 0.,
4646
*,
@@ -49,11 +49,11 @@ def __init__(
4949
""" Time-embedding may be necessary """
5050
in_key, *net_keys, out_key = jr.split(key, 2 + depth)
5151
self._in = Linear(
52-
in_size + y_dim + 1, width_size, key=in_key
52+
in_size + q_dim + 1, width_size, key=in_key
5353
)
5454
layers = [
5555
Linear(
56-
width_size + y_dim + 1, width_size, key=_key
56+
width_size + q_dim + 1, width_size, key=_key
5757
)
5858
for _key in net_keys
5959
]

0 commit comments

Comments
 (0)