|
36 | 36 | "from autocast.data.advection_diffusion import (\n", |
37 | 37 | " AdvectionDiffusion as AdvectionDiffusionMultichannel,\n", |
38 | 38 | ")\n", |
| 39 | + "from autocast.data.datamodule import SpatioTemporalDataModule, TheWellDataModule\n", |
39 | 40 | "from autocast.metrics.spatiotemporal import MAE, MSE, RMSE\n", |
40 | 41 | "\n", |
41 | | - "# simulation_name = \"reaction_diffusion\"\n", |
42 | | - "# simulation_name = \"advection_diffusion\"\n", |
43 | | - "simulation_name = \"advection_diffusion_multichannel\"\n", |
44 | | - "\n", |
45 | | - "if simulation_name == \"advection_diffusion_multichannel\":\n", |
46 | | - " # Override to use multichannel version\n", |
47 | | - " Sim = AdvectionDiffusionMultichannel\n", |
48 | | - "if simulation_name == \"reaction_diffusion\":\n", |
49 | | - " Sim = ReactionDiffusion\n", |
50 | | - "if simulation_name == \"advection_diffusion\":\n", |
51 | | - " Sim = AdvectionDiffusion\n", |
52 | | - "\n", |
53 | | - "sim = Sim(return_timeseries=True, log_level=\"error\")\n", |
54 | | - "\n", |
55 | | - "def generate_split(simulator, n_train: int = 200, n_valid: int = 20, n_test: int = 20):\n", |
56 | | - " \"\"\"Generate training, validation, and test splits from the simulator.\"\"\"\n", |
57 | | - " train = simulator.forward_samples_spatiotemporal(n_train)\n", |
58 | | - " valid = simulator.forward_samples_spatiotemporal(n_valid)\n", |
59 | | - " test = simulator.forward_samples_spatiotemporal(n_test)\n", |
60 | | - " return {\"train\": train, \"valid\": valid, \"test\": test}\n", |
61 | | - "\n", |
62 | | - "\n", |
63 | | - "# Cache file path\n", |
64 | | - "cache_file = Path(f\"{simulation_name}_cache.pkl\")\n", |
65 | | - "\n", |
66 | | - "# Load from cache if it exists, otherwise generate and save\n", |
67 | | - "if cache_file.exists():\n", |
68 | | - " print(f\"Loading cached simulation data from {cache_file}\")\n", |
69 | | - " with open(cache_file, \"rb\") as f:\n", |
70 | | - " combined_data = pickle.load(f)\n", |
71 | | - " for key in ['data', 'constant_scalars', 'constant_fields']:\n", |
72 | | - " combined_data[\"test\"][key] = (\n", |
73 | | - " combined_data[\"test\"][key][:5]\n", |
74 | | - " if combined_data[\"test\"][key] is not None else None\n", |
75 | | - " )\n", |
76 | | - "else:\n", |
77 | | - " print(\"Generating simulation data...\")\n", |
78 | | - " combined_data = generate_split(sim)\n", |
79 | | - " print(f\"Saving simulation data to {cache_file}\")\n", |
80 | | - " with open(cache_file, \"wb\") as f:\n", |
81 | | - " pickle.dump(combined_data, f)\n" |
| 42 | + "THE_WELL = True\n", |
| 43 | + "n_steps_input = 1\n", |
| 44 | + "n_steps_output = 4\n", |
| 45 | + "stride = n_steps_output" |
| 46 | + ] |
| 47 | + }, |
| 48 | + { |
| 49 | + "cell_type": "markdown", |
| 50 | + "id": "3", |
| 51 | + "metadata": {}, |
| 52 | + "source": [ |
| 53 | + "### Read combined data into datamodule" |
82 | 54 | ] |
83 | 55 | }, |
84 | 56 | { |
85 | 57 | "cell_type": "code", |
86 | 58 | "execution_count": null, |
87 | | - "id": "3", |
| 59 | + "id": "4", |
88 | 60 | "metadata": {}, |
89 | 61 | "outputs": [], |
90 | 62 | "source": [ |
91 | | - "from autocast.logging import create_wandb_logger, maybe_watch_model\n", |
92 | | - "from autocast.logging.wandb import create_notebook_logger\n", |
93 | 63 | "\n", |
94 | | - "logger, watch = create_notebook_logger(\n", |
95 | | - " project=\"autocast-notebooks\",\n", |
96 | | - " name=f\"00_01_exploration_{simulation_name}\",\n", |
97 | | - " tags=[\"notebook\", simulation_name]\n", |
98 | | - ")" |
| 64 | + "if not THE_WELL:\n", |
| 65 | + " # simulation_name = \"reaction_diffusion\"\n", |
| 66 | + " # simulation_name = \"advection_diffusion\"\n", |
| 67 | + " simulation_name = \"advection_diffusion_multichannel\"\n", |
| 68 | + "\n", |
| 69 | + " if simulation_name == \"advection_diffusion_multichannel\":\n", |
| 70 | + " # Override to use multichannel version\n", |
| 71 | + " Sim = AdvectionDiffusionMultichannel\n", |
| 72 | + " if simulation_name == \"reaction_diffusion\":\n", |
| 73 | + " Sim = ReactionDiffusion\n", |
| 74 | + " if simulation_name == \"advection_diffusion\":\n", |
| 75 | + " Sim = AdvectionDiffusion\n", |
| 76 | + "\n", |
| 77 | + " sim = Sim(return_timeseries=True, log_level=\"error\")\n", |
| 78 | + "\n", |
| 79 | + " def generate_split(\n", |
| 80 | + " simulator, n_train: int = 200, n_valid: int = 20, n_test: int = 20\n", |
| 81 | + " ):\n", |
| 82 | + " \"\"\"Generate training, validation, and test splits from the simulator.\"\"\"\n", |
| 83 | + " train = simulator.forward_samples_spatiotemporal(n_train)\n", |
| 84 | + " valid = simulator.forward_samples_spatiotemporal(n_valid)\n", |
| 85 | + " test = simulator.forward_samples_spatiotemporal(n_test)\n", |
| 86 | + " return {\"train\": train, \"valid\": valid, \"test\": test}\n", |
| 87 | + "\n", |
| 88 | + " # Cache file path\n", |
| 89 | + " cache_file = Path(f\"{simulation_name}_cache.pkl\")\n", |
| 90 | + "\n", |
| 91 | + " # Load from cache if it exists, otherwise generate and save\n", |
| 92 | + " if cache_file.exists():\n", |
| 93 | + " print(f\"Loading cached simulation data from {cache_file}\")\n", |
| 94 | + " with open(cache_file, \"rb\") as f:\n", |
| 95 | + " combined_data = pickle.load(f)\n", |
| 96 | + " for key in [\"data\", \"constant_scalars\", \"constant_fields\"]:\n", |
| 97 | + " combined_data[\"test\"][key] = (\n", |
| 98 | + " combined_data[\"test\"][key][:5]\n", |
| 99 | + " if combined_data[\"test\"][key] is not None\n", |
| 100 | + " else None\n", |
| 101 | + " )\n", |
| 102 | + " else:\n", |
| 103 | + " print(\"Generating simulation data...\")\n", |
| 104 | + " combined_data = generate_split(sim)\n", |
| 105 | + " print(f\"Saving simulation data to {cache_file}\")\n", |
| 106 | + " with open(cache_file, \"wb\") as f:\n", |
| 107 | + " pickle.dump(combined_data, f)\n", |
| 108 | + "\n", |
| 109 | + " datamodule = SpatioTemporalDataModule(\n", |
| 110 | + " data=combined_data,\n", |
| 111 | + " data_path=None,\n", |
| 112 | + " n_steps_input=n_steps_input,\n", |
| 113 | + " n_steps_output=n_steps_output,\n", |
| 114 | + " stride=n_steps_output,\n", |
| 115 | + " batch_size=16,\n", |
| 116 | + " )\n", |
| 117 | + "else:\n", |
| 118 | + " simulation_name = \"turbulent_radiative_layer_2D\"\n", |
| 119 | + " datamodule = TheWellDataModule(\n", |
| 120 | + " well_base_path=\"../../autocast/datasets/\",\n", |
| 121 | + " well_dataset_name=simulation_name,\n", |
| 122 | + " n_steps_input=n_steps_input,\n", |
| 123 | + " n_steps_output=n_steps_output,\n", |
| 124 | + " min_dt_stride=1,\n", |
| 125 | + " use_normalization=True,\n", |
| 126 | + " )\n" |
99 | 127 | ] |
100 | 128 | }, |
101 | 129 | { |
102 | 130 | "cell_type": "markdown", |
103 | | - "id": "4", |
| 131 | + "id": "5", |
104 | 132 | "metadata": {}, |
105 | 133 | "source": [ |
106 | | - "### Read combined data into datamodule\n" |
| 134 | + "### Set-up logging" |
107 | 135 | ] |
108 | 136 | }, |
109 | 137 | { |
110 | 138 | "cell_type": "code", |
111 | 139 | "execution_count": null, |
112 | | - "id": "5", |
| 140 | + "id": "6", |
113 | 141 | "metadata": {}, |
114 | 142 | "outputs": [], |
115 | 143 | "source": [ |
116 | 144 | "\n", |
117 | | - "from autocast.data.datamodule import SpatioTemporalDataModule\n", |
| 145 | + "from autocast.logging import create_wandb_logger, maybe_watch_model\n", |
| 146 | + "from autocast.logging.wandb import create_notebook_logger\n", |
118 | 147 | "\n", |
119 | | - "n_steps_input = 1\n", |
120 | | - "n_steps_output = 4\n", |
121 | | - "stride = n_steps_output\n", |
122 | | - "datamodule = SpatioTemporalDataModule(\n", |
123 | | - " data=combined_data,\n", |
124 | | - " data_path=None,\n", |
125 | | - " n_steps_input=n_steps_input,\n", |
126 | | - " n_steps_output=n_steps_output,\n", |
127 | | - " stride=n_steps_output,\n", |
128 | | - " batch_size=16,\n", |
| 148 | + "logger, watch = create_notebook_logger(\n", |
| 149 | + " project=\"autocast-notebooks\",\n", |
| 150 | + " name=f\"00_01_exploration_{simulation_name}\",\n", |
| 151 | + " tags=[\"notebook\", simulation_name],\n", |
129 | 152 | ")" |
130 | 153 | ] |
131 | 154 | }, |
132 | 155 | { |
133 | 156 | "cell_type": "markdown", |
134 | | - "id": "6", |
135 | | - "metadata": {}, |
136 | | - "source": [ |
137 | | - "### Example batch\n" |
138 | | - ] |
139 | | - }, |
140 | | - { |
141 | | - "cell_type": "code", |
142 | | - "execution_count": null, |
143 | 157 | "id": "7", |
144 | 158 | "metadata": {}, |
145 | | - "outputs": [], |
146 | 159 | "source": [ |
147 | | - "len(datamodule.train_dataset) / 50\n" |
| 160 | + "### Example shape and batch\n" |
148 | 161 | ] |
149 | 162 | }, |
150 | 163 | { |
|
164 | 177 | "metadata": {}, |
165 | 178 | "outputs": [], |
166 | 179 | "source": [ |
167 | | - "\n", |
168 | 180 | "batch = next(iter(datamodule.train_dataloader()))\n", |
169 | 181 | "\n", |
170 | 182 | "batch.input_fields.shape" |
|
201 | 213 | " hid_blocks=(2, 2, 2),\n", |
202 | 214 | " spatial=2,\n", |
203 | 215 | " periodic=False,\n", |
204 | | - " )\n", |
| 216 | + ")\n", |
205 | 217 | "\n", |
206 | 218 | "if processor_name == \"flow_matching\":\n", |
207 | 219 | " processor = FlowMatchingProcessor(\n", |
|
211 | 223 | " n_channels_out=n_channels,\n", |
212 | 224 | " stride=n_steps_output,\n", |
213 | 225 | " flow_ode_steps=4,\n", |
214 | | - " )\n", |
| 226 | + " )\n", |
215 | 227 | "else:\n", |
216 | 228 | " from autocast.processors.diffusion import DiffusionProcessor\n", |
217 | 229 | "\n", |
|
221 | 233 | " n_steps_output=n_steps_output,\n", |
222 | 234 | " n_channels_out=n_channels,\n", |
223 | 235 | " stride=n_steps_output,\n", |
224 | | - " )\n", |
| 236 | + " )\n", |
225 | 237 | "\n", |
226 | 238 | "encoder = IdentityEncoder()\n", |
227 | 239 | "decoder = IdentityDecoder()\n", |
|
233 | 245 | " # learning_rate=1e-5,\n", |
234 | 246 | " learning_rate=1e-4,\n", |
235 | 247 | " #test_metrics = [MSE(), MAE(), RMSE()]\n", |
236 | | - " )\n", |
| 248 | + ")\n", |
237 | 249 | "maybe_watch_model(logger, model, watch)" |
238 | 250 | ] |
239 | 251 | }, |
|
266 | 278 | "\n", |
267 | 279 | "device = \"mps\" # \"cpu\"\n", |
268 | 280 | "# device = \"cpu\"\n", |
269 | | - "trainer = L.Trainer(max_epochs=4, accelerator=device, log_every_n_steps=10, logger=logger)\n", |
| 281 | + "trainer = L.Trainer(\n", |
| 282 | + " max_epochs=4, accelerator=device, log_every_n_steps=10, logger=logger\n", |
| 283 | + ")\n", |
270 | 284 | "trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())\n", |
271 | 285 | "trainer.save_checkpoint(f\"./{simulation_name}_{processor_name}_model.ckpt\")" |
272 | 286 | ] |
|
379 | 393 | "\n", |
380 | 394 | "batch_idx = 0\n", |
381 | 395 | "if simulation_name == \"advection_diffusion_multichannel\":\n", |
382 | | - " channel_names=[\"vorticity\", \"velocity_x\", \"velocity_y\", \"streamfunction\"]\n", |
| 396 | + " channel_names = [\"vorticity\", \"velocity_x\", \"velocity_y\", \"streamfunction\"]\n", |
383 | 397 | "elif simulation_name == \"advection_diffusion\":\n", |
384 | | - " channel_names=[\"vorticity\"]\n", |
| 398 | + " channel_names = [\"vorticity\"]\n", |
385 | 399 | "elif simulation_name == \"reaction_diffusion\":\n", |
386 | | - " channel_names=[\"U\", \"V\"]\n", |
| 400 | + " channel_names = [\"U\", \"V\"]\n", |
387 | 401 | "else:\n", |
388 | | - " channel_names=None\n", |
| 402 | + " channel_names = None\n", |
389 | 403 | "\n", |
390 | 404 | "anim = plot_spatiotemporal_video(\n", |
391 | 405 | " pred=preds,\n", |
|
0 commit comments