Skip to content

Commit 8cf2e18

Browse files
committed
Add config to support different stride at train and eval
1 parent ee9d0c8 commit 8cf2e18

6 files changed

Lines changed: 28 additions & 34 deletions

File tree

configs/encoder_processor_decoder.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ output:
1414
training:
1515
n_steps_input: 1
1616
n_steps_output: 4
17-
stride: 4
17+
train_stride: 1
18+
eval_stride: 4
19+
stride: 4 # Default stride for backward compatibility
1820
autoencoder_checkpoint: null
1921
freeze_autoencoder: false
2022

configs/model/encoder_processor_decoder.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ defaults:
66

77
learning_rate: 0.0001
88
train_processor_only: true
9+
stride: ${training.train_stride}
10+
eval_stride: ${training.eval_stride}
911
teacher_forcing_ratio: 0.5
1012
max_rollout_steps: 10
1113
loss_func:

configs/processor/diffusion.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ schedule:
1313
_target_: azula.noise.VPSchedule
1414
denoiser_type: karras
1515
teacher_forcing_ratio: 0.0
16-
stride: ${training.stride}
16+
stride: ${training.train_stride}
1717
max_rollout_steps: ${training.n_steps_output}
1818
learning_rate: 0.0001
1919
n_steps_output: null

configs/processor/flow_matching.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
_target_: autocast.processors.flow_matching.FlowMatchingProcessor
2-
stride: ${training.stride}
2+
stride: ${training.train_stride}
33
teacher_forcing_ratio: 0.0
44
max_rollout_steps: ${training.n_steps_output}
55
learning_rate: 0.0001
@@ -11,7 +11,7 @@ backbone:
1111
in_channels: null
1212
out_channels: null
1313
cond_channels: null
14-
mod_features: 200
14+
mod_features: 256
1515
hid_channels: [32, 64, 128]
1616
hid_blocks: [2, 2, 2]
1717
spatial: 2

configs/processor/fno.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,7 @@ out_channels: null
44
n_modes: [16, 16]
55
hidden_channels: 64
66
n_layers: 4
7-
learning_rate: 0.001
7+
learning_rate: 0.0001
8+
stride: ${training.train_stride}
9+
teacher_forcing_ratio: 0.0
10+
max_rollout_steps: ${training.n_steps_output}

notebooks/00_quickstart.ipynb

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@
3939
"from autocast.data.datamodule import SpatioTemporalDataModule, TheWellDataModule\n",
4040
"from autocast.metrics.spatiotemporal import MAE, MSE, RMSE\n",
4141
"\n",
42-
"THE_WELL = True\n",
42+
"THE_WELL = False\n",
4343
"n_steps_input = 1\n",
4444
"n_steps_output = 4\n",
45-
"stride = n_steps_output"
45+
"train_stride = 1\n",
46+
"eval_stride = 4"
4647
]
4748
},
4849
{
@@ -62,9 +63,9 @@
6263
"source": [
6364
"\n",
6465
"if not THE_WELL:\n",
65-
" # simulation_name = \"reaction_diffusion\"\n",
66+
" simulation_name = \"reaction_diffusion\"\n",
6667
" # simulation_name = \"advection_diffusion\"\n",
67-
" simulation_name = \"advection_diffusion_multichannel\"\n",
68+
" # simulation_name = \"advection_diffusion_multichannel\"\n",
6869
"\n",
6970
" if simulation_name == \"advection_diffusion_multichannel\":\n",
7071
" # Override to use multichannel version\n",
@@ -107,8 +108,8 @@
107108
" pickle.dump(combined_data, f)\n",
108109
"\n",
109110
" datamodule = SpatioTemporalDataModule(\n",
110-
" data=combined_data,\n",
111-
" data_path=None,\n",
111+
" # data=combined_data,\n",
112+
" data_path=\"../datasets/reaction_diffusion\",\n",
112113
" n_steps_input=n_steps_input,\n",
113114
" n_steps_output=n_steps_output,\n",
114115
" stride=n_steps_output,\n",
@@ -221,7 +222,7 @@
221222
" schedule=VPSchedule(), # accepted for API parity, not used internally\n",
222223
" n_steps_output=n_steps_output,\n",
223224
" n_channels_out=n_channels,\n",
224-
" stride=n_steps_output,\n",
225+
" stride=train_stride,\n",
225226
" flow_ode_steps=4,\n",
226227
" )\n",
227228
"else:\n",
@@ -232,21 +233,18 @@
232233
" schedule=VPSchedule(),\n",
233234
" n_steps_output=n_steps_output,\n",
234235
" n_channels_out=n_channels,\n",
235-
" stride=n_steps_output,\n",
236236
" )\n",
237237
"\n",
238238
"encoder = IdentityEncoder()\n",
239239
"decoder = IdentityDecoder()\n",
240240
"model = EncoderProcessorDecoder(\n",
241241
" encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),\n",
242242
" processor=processor,\n",
243-
" stride=stride,\n",
244243
" train_processor_only=True,\n",
245-
" # learning_rate=1e-5,\n",
246244
" learning_rate=1e-4,\n",
247-
" #test_metrics = [MSE(), MAE(), RMSE()]\n",
245+
" test_metrics = [MSE(), MAE(), RMSE()]\n",
248246
")\n",
249-
"maybe_watch_model(logger, model, watch)"
247+
"maybe_watch_model(logger, model, watch)\n"
250248
]
251249
},
252250
{
@@ -341,21 +339,10 @@
341339
"id": "19",
342340
"metadata": {},
343341
"outputs": [],
344-
"source": [
345-
"# Set max rollout steps based on batch output shape\n",
346-
"# model.max_rollout_steps = batch.output_fields.shape[1] // (n_steps_output * 2)\n",
347-
"model.max_rollout_steps = 20"
348-
]
349-
},
350-
{
351-
"cell_type": "code",
352-
"execution_count": null,
353-
"id": "20",
354-
"metadata": {},
355-
"outputs": [],
356342
"source": [
357343
"# Run rollout on one trajectory\n",
358-
"preds, trues = model.rollout(batch, free_running_only=True)\n",
344+
"model.max_rollout_steps = 20\n",
345+
"preds, trues = model.rollout(batch, stride=eval_stride, free_running_only=True)\n",
359346
"\n",
360347
"print(preds.shape)\n",
361348
"assert trues is not None\n",
@@ -365,7 +352,7 @@
365352
{
366353
"cell_type": "code",
367354
"execution_count": null,
368-
"id": "21",
355+
"id": "20",
369356
"metadata": {},
370357
"outputs": [],
371358
"source": [
@@ -374,7 +361,7 @@
374361
"assert trues is not None\n",
375362
"assert preds.shape == trues.shape\n",
376363
"mse = MSE()\n",
377-
"mse_error_spatial = mse.score(preds, trues)\n",
364+
"mse_error_spatial = mse(preds, trues)\n",
378365
"mse_error = mse(preds, trues)\n",
379366
"print(\"MSE spatial has shape (B,T,C):\", mse_error_spatial.shape)\n",
380367
"print(\"MSE overall is a single scalar:\", mse_error.shape)"
@@ -383,7 +370,7 @@
383370
{
384371
"cell_type": "code",
385372
"execution_count": null,
386-
"id": "22",
373+
"id": "21",
387374
"metadata": {},
388375
"outputs": [],
389376
"source": [
@@ -415,7 +402,7 @@
415402
{
416403
"cell_type": "code",
417404
"execution_count": null,
418-
"id": "23",
405+
"id": "22",
419406
"metadata": {},
420407
"outputs": [],
421408
"source": []

0 commit comments

Comments
 (0)