|
39 | 39 | "from autocast.data.datamodule import SpatioTemporalDataModule, TheWellDataModule\n", |
40 | 40 | "from autocast.metrics.spatiotemporal import MAE, MSE, RMSE\n", |
41 | 41 | "\n", |
42 | | - "THE_WELL = True\n", |
| 42 | + "THE_WELL = False\n", |
43 | 43 | "n_steps_input = 1\n", |
44 | 44 | "n_steps_output = 4\n", |
45 | | - "stride = n_steps_output" |
| 45 | + "train_stride = 1\n", |
| 46 | + "eval_stride = 4" |
46 | 47 | ] |
47 | 48 | }, |
48 | 49 | { |
|
62 | 63 | "source": [ |
63 | 64 | "\n", |
64 | 65 | "if not THE_WELL:\n", |
65 | | - " # simulation_name = \"reaction_diffusion\"\n", |
| 66 | + " simulation_name = \"reaction_diffusion\"\n", |
66 | 67 | " # simulation_name = \"advection_diffusion\"\n", |
67 | | - " simulation_name = \"advection_diffusion_multichannel\"\n", |
| 68 | + " # simulation_name = \"advection_diffusion_multichannel\"\n", |
68 | 69 | "\n", |
69 | 70 | " if simulation_name == \"advection_diffusion_multichannel\":\n", |
70 | 71 | " # Override to use multichannel version\n", |
|
107 | 108 | " pickle.dump(combined_data, f)\n", |
108 | 109 | "\n", |
109 | 110 | " datamodule = SpatioTemporalDataModule(\n", |
110 | | - " data=combined_data,\n", |
111 | | - " data_path=None,\n", |
| 111 | + " # data=combined_data,\n", |
| 112 | + " data_path=\"../datasets/reaction_diffusion\",\n", |
112 | 113 | " n_steps_input=n_steps_input,\n", |
113 | 114 | " n_steps_output=n_steps_output,\n", |
114 | 115 | " stride=n_steps_output,\n", |
|
221 | 222 | " schedule=VPSchedule(), # accepted for API parity, not used internally\n", |
222 | 223 | " n_steps_output=n_steps_output,\n", |
223 | 224 | " n_channels_out=n_channels,\n", |
224 | | - " stride=n_steps_output,\n", |
| 225 | + " stride=train_stride,\n", |
225 | 226 | " flow_ode_steps=4,\n", |
226 | 227 | " )\n", |
227 | 228 | "else:\n", |
|
232 | 233 | " schedule=VPSchedule(),\n", |
233 | 234 | " n_steps_output=n_steps_output,\n", |
234 | 235 | " n_channels_out=n_channels,\n", |
235 | | - " stride=n_steps_output,\n", |
236 | 236 | " )\n", |
237 | 237 | "\n", |
238 | 238 | "encoder = IdentityEncoder()\n", |
239 | 239 | "decoder = IdentityDecoder()\n", |
240 | 240 | "model = EncoderProcessorDecoder(\n", |
241 | 241 | " encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),\n", |
242 | 242 | " processor=processor,\n", |
243 | | - " stride=stride,\n", |
244 | 243 | " train_processor_only=True,\n", |
245 | | - " # learning_rate=1e-5,\n", |
246 | 244 | " learning_rate=1e-4,\n", |
247 | | - " #test_metrics = [MSE(), MAE(), RMSE()]\n", |
| 245 | + " test_metrics = [MSE(), MAE(), RMSE()]\n", |
248 | 246 | ")\n", |
249 | | - "maybe_watch_model(logger, model, watch)" |
| 247 | + "maybe_watch_model(logger, model, watch)\n" |
250 | 248 | ] |
251 | 249 | }, |
252 | 250 | { |
|
341 | 339 | "id": "19", |
342 | 340 | "metadata": {}, |
343 | 341 | "outputs": [], |
344 | | - "source": [ |
345 | | - "# Set max rollout steps based on batch output shape\n", |
346 | | - "# model.max_rollout_steps = batch.output_fields.shape[1] // (n_steps_output * 2)\n", |
347 | | - "model.max_rollout_steps = 20" |
348 | | - ] |
349 | | - }, |
350 | | - { |
351 | | - "cell_type": "code", |
352 | | - "execution_count": null, |
353 | | - "id": "20", |
354 | | - "metadata": {}, |
355 | | - "outputs": [], |
356 | 342 | "source": [ |
357 | 343 | "# Run rollout on one trajectory\n", |
358 | | - "preds, trues = model.rollout(batch, free_running_only=True)\n", |
| 344 | + "model.max_rollout_steps = 20\n", |
| 345 | + "preds, trues = model.rollout(batch, stride=eval_stride, free_running_only=True)\n", |
359 | 346 | "\n", |
360 | 347 | "print(preds.shape)\n", |
361 | 348 | "assert trues is not None\n", |
|
365 | 352 | { |
366 | 353 | "cell_type": "code", |
367 | 354 | "execution_count": null, |
368 | | - "id": "21", |
| 355 | + "id": "20", |
369 | 356 | "metadata": {}, |
370 | 357 | "outputs": [], |
371 | 358 | "source": [ |
|
374 | 361 | "assert trues is not None\n", |
375 | 362 | "assert preds.shape == trues.shape\n", |
376 | 363 | "mse = MSE()\n", |
377 | | - "mse_error_spatial = mse.score(preds, trues)\n", |
| 364 | + "mse_error_spatial = mse(preds, trues)\n", |
378 | 365 | "mse_error = mse(preds, trues)\n", |
379 | 366 | "print(\"MSE spatial has shape (B,T,C):\", mse_error_spatial.shape)\n", |
380 | 367 | "print(\"MSE overall is a single scalar:\", mse_error.shape)" |
|
383 | 370 | { |
384 | 371 | "cell_type": "code", |
385 | 372 | "execution_count": null, |
386 | | - "id": "22", |
| 373 | + "id": "21", |
387 | 374 | "metadata": {}, |
388 | 375 | "outputs": [], |
389 | 376 | "source": [ |
|
415 | 402 | { |
416 | 403 | "cell_type": "code", |
417 | 404 | "execution_count": null, |
418 | | - "id": "23", |
| 405 | + "id": "22", |
419 | 406 | "metadata": {}, |
420 | 407 | "outputs": [], |
421 | 408 | "source": [] |
|
0 commit comments