|
9 | 9 | from torchvision import transforms |
10 | 10 | import powerbox |
11 | 11 |
|
12 | | -from .utils import Scaler, ScalerDataset, TorchDataLoader, InMemoryDataLoader |
| 12 | +from .utils import Scaler, Normer, ScalerDataset, TorchDataLoader, InMemoryDataLoader |
13 | 13 |
|
14 | 14 | data_dir = "/project/ls-gruen/users/jed.homer/data/fields/" |
15 | 15 |
|
@@ -114,15 +114,10 @@ def grfs( |
114 | 114 |
|
115 | 115 | print("\nFields data:", X.shape, Q.shape) |
116 | 116 |
|
117 | | - min = X.min() |
118 | | - max = X.max() |
119 | | - X = (X - min) / (max - min) # ... -> [0, 1] |
| 117 | + X = (X - jnp.mean(X, axis=0)) / jnp.std(X, axis=0) # Standardize fields |
| 118 | + Q = (Q - jnp.mean(Q, axis=0)) / jnp.std(Q, axis=0) # Standardize fields |
120 | 119 |
|
121 | | - # min = Q.min() |
122 | | - # max = Q.max() |
123 | | - # Q = (Q - min) / (max - min) # ... -> [0, 1] |
124 | | - |
125 | | - scaler = Scaler() # [0,1] -> [-1,1] |
| 120 | + scaler = Normer() #Scaler() # [0,1] -> [-1,1] |
126 | 121 |
|
127 | 122 | n_train = int(split * n_fields) |
128 | 123 |
|
@@ -152,18 +147,24 @@ def grfs( |
152 | 147 | (X[n_train:], Q[n_train:], A[n_train:]), transform=valid_transform |
153 | 148 | ) |
154 | 149 | train_dataloader = TorchDataLoader( |
155 | | - train_dataset, data_shape, parameter_dim=parameter_dim, key=key_train |
| 150 | + train_dataset, |
| 151 | + data_shape=data_shape, |
| 152 | + context_shape=context_shape, |
| 153 | + parameter_dim=parameter_dim, |
| 154 | + key=key_train |
156 | 155 | ) |
157 | 156 | valid_dataloader = TorchDataLoader( |
158 | | - valid_dataset, data_shape, parameter_dim=parameter_dim, key=key_valid |
| 157 | + valid_dataset, |
| 158 | + data_shape=data_shape, |
| 159 | + context_shape=context_shape, |
| 160 | + parameter_dim=parameter_dim, |
| 161 | + key=key_valid |
159 | 162 | ) |
160 | 163 |
|
161 | 164 | def label_fn(key: Key[jnp.ndarray, "..."], n: int) -> Tuple[Array, Array]: |
162 | 165 | Q, A = get_grf_labels(n_pix) |
163 | 166 | ix = jr.choice(key, jnp.arange(len(Q)), (n,)) |
164 | | - Q = Q[ix] |
165 | | - A = A[ix] |
166 | | - return Q, A |
| 167 | + return Q[ix], A[ix] |
167 | 168 |
|
168 | 169 | return ScalerDataset( |
169 | 170 | name="grfs", |
|
0 commit comments