Skip to content

Commit 399c4d7

Browse files
committed
update
1 parent 70727f4 commit 399c4d7

File tree

11 files changed

+198
-94
lines changed

11 files changed

+198
-94
lines changed

configs/quijote.py

Lines changed: 86 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,107 @@
11
import ml_collections
2+
import jax.numpy as jnp
3+
4+
# def quijote_config():
5+
# config = ml_collections.ConfigDict()
6+
7+
# config.seed = 0
8+
9+
# # Data
10+
# config.dataset_name = "quijote"
11+
# config.n_pix = 64
12+
13+
# # Model
14+
# config.model = model = ml_collections.ConfigDict()
15+
# model.model_type = "UNet"
16+
# model.is_biggan = False
17+
# model.dim_mults = [1, 1, 1]
18+
# model.hidden_size = 128
19+
# model.heads = 4
20+
# model.dim_head = 64
21+
# model.dropout_rate = 0.3
22+
# model.num_res_blocks = 2
23+
# model.attn_resolutions = [8, 32, 64]
24+
# model.final_activation = None
25+
26+
# # SDE
27+
# config.sde = sde = ml_collections.ConfigDict()
28+
# sde.sde = "VP"
29+
# sde.t1 = 8.
30+
# sde.t0 = 1e-5
31+
# sde.dt = 0.1
32+
# sde.beta_integral = lambda t: t
33+
# # sde: SDE = VPSDE(beta_integral, dt=dt, t0=t0, t1=t1)
34+
35+
# # Sampling
36+
# config.use_ema = False
37+
# config.sample_size = 5
38+
# config.exact_logp = False
39+
# config.ode_sample = True
40+
# config.eu_sample = True
41+
42+
# # Optimisation hyperparameters
43+
# config.start_step = 0
44+
# config.n_steps = 1_000_000
45+
# config.lr = 1e-4
46+
# config.batch_size = 32
47+
# config.sample_and_save_every = 1_000
48+
# config.opt = "adabelief"
49+
# config.opt_kwargs = {}
50+
# config.num_workers = 8
51+
52+
# # Other
53+
# config.cmap = "gnuplot"
54+
55+
# return config
256

357

458
def quijote_config():
559
config = ml_collections.ConfigDict()
660

7-
config.seed = 0
61+
config.seed = 0
862

963
# Data
10-
config.dataset_name = "quijote"
11-
config.n_pix = 64
64+
config.dataset_name = "quijote"
65+
config.n_pix = 64
1266

1367
# Model
1468
config.model = model = ml_collections.ConfigDict()
15-
model.model_type = "UNet"
16-
model.is_biggan = False
17-
model.dim_mults = [1, 1, 1]
18-
model.hidden_size = 128
19-
model.heads = 4
20-
model.dim_head = 64
21-
model.dropout_rate = 0.3
22-
model.num_res_blocks = 2
23-
model.attn_resolutions = [8, 32, 64]
24-
model.final_activation = None
69+
model.model_type = "Mixer"
70+
model.patch_size = 2
71+
model.hidden_size = 1024
72+
model.mix_patch_size = 512
73+
model.mix_hidden_size = 1024
74+
model.num_blocks = 5
75+
model.t1 = 10.
76+
model.final_activation = None #"tanh"
2577

2678
# SDE
2779
config.sde = sde = ml_collections.ConfigDict()
28-
sde.sde = "VP"
29-
sde.t1 = 8.
30-
sde.t0 = 1e-5
31-
sde.dt = 0.1
32-
sde.beta_integral = lambda t: t
33-
# sde: SDE = VPSDE(beta_integral, dt=dt, t0=t0, t1=t1)
80+
sde.sde = "VP"
81+
sde.t1 = model.t1
82+
sde.t0 = 0.
83+
sde.dt = 0.1
84+
sde.beta_integral = lambda t: t
85+
sde.weight_fn = lambda t: 1. - jnp.exp(-sde.beta_integral(t))
3486

3587
# Sampling
36-
config.use_ema = False
37-
config.sample_size = 5
38-
config.exact_logp = False
39-
config.ode_sample = True
40-
config.eu_sample = True
88+
config.use_ema = False
89+
config.sample_size = 5
90+
config.exact_logp = False
91+
config.ode_sample = True
92+
config.eu_sample = True
4193

4294
# Optimisation hyperparameters
43-
config.start_step = 0
44-
config.n_steps = 1_000_000
45-
config.lr = 1e-4
46-
config.batch_size = 32
47-
config.sample_and_save_every = 1_000
48-
config.opt = "adabelief"
49-
config.opt_kwargs = {}
50-
config.num_workers = 8
95+
config.start_step = 0
96+
config.n_steps = 1_000_000
97+
config.lr = 1e-4
98+
config.batch_size = 32
99+
config.sample_and_save_every = 5_000
100+
config.opt = "adabelief"
101+
config.opt_kwargs = {}
102+
config.num_workers = 8
51103

52104
# Other
53-
config.cmap = "gnuplot"
105+
config.cmap = "gist_stern"
54106

55-
return config
107+
return config

data/quijote.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_quijote_labels() -> Array:
5252
return Q
5353

5454

55-
def quijote(key, n_pix, split=0.5):
55+
def quijote(key, n_pix, split=0.9):
5656
key_train, key_valid = jr.split(key)
5757

5858
data_shape = (1, n_pix, n_pix)
@@ -65,32 +65,39 @@ def quijote(key, n_pix, split=0.5):
6565

6666
min = X.min()
6767
max = X.max()
68-
X = (X - min) / (max - min) # ... -> [0, 1]
68+
# X = (X - min) / (max - min) # ... -> [0, 1]
69+
# X = 2.0 * (X - min) / (max - min) - 1.0 # ... -> [-1, 1]
70+
X = (X - X.mean()) / X.std()
6971

7072
# min = Q.min()
7173
# max = Q.max()
7274
# Q = (Q - min) / (max - min) # ... -> [0, 1]
7375

74-
scaler = Scaler() # [0,1] -> [-1,1]
76+
n_train = int(split * len(X))
7577

76-
train_transform = transforms.Compose(
77-
[
78-
transforms.RandomHorizontalFlip(),
79-
transforms.RandomVerticalFlip(),
80-
transforms.Lambda(scaler.forward)
81-
]
82-
)
83-
valid_transform = transforms.Compose(
84-
[transforms.Lambda(scaler.forward)]
85-
)
78+
# scaler = Scaler() # [0,1] -> [-1,1]
8679

87-
n_train = int(split * len(X))
88-
train_dataset = MapDataset(
89-
(X[:n_train], A[:n_train]), transform=train_transform
90-
)
91-
valid_dataset = MapDataset(
92-
(X[n_train:], A[n_train:]), transform=valid_transform
93-
)
80+
# train_transform = transforms.Compose(
81+
# [
82+
# transforms.RandomHorizontalFlip(),
83+
# transforms.RandomVerticalFlip(),
84+
# # transforms.Lambda(scaler.forward)
85+
# ]
86+
# )
87+
# valid_transform = transforms.Compose(
88+
# [
89+
# transforms.RandomHorizontalFlip(),
90+
# transforms.RandomVerticalFlip(),
91+
# # transforms.Lambda(scaler.forward)
92+
# ]
93+
# )
94+
95+
# train_dataset = MapDataset(
96+
# (X[:n_train], A[:n_train]), transform=train_transform
97+
# )
98+
# valid_dataset = MapDataset(
99+
# (X[n_train:], A[n_train:]), transform=valid_transform
100+
# )
94101
# train_dataloader = TorchDataLoader(
95102
# train_dataset,
96103
# data_shape=data_shape,
@@ -128,6 +135,6 @@ def label_fn(key, n):
128135
data_shape=data_shape,
129136
context_shape=context_shape,
130137
parameter_dim=parameter_dim,
131-
scaler=scaler,
138+
scaler=None, #scaler,
132139
label_fn=label_fn
133140
)

examples/run_from_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def main():
5151
dataset,
5252
config,
5353
reload_opt_state=reload_opt_state,
54+
plot_train_data=True,
5455
sharding=sharding,
5556
save_dir=root_dir
5657
)

paper/paper.bib

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,4 +303,14 @@ @misc{batzolis
303303
archivePrefix={arXiv},
304304
primaryClass={cs.LG},
305305
url={https://arxiv.org/abs/2111.13606},
306+
}
307+
308+
@misc{mixer,
309+
title={MLP-Mixer: An all-MLP Architecture for Vision},
310+
author={Ilya Tolstikhin and Neil Houlsby and Alexander Kolesnikov and Lucas Beyer and Xiaohua Zhai and Thomas Unterthiner and Jessica Yung and Andreas Steiner and Daniel Keysers and Jakob Uszkoreit and Mario Lucic and Alexey Dosovitskiy},
311+
year={2021},
312+
eprint={2105.01601},
313+
archivePrefix={arXiv},
314+
primaryClass={cs.CV},
315+
url={https://arxiv.org/abs/2105.01601},
306316
}

0 commit comments

Comments
 (0)