Skip to content

Commit f2bb7ed

Browse files
authored
Merge pull request #98 from alan-turing-institute/82-the-well
Add support for The Well datasets (#82)
2 parents 93c073d + 6f7366d commit f2bb7ed

5 files changed

Lines changed: 252 additions & 98 deletions

File tree

notebooks/00_01_exploration_diffusion_reaction.ipynb

Lines changed: 98 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -36,115 +36,128 @@
3636
"from autocast.data.advection_diffusion import (\n",
3737
" AdvectionDiffusion as AdvectionDiffusionMultichannel,\n",
3838
")\n",
39+
"from autocast.data.datamodule import SpatioTemporalDataModule, TheWellDataModule\n",
3940
"from autocast.metrics.spatiotemporal import MAE, MSE, RMSE\n",
4041
"\n",
41-
"# simulation_name = \"reaction_diffusion\"\n",
42-
"# simulation_name = \"advection_diffusion\"\n",
43-
"simulation_name = \"advection_diffusion_multichannel\"\n",
44-
"\n",
45-
"if simulation_name == \"advection_diffusion_multichannel\":\n",
46-
" # Override to use multichannel version\n",
47-
" Sim = AdvectionDiffusionMultichannel\n",
48-
"if simulation_name == \"reaction_diffusion\":\n",
49-
" Sim = ReactionDiffusion\n",
50-
"if simulation_name == \"advection_diffusion\":\n",
51-
" Sim = AdvectionDiffusion\n",
52-
"\n",
53-
"sim = Sim(return_timeseries=True, log_level=\"error\")\n",
54-
"\n",
55-
"def generate_split(simulator, n_train: int = 200, n_valid: int = 20, n_test: int = 20):\n",
56-
" \"\"\"Generate training, validation, and test splits from the simulator.\"\"\"\n",
57-
" train = simulator.forward_samples_spatiotemporal(n_train)\n",
58-
" valid = simulator.forward_samples_spatiotemporal(n_valid)\n",
59-
" test = simulator.forward_samples_spatiotemporal(n_test)\n",
60-
" return {\"train\": train, \"valid\": valid, \"test\": test}\n",
61-
"\n",
62-
"\n",
63-
"# Cache file path\n",
64-
"cache_file = Path(f\"{simulation_name}_cache.pkl\")\n",
65-
"\n",
66-
"# Load from cache if it exists, otherwise generate and save\n",
67-
"if cache_file.exists():\n",
68-
" print(f\"Loading cached simulation data from {cache_file}\")\n",
69-
" with open(cache_file, \"rb\") as f:\n",
70-
" combined_data = pickle.load(f)\n",
71-
" for key in ['data', 'constant_scalars', 'constant_fields']:\n",
72-
" combined_data[\"test\"][key] = (\n",
73-
" combined_data[\"test\"][key][:5]\n",
74-
" if combined_data[\"test\"][key] is not None else None\n",
75-
" )\n",
76-
"else:\n",
77-
" print(\"Generating simulation data...\")\n",
78-
" combined_data = generate_split(sim)\n",
79-
" print(f\"Saving simulation data to {cache_file}\")\n",
80-
" with open(cache_file, \"wb\") as f:\n",
81-
" pickle.dump(combined_data, f)\n"
42+
"THE_WELL = True\n",
43+
"n_steps_input = 1\n",
44+
"n_steps_output = 4\n",
45+
"stride = n_steps_output"
46+
]
47+
},
48+
{
49+
"cell_type": "markdown",
50+
"id": "3",
51+
"metadata": {},
52+
"source": [
53+
"### Read combined data into datamodule"
8254
]
8355
},
8456
{
8557
"cell_type": "code",
8658
"execution_count": null,
87-
"id": "3",
59+
"id": "4",
8860
"metadata": {},
8961
"outputs": [],
9062
"source": [
91-
"from autocast.logging import create_wandb_logger, maybe_watch_model\n",
92-
"from autocast.logging.wandb import create_notebook_logger\n",
9363
"\n",
94-
"logger, watch = create_notebook_logger(\n",
95-
" project=\"autocast-notebooks\",\n",
96-
" name=f\"00_01_exploration_{simulation_name}\",\n",
97-
" tags=[\"notebook\", simulation_name]\n",
98-
")"
64+
"if not THE_WELL:\n",
65+
" # simulation_name = \"reaction_diffusion\"\n",
66+
" # simulation_name = \"advection_diffusion\"\n",
67+
" simulation_name = \"advection_diffusion_multichannel\"\n",
68+
"\n",
69+
" if simulation_name == \"advection_diffusion_multichannel\":\n",
70+
" # Override to use multichannel version\n",
71+
" Sim = AdvectionDiffusionMultichannel\n",
72+
" if simulation_name == \"reaction_diffusion\":\n",
73+
" Sim = ReactionDiffusion\n",
74+
" if simulation_name == \"advection_diffusion\":\n",
75+
" Sim = AdvectionDiffusion\n",
76+
"\n",
77+
" sim = Sim(return_timeseries=True, log_level=\"error\")\n",
78+
"\n",
79+
" def generate_split(\n",
80+
" simulator, n_train: int = 200, n_valid: int = 20, n_test: int = 20\n",
81+
" ):\n",
82+
" \"\"\"Generate training, validation, and test splits from the simulator.\"\"\"\n",
83+
" train = simulator.forward_samples_spatiotemporal(n_train)\n",
84+
" valid = simulator.forward_samples_spatiotemporal(n_valid)\n",
85+
" test = simulator.forward_samples_spatiotemporal(n_test)\n",
86+
" return {\"train\": train, \"valid\": valid, \"test\": test}\n",
87+
"\n",
88+
" # Cache file path\n",
89+
" cache_file = Path(f\"{simulation_name}_cache.pkl\")\n",
90+
"\n",
91+
" # Load from cache if it exists, otherwise generate and save\n",
92+
" if cache_file.exists():\n",
93+
" print(f\"Loading cached simulation data from {cache_file}\")\n",
94+
" with open(cache_file, \"rb\") as f:\n",
95+
" combined_data = pickle.load(f)\n",
96+
" for key in [\"data\", \"constant_scalars\", \"constant_fields\"]:\n",
97+
" combined_data[\"test\"][key] = (\n",
98+
" combined_data[\"test\"][key][:5]\n",
99+
" if combined_data[\"test\"][key] is not None\n",
100+
" else None\n",
101+
" )\n",
102+
" else:\n",
103+
" print(\"Generating simulation data...\")\n",
104+
" combined_data = generate_split(sim)\n",
105+
" print(f\"Saving simulation data to {cache_file}\")\n",
106+
" with open(cache_file, \"wb\") as f:\n",
107+
" pickle.dump(combined_data, f)\n",
108+
"\n",
109+
" datamodule = SpatioTemporalDataModule(\n",
110+
" data=combined_data,\n",
111+
" data_path=None,\n",
112+
" n_steps_input=n_steps_input,\n",
113+
" n_steps_output=n_steps_output,\n",
114+
" stride=n_steps_output,\n",
115+
" batch_size=16,\n",
116+
" )\n",
117+
"else:\n",
118+
" simulation_name = \"turbulent_radiative_layer_2D\"\n",
119+
" datamodule = TheWellDataModule(\n",
120+
" well_base_path=\"../../autocast/datasets/\",\n",
121+
" well_dataset_name=simulation_name,\n",
122+
" n_steps_input=n_steps_input,\n",
123+
" n_steps_output=n_steps_output,\n",
124+
" min_dt_stride=1,\n",
125+
" use_normalization=True,\n",
126+
" )\n"
99127
]
100128
},
101129
{
102130
"cell_type": "markdown",
103-
"id": "4",
131+
"id": "5",
104132
"metadata": {},
105133
"source": [
106-
"### Read combined data into datamodule\n"
134+
"### Set-up logging"
107135
]
108136
},
109137
{
110138
"cell_type": "code",
111139
"execution_count": null,
112-
"id": "5",
140+
"id": "6",
113141
"metadata": {},
114142
"outputs": [],
115143
"source": [
116144
"\n",
117-
"from autocast.data.datamodule import SpatioTemporalDataModule\n",
145+
"from autocast.logging import create_wandb_logger, maybe_watch_model\n",
146+
"from autocast.logging.wandb import create_notebook_logger\n",
118147
"\n",
119-
"n_steps_input = 1\n",
120-
"n_steps_output = 4\n",
121-
"stride = n_steps_output\n",
122-
"datamodule = SpatioTemporalDataModule(\n",
123-
" data=combined_data,\n",
124-
" data_path=None,\n",
125-
" n_steps_input=n_steps_input,\n",
126-
" n_steps_output=n_steps_output,\n",
127-
" stride=n_steps_output,\n",
128-
" batch_size=16,\n",
148+
"logger, watch = create_notebook_logger(\n",
149+
" project=\"autocast-notebooks\",\n",
150+
" name=f\"00_01_exploration_{simulation_name}\",\n",
151+
" tags=[\"notebook\", simulation_name],\n",
129152
")"
130153
]
131154
},
132155
{
133156
"cell_type": "markdown",
134-
"id": "6",
135-
"metadata": {},
136-
"source": [
137-
"### Example batch\n"
138-
]
139-
},
140-
{
141-
"cell_type": "code",
142-
"execution_count": null,
143157
"id": "7",
144158
"metadata": {},
145-
"outputs": [],
146159
"source": [
147-
"len(datamodule.train_dataset) / 50\n"
160+
"### Example shape and batch\n"
148161
]
149162
},
150163
{
@@ -164,7 +177,6 @@
164177
"metadata": {},
165178
"outputs": [],
166179
"source": [
167-
"\n",
168180
"batch = next(iter(datamodule.train_dataloader()))\n",
169181
"\n",
170182
"batch.input_fields.shape"
@@ -201,7 +213,7 @@
201213
" hid_blocks=(2, 2, 2),\n",
202214
" spatial=2,\n",
203215
" periodic=False,\n",
204-
" )\n",
216+
")\n",
205217
"\n",
206218
"if processor_name == \"flow_matching\":\n",
207219
" processor = FlowMatchingProcessor(\n",
@@ -211,7 +223,7 @@
211223
" n_channels_out=n_channels,\n",
212224
" stride=n_steps_output,\n",
213225
" flow_ode_steps=4,\n",
214-
" )\n",
226+
" )\n",
215227
"else:\n",
216228
" from autocast.processors.diffusion import DiffusionProcessor\n",
217229
"\n",
@@ -221,7 +233,7 @@
221233
" n_steps_output=n_steps_output,\n",
222234
" n_channels_out=n_channels,\n",
223235
" stride=n_steps_output,\n",
224-
" )\n",
236+
" )\n",
225237
"\n",
226238
"encoder = IdentityEncoder()\n",
227239
"decoder = IdentityDecoder()\n",
@@ -233,7 +245,7 @@
233245
" # learning_rate=1e-5,\n",
234246
" learning_rate=1e-4,\n",
235247
" #test_metrics = [MSE(), MAE(), RMSE()]\n",
236-
" )\n",
248+
")\n",
237249
"maybe_watch_model(logger, model, watch)"
238250
]
239251
},
@@ -266,7 +278,9 @@
266278
"\n",
267279
"device = \"mps\" # \"cpu\"\n",
268280
"# device = \"cpu\"\n",
269-
"trainer = L.Trainer(max_epochs=4, accelerator=device, log_every_n_steps=10, logger=logger)\n",
281+
"trainer = L.Trainer(\n",
282+
" max_epochs=4, accelerator=device, log_every_n_steps=10, logger=logger\n",
283+
")\n",
270284
"trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())\n",
271285
"trainer.save_checkpoint(f\"./{simulation_name}_{processor_name}_model.ckpt\")"
272286
]
@@ -379,13 +393,13 @@
379393
"\n",
380394
"batch_idx = 0\n",
381395
"if simulation_name == \"advection_diffusion_multichannel\":\n",
382-
" channel_names=[\"vorticity\", \"velocity_x\", \"velocity_y\", \"streamfunction\"]\n",
396+
" channel_names = [\"vorticity\", \"velocity_x\", \"velocity_y\", \"streamfunction\"]\n",
383397
"elif simulation_name == \"advection_diffusion\":\n",
384-
" channel_names=[\"vorticity\"]\n",
398+
" channel_names = [\"vorticity\"]\n",
385399
"elif simulation_name == \"reaction_diffusion\":\n",
386-
" channel_names=[\"U\", \"V\"]\n",
400+
" channel_names = [\"U\", \"V\"]\n",
387401
"else:\n",
388-
" channel_names=None\n",
402+
" channel_names = None\n",
389403
"\n",
390404
"anim = plot_spatiotemporal_video(\n",
391405
" pred=preds,\n",

src/autocast/data/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from autocast.data.datamodule import SpatioTemporalDataModule, TheWellDataModule
2+
from autocast.data.dataset import SpatioTemporalDataset, TheWell
3+
4+
__all__ = [
5+
"SpatioTemporalDataModule",
6+
"SpatioTemporalDataset",
7+
"TheWell",
8+
"TheWellDataModule",
9+
]

0 commit comments

Comments
 (0)