Skip to content

Commit 94919de

Browse files
committed
Fix FNO and rollout
- Update FNO and EncoderProcessorDecoder to expect no temporal dim - Update return from rollout to stack on new first dimension for rollout windows
1 parent 2ff6ddf commit 94919de

4 files changed

Lines changed: 106 additions & 17 deletions

File tree

notebooks/00_exploration.ipynb

Lines changed: 92 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,11 @@
2828
"outputs": [],
2929
"source": [
3030
"\n",
31-
"from autoemulate.simulations.advection_diffusion import AdvectionDiffusion\n",
31+
"from autoemulate.simulations.reaction_diffusion import ReactionDiffusion as Sim\n",
3232
"\n",
33-
"sim = AdvectionDiffusion(return_timeseries=True, log_level=\"error\")\n",
33+
"sim = Sim(return_timeseries=True, log_level=\"error\")\n",
3434
"\n",
35-
"def generate_split(\n",
36-
" simulator: AdvectionDiffusion, n_train: int = 4, n_valid: int = 2, n_test: int = 2\n",
37-
"):\n",
35+
"def generate_split(simulator: Sim, n_train: int = 1, n_valid: int = 1, n_test: int = 1):\n",
3836
" \"\"\"Generate training, validation, and test splits from the simulator.\"\"\"\n",
3937
" train = simulator.forward_samples_spatiotemporal(n_train)\n",
4038
" valid = simulator.forward_samples_spatiotemporal(n_valid)\n",
@@ -62,8 +60,10 @@
6260
"source": [
6361
"from auto_cast.data.datamodule import SpatioTemporalDataModule\n",
6462
"\n",
63+
"n_steps_input = 4\n",
64+
"n_steps_output = 1\n",
6565
"datamodule = SpatioTemporalDataModule(\n",
66-
" data=combined_data, data_path=None, n_steps_input=4, n_steps_output=1, batch_size=16\n",
66+
" data=combined_data, data_path=None, n_steps_input=n_steps_input, n_steps_output=n_steps_output, batch_size=16\n",
6767
")"
6868
]
6969
},
@@ -100,8 +100,13 @@
100100
"from auto_cast.models.encoder_processor_decoder import EncoderProcessorDecoder\n",
101101
"from auto_cast.nn.fno import FNOProcessor\n",
102102
"\n",
103+
"batch = next(iter(datamodule.train_dataloader()))\n",
104+
"n_channels = batch.input_fields.shape[-1]\n",
103105
"processor = FNOProcessor(\n",
104-
" in_channels=1, out_channels=1, n_modes=(16, 16, 1), hidden_channels=64\n",
106+
" in_channels=n_channels * n_steps_input,\n",
107+
" out_channels=n_channels * n_steps_output,\n",
108+
" n_modes=(16, 16),\n",
109+
" hidden_channels=64,\n",
105110
")\n",
106111
"encoder = PermuteConcat(with_constants=False)\n",
107112
"decoder = ChannelsLast()\n",
@@ -113,30 +118,41 @@
113118
]
114119
},
115120
{
116-
"cell_type": "markdown",
121+
"cell_type": "code",
122+
"execution_count": null,
117123
"id": "8",
118124
"metadata": {},
125+
"outputs": [],
126+
"source": [
127+
"model(batch).shape"
128+
]
129+
},
130+
{
131+
"cell_type": "markdown",
132+
"id": "9",
133+
"metadata": {},
119134
"source": [
120135
"### Run trainer\n"
121136
]
122137
},
123138
{
124139
"cell_type": "code",
125140
"execution_count": null,
126-
"id": "9",
141+
"id": "10",
127142
"metadata": {},
128143
"outputs": [],
129144
"source": [
130145
"import lightning as L\n",
131146
"\n",
132-
"device = \"mps\" # \"cpu\"\n",
133-
"trainer = L.Trainer(max_epochs=5, accelerator=device, log_every_n_steps=10)\n",
147+
"# device = \"mps\" # \"cpu\"\n",
148+
"device = \"cpu\"\n",
149+
"trainer = L.Trainer(max_epochs=1, accelerator=device, log_every_n_steps=10)\n",
134150
"trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())"
135151
]
136152
},
137153
{
138154
"cell_type": "markdown",
139-
"id": "10",
155+
"id": "11",
140156
"metadata": {},
141157
"source": [
142158
"### Run the evaluation"
@@ -145,12 +161,75 @@
145161
{
146162
"cell_type": "code",
147163
"execution_count": null,
148-
"id": "11",
164+
"id": "12",
149165
"metadata": {},
150166
"outputs": [],
151167
"source": [
152168
"trainer.test(model, datamodule.test_dataloader())"
153169
]
170+
},
171+
{
172+
"cell_type": "markdown",
173+
"id": "13",
174+
"metadata": {},
175+
"source": [
176+
"### Example rollout"
177+
]
178+
},
179+
{
180+
"cell_type": "code",
181+
"execution_count": null,
182+
"id": "14",
183+
"metadata": {},
184+
"outputs": [],
185+
"source": [
186+
"# A single element is the full trajectory\n",
187+
"batch = next(iter(datamodule.rollout_test_dataloader()))"
188+
]
189+
},
190+
{
191+
"cell_type": "code",
192+
"execution_count": null,
193+
"id": "15",
194+
"metadata": {},
195+
"outputs": [],
196+
"source": [
197+
"# First n_steps_input are inputs\n",
198+
"print(batch.input_fields.shape)\n",
199+
"# Remaining n_steps_output are outputs\n",
200+
"print(batch.output_fields.shape)"
201+
]
202+
},
203+
{
204+
"cell_type": "code",
205+
"execution_count": null,
206+
"id": "16",
207+
"metadata": {},
208+
"outputs": [],
209+
"source": [
210+
"# Run rollout on one trajectory\n",
211+
"preds, trues = model.rollout(batch)"
212+
]
213+
},
214+
{
215+
"cell_type": "code",
216+
"execution_count": null,
217+
"id": "17",
218+
"metadata": {},
219+
"outputs": [],
220+
"source": [
221+
"print(preds.shape)"
222+
]
223+
},
224+
{
225+
"cell_type": "code",
226+
"execution_count": null,
227+
"id": "18",
228+
"metadata": {},
229+
"outputs": [],
230+
"source": [
231+
"print(trues.shape)\n"
232+
]
154233
}
155234
],
156235
"metadata": {

src/auto_cast/decoders/channels_last.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ class ChannelsLast(Decoder):
99

1010
def forward(self, x: Tensor) -> Tensor:
1111
"""Forward pass through the ChannelsLast decoder."""
12-
return rearrange(x, "b c t w h -> b t w h c")
12+
return rearrange(x, "b c w h -> b 1 w h c")

src/auto_cast/encoders/permute_concat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ def forward(self, batch: Batch) -> Tensor:
2626
scalars = rearrange(scalars, "b c -> b c 1 1 1")
2727
scalars = scalars.expand(b, -1, t, w, h)
2828
x = torch.cat([x, scalars], dim=1)
29-
return x
29+
return rearrange(x, "b c t w h -> b (c t) w h")

src/auto_cast/processors/rollout.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,19 @@ def rollout(self, batch: BatchT) -> RolloutOutput:
4141

4242
current_batch = self._advance_batch(current_batch, next_inputs, self.stride)
4343

44-
predictions = torch.stack(pred_outs)
44+
# Stack along a new axis after batch representing number of rollout windows R
45+
# Each window R contains n_steps_output time steps T.
46+
# For example with:
47+
# - batch size B=16
48+
# - rollout windows R=10
49+
# - n_steps_output T=2 per window,
50+
# - spatial dimensions W=16, H=8
51+
# - channels C=2
52+
# The output shapes will be:
53+
# (B, R, T, W, H, C) = (16, 10, 2, 16, 8, 2)
54+
predictions = torch.stack(pred_outs, dim=1) # (B, R, T, spatial, C)
4555
if true_outs:
46-
return predictions, torch.stack(true_outs)
56+
return predictions, torch.stack(true_outs, dim=1) # (B, R, T, spatial, C)
4757
return predictions, None
4858

4959
@abstractmethod

0 commit comments

Comments
 (0)