Skip to content

Commit f72f124

Browse files
committed
DiT, a bit of clean up
1 parent d69c1b9 commit f72f124

File tree

21 files changed

+1051
-410
lines changed

21 files changed

+1051
-410
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ __pycache__/
44
imgs/
55
exps/
66
_fisher.py
7+
sbgm/_sbgm.py
78
_set_transformer.py
89
.pytest_cacche/
910
__unet.py

data/cifar10.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,23 @@ def convert_torch_to_in_memory(dataset):
1818

1919

2020
def cifar10(path: str, key: Key, *, in_memory: bool = True) -> ScalerDataset:
21+
2122
key_train, key_valid = jr.split(key)
2223

2324
n_pix = 32 # Native resolution for CIFAR10
2425
data_shape = (3, n_pix, n_pix)
26+
context_shape = None
2527
parameter_dim = 1
2628
n_classes = 10
2729

28-
scaler = Scaler(x_min=0., x_max=1.)
30+
scaler = Normer()
2931

3032
train_transform = transforms.Compose(
3133
[
3234
transforms.Resize((n_pix, n_pix)),
3335
transforms.RandomHorizontalFlip(),
3436
transforms.ToTensor(),
35-
transforms.Lambda(scaler.forward) # [0,1] -> [-1,1]
37+
transforms.Lambda(scaler.forward)
3638
]
3739
)
3840
valid_transform = transforms.Compose(
@@ -64,7 +66,6 @@ def cifar10(path: str, key: Key, *, in_memory: bool = True) -> ScalerDataset:
6466
At = At.astype(jnp.float32)
6567
Av = Av.astype(jnp.float32)
6668

67-
# process_fn = Scaler(x_min=Xt.min(), x_max=Xt.max())
6869
process_fn = Normer(x_mean=Xt.mean(), x_std=Xt.std())
6970

7071
train_dataloader = InMemoryDataLoader(
@@ -76,10 +77,18 @@ def cifar10(path: str, key: Key, *, in_memory: bool = True) -> ScalerDataset:
7677
process_fn = Scaler(x_min=0., x_max=1.)
7778

7879
train_dataloader = TorchDataLoader(
79-
train_dataset, data_shape, parameter_dim=parameter_dim, key=key_train
80+
train_dataset,
81+
data_shape=data_shape,
82+
context_shape=context_shape,
83+
parameter_dim=parameter_dim,
84+
key=key_train
8085
)
8186
valid_dataloader = TorchDataLoader(
82-
valid_dataset, data_shape, parameter_dim=parameter_dim, key=key_valid
87+
valid_dataset,
88+
data_shape=data_shape,
89+
context_shape=context_shape,
90+
parameter_dim=parameter_dim,
91+
key=key_valid
8392
)
8493

8594
def label_fn(key, n):

data/flowers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99

1010
def flowers(path: str, key: Key, n_pix: int) -> ScalerDataset:
11+
1112
key_train, key_valid = jr.split(key)
13+
1214
data_shape = (3, n_pix, n_pix)
1315
parameter_dim = 1
1416
n_classes = 102
@@ -19,7 +21,6 @@ def flowers(path: str, key: Key, n_pix: int) -> ScalerDataset:
1921
[
2022
transforms.Resize((n_pix, n_pix)),
2123
transforms.RandomCrop(n_pix, padding=4, padding_mode='reflect'),
22-
# transforms.Grayscale(),
2324
transforms.RandomHorizontalFlip(),
2425
transforms.RandomVerticalFlip(),
2526
transforms.ToTensor(),
@@ -30,11 +31,11 @@ def flowers(path: str, key: Key, n_pix: int) -> ScalerDataset:
3031
[
3132
transforms.Resize((n_pix, n_pix)),
3233
transforms.RandomCrop(n_pix, padding=4, padding_mode='reflect'),
33-
# transforms.Grayscale(),
3434
transforms.ToTensor(),
3535
transforms.Lambda(scaler.forward)
3636
]
3737
)
38+
3839
train_dataset = datasets.Flowers102(
3940
os.path.join(path, "datasets/flowers/"),
4041
split="train",

data/grfs.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torchvision import transforms
1010
import powerbox
1111

12-
from .utils import Scaler, ScalerDataset, TorchDataLoader, InMemoryDataLoader
12+
from .utils import Scaler, Normer, ScalerDataset, TorchDataLoader, InMemoryDataLoader
1313

1414
data_dir = "/project/ls-gruen/users/jed.homer/data/fields/"
1515

@@ -114,15 +114,10 @@ def grfs(
114114

115115
print("\nFields data:", X.shape, Q.shape)
116116

117-
min = X.min()
118-
max = X.max()
119-
X = (X - min) / (max - min) # ... -> [0, 1]
117+
X = (X - jnp.mean(X, axis=0)) / jnp.std(X, axis=0) # Standardize fields
118+
Q = (Q - jnp.mean(Q, axis=0)) / jnp.std(Q, axis=0) # Standardize fields
120119

121-
# min = Q.min()
122-
# max = Q.max()
123-
# Q = (Q - min) / (max - min) # ... -> [0, 1]
124-
125-
scaler = Scaler() # [0,1] -> [-1,1]
120+
scaler = Normer() #Scaler() # [0,1] -> [-1,1]
126121

127122
n_train = int(split * n_fields)
128123

@@ -152,18 +147,24 @@ def grfs(
152147
(X[n_train:], Q[n_train:], A[n_train:]), transform=valid_transform
153148
)
154149
train_dataloader = TorchDataLoader(
155-
train_dataset, data_shape, parameter_dim=parameter_dim, key=key_train
150+
train_dataset,
151+
data_shape=data_shape,
152+
context_shape=context_shape,
153+
parameter_dim=parameter_dim,
154+
key=key_train
156155
)
157156
valid_dataloader = TorchDataLoader(
158-
valid_dataset, data_shape, parameter_dim=parameter_dim, key=key_valid
157+
valid_dataset,
158+
data_shape=data_shape,
159+
context_shape=context_shape,
160+
parameter_dim=parameter_dim,
161+
key=key_valid
159162
)
160163

161164
def label_fn(key: Key[jnp.ndarray, "..."], n: int) -> Tuple[Array, Array]:
162165
Q, A = get_grf_labels(n_pix)
163166
ix = jr.choice(key, jnp.arange(len(Q)), (n,))
164-
Q = Q[ix]
165-
A = A[ix]
166-
return Q, A
167+
return Q[ix], A[ix]
167168

168169
return ScalerDataset(
169170
name="grfs",

data/mnist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ def tensor_to_array(tensor: Tensor) -> Array:
1313

1414

1515
def mnist(path:str, key: Key, *, in_memory: bool = True) -> ScalerDataset:
16+
1617
key_train, key_valid = jr.split(key)
18+
1719
n_pix = 28
1820
data_shape = (1, n_pix, n_pix)
1921
parameter_dim = 1

data/moons.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def key_to_seed(key):
1010

1111
def moons(key):
1212
key_train, key_valid = jr.split(key)
13+
1314
data_shape = (2,)
1415
context_shape = None
1516
parameter_dim = 1
@@ -24,17 +25,6 @@ def moons(key):
2425
Yv = Yv[:, jnp.newaxis].astype(jnp.float32)
2526

2627
process_fn = Normer(Xt.mean(), Xt.std())
27-
28-
# min = Xt.min()
29-
# max = Xt.max()
30-
# mean = Xt.mean()
31-
# std = Xt.std()
32-
33-
# (We do need to handle normalisation ourselves though.)
34-
# train_data = (Xt - min) / (max - min)
35-
# valid_data = (Xv - min) / (max - min)
36-
# train_data = (Xt - mean) / std
37-
# valid_data = (Xv - mean) / std
3828

3929
train_dataloader = InMemoryDataLoader(
4030
X=jnp.asarray(Xt), Q=jnp.asarray(Yt), A=None, process_fn=process_fn, key=key_train

0 commit comments

Comments
 (0)