Skip to content

Commit d3c48c8

Browse files
committed
Add exploration notebook but with diffusion
1 parent a863c68 commit d3c48c8

1 file changed

Lines changed: 340 additions & 0 deletions

File tree

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "0",
6+
"metadata": {},
7+
"source": [
8+
"## AutoCast encoder-processor-decoder model API Exploration\n",
9+
"\n",
10+
"This notebook aims to explore the end-to-end API.\n"
11+
]
12+
},
13+
{
14+
"cell_type": "markdown",
15+
"id": "1",
16+
"metadata": {},
17+
"source": [
18+
"### Example dataaset\n",
19+
"\n",
20+
"We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D.\n"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"id": "2",
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"import pickle\n",
31+
"from pathlib import Path\n",
32+
"\n",
33+
"from autoemulate.simulations.advection_diffusion import AdvectionDiffusion\n",
34+
"from autoemulate.simulations.reaction_diffusion import ReactionDiffusion\n",
35+
"\n",
36+
"simulation_name = \"reaction_diffusion\"\n",
37+
"\n",
38+
"Sim = (\n",
39+
" ReactionDiffusion if simulation_name == \"reaction_diffusion\" else AdvectionDiffusion\n",
40+
")\n",
41+
"sim = Sim(return_timeseries=True, log_level=\"error\")\n",
42+
"\n",
43+
"\n",
44+
"def generate_split(simulator, n_train: int = 50, n_valid: int = 2, n_test: int = 2):\n",
45+
" \"\"\"Generate training, validation, and test splits from the simulator.\"\"\"\n",
46+
" train = simulator.forward_samples_spatiotemporal(n_train)\n",
47+
" valid = simulator.forward_samples_spatiotemporal(n_valid)\n",
48+
" test = simulator.forward_samples_spatiotemporal(n_test)\n",
49+
" return {\"train\": train, \"valid\": valid, \"test\": test}\n",
50+
"\n",
51+
"\n",
52+
"# Cache file path\n",
53+
"cache_file = Path(f\"{simulation_name}_cache.pkl\")\n",
54+
"\n",
55+
"# Load from cache if it exists, otherwise generate and save\n",
56+
"if cache_file.exists():\n",
57+
" print(f\"Loading cached simulation data from {cache_file}\")\n",
58+
" with open(cache_file, \"rb\") as f:\n",
59+
" combined_data = pickle.load(f)\n",
60+
"else:\n",
61+
" print(\"Generating simulation data...\")\n",
62+
" combined_data = generate_split(sim)\n",
63+
" print(f\"Saving simulation data to {cache_file}\")\n",
64+
" with open(cache_file, \"wb\") as f:\n",
65+
" pickle.dump(combined_data, f)\n"
66+
]
67+
},
68+
{
69+
"cell_type": "markdown",
70+
"id": "3",
71+
"metadata": {},
72+
"source": [
73+
"### Read combined data into datamodule\n"
74+
]
75+
},
76+
{
77+
"cell_type": "code",
78+
"execution_count": null,
79+
"id": "4",
80+
"metadata": {},
81+
"outputs": [],
82+
"source": [
83+
"from auto_cast.data.datamodule import SpatioTemporalDataModule\n",
84+
"\n",
85+
"n_steps_input = 1\n",
86+
"n_steps_output = 4\n",
87+
"stride = 4\n",
88+
"datamodule = SpatioTemporalDataModule(\n",
89+
" data=combined_data,\n",
90+
" data_path=None,\n",
91+
" n_steps_input=n_steps_input,\n",
92+
" n_steps_output=n_steps_output,\n",
93+
" stride=n_steps_output,\n",
94+
" batch_size=16,\n",
95+
")"
96+
]
97+
},
98+
{
99+
"cell_type": "markdown",
100+
"id": "5",
101+
"metadata": {},
102+
"source": [
103+
"### Example batch\n"
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": null,
109+
"id": "6",
110+
"metadata": {},
111+
"outputs": [],
112+
"source": [
113+
"batch = next(iter(datamodule.train_dataloader()))\n",
114+
"\n",
115+
"# batch"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"id": "7",
122+
"metadata": {},
123+
"outputs": [],
124+
"source": [
125+
"from azula.noise import VPSchedule\n",
126+
"\n",
127+
"from auto_cast.decoders.identity import IdentityDecoder\n",
128+
"from auto_cast.encoders.identity import IdentityEncoder\n",
129+
"from auto_cast.models.encoder_decoder import EncoderDecoder\n",
130+
"from auto_cast.models.encoder_processor_decoder import EPDTrainProcessor\n",
131+
"from auto_cast.nn.unet import TemporalUNetBackbone\n",
132+
"from auto_cast.processors.diffusion import DiffusionProcessor\n",
133+
"\n",
134+
"# from auto_cast.processors.fno import FNOProcessor\n",
135+
"\n",
136+
"batch = next(iter(datamodule.train_dataloader()))\n",
137+
"n_channels = batch.input_fields.shape[-1]\n",
138+
"# processor = FNOProcessor(\n",
139+
"# in_channels=n_channels * n_steps_input,\n",
140+
"# out_channels=n_channels * n_steps_output,\n",
141+
"# n_modes=(16, 16),\n",
142+
"# hidden_channels=64,\n",
143+
"# stride=n_steps_output,\n",
144+
"# max_rollout_steps=100,\n",
145+
"# )\n",
146+
"processor = DiffusionProcessor(\n",
147+
" backbone=TemporalUNetBackbone(\n",
148+
" in_channels=n_channels * n_steps_output,\n",
149+
" out_channels=n_channels * n_steps_output,\n",
150+
" cond_channels=n_channels * n_steps_input,\n",
151+
" mod_features=256,\n",
152+
" hid_channels=(32, 64, 128),\n",
153+
" hid_blocks=(2, 2, 2),\n",
154+
" spatial=2,\n",
155+
" periodic=False,\n",
156+
" ),\n",
157+
" schedule=VPSchedule(),\n",
158+
" n_steps_output=n_steps_output,\n",
159+
" n_channels_out=n_channels,\n",
160+
" stride=n_steps_output\n",
161+
")\n",
162+
"# encoder = PermuteConcat(with_constants=False)\n",
163+
"# decoder = ChannelsLast(output_channels=n_channels, time_steps=n_steps_output)\n",
164+
"\n",
165+
"encoder = IdentityEncoder()\n",
166+
"decoder = IdentityDecoder()\n",
167+
"model = EPDTrainProcessor(\n",
168+
" encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),\n",
169+
" processor=processor,\n",
170+
" stride=stride,\n",
171+
")"
172+
]
173+
},
174+
{
175+
"cell_type": "code",
176+
"execution_count": null,
177+
"id": "8",
178+
"metadata": {},
179+
"outputs": [],
180+
"source": [
181+
"model(batch).shape"
182+
]
183+
},
184+
{
185+
"cell_type": "code",
186+
"execution_count": null,
187+
"id": "9",
188+
"metadata": {},
189+
"outputs": [],
190+
"source": [
191+
"dl = datamodule.train_dataloader()\n"
192+
]
193+
},
194+
{
195+
"cell_type": "code",
196+
"execution_count": null,
197+
"id": "10",
198+
"metadata": {},
199+
"outputs": [],
200+
"source": [
201+
"batch.input_fields.shape[0]*len(dl)"
202+
]
203+
},
204+
{
205+
"cell_type": "markdown",
206+
"id": "11",
207+
"metadata": {},
208+
"source": [
209+
"### Run trainer\n"
210+
]
211+
},
212+
{
213+
"cell_type": "code",
214+
"execution_count": null,
215+
"id": "12",
216+
"metadata": {},
217+
"outputs": [],
218+
"source": [
219+
"import lightning as L\n",
220+
"\n",
221+
"device = \"mps\" # \"cpu\"\n",
222+
"# device = \"cpu\"\n",
223+
"trainer = L.Trainer(max_epochs=3, accelerator=device, log_every_n_steps=10)\n",
224+
"trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())"
225+
]
226+
},
227+
{
228+
"cell_type": "markdown",
229+
"id": "13",
230+
"metadata": {},
231+
"source": [
232+
"### Run the evaluation\n"
233+
]
234+
},
235+
{
236+
"cell_type": "code",
237+
"execution_count": null,
238+
"id": "14",
239+
"metadata": {},
240+
"outputs": [],
241+
"source": [
242+
"trainer.test(model, datamodule.test_dataloader())"
243+
]
244+
},
245+
{
246+
"cell_type": "markdown",
247+
"id": "15",
248+
"metadata": {},
249+
"source": [
250+
"### Example rollout\n"
251+
]
252+
},
253+
{
254+
"cell_type": "code",
255+
"execution_count": null,
256+
"id": "16",
257+
"metadata": {},
258+
"outputs": [],
259+
"source": [
260+
"# A single element is the full trajectory\n",
261+
"batch = next(iter(datamodule.rollout_test_dataloader()))"
262+
]
263+
},
264+
{
265+
"cell_type": "code",
266+
"execution_count": null,
267+
"id": "17",
268+
"metadata": {},
269+
"outputs": [],
270+
"source": [
271+
"# First n_steps_input are inputs\n",
272+
"print(batch.input_fields.shape)\n",
273+
"# Remaining n_steps_output are outputs\n",
274+
"print(batch.output_fields.shape)"
275+
]
276+
},
277+
{
278+
"cell_type": "code",
279+
"execution_count": null,
280+
"id": "18",
281+
"metadata": {},
282+
"outputs": [],
283+
"source": [
284+
"# Run rollout on one trajectory\n",
285+
"preds, trues = model.rollout(batch, free_running_only=True)\n",
286+
"\n",
287+
"print(preds.shape)\n",
288+
"assert trues is not None\n",
289+
"print(trues.shape)"
290+
]
291+
},
292+
{
293+
"cell_type": "code",
294+
"execution_count": null,
295+
"id": "19",
296+
"metadata": {},
297+
"outputs": [],
298+
"source": [
299+
"from IPython.display import HTML\n",
300+
"\n",
301+
"from auto_cast.utils import plot_spatiotemporal_video\n",
302+
"\n",
303+
"anim = plot_spatiotemporal_video(\n",
304+
" pred=preds,\n",
305+
" true=trues,\n",
306+
")\n",
307+
"HTML(anim.to_jshtml())"
308+
]
309+
},
310+
{
311+
"cell_type": "code",
312+
"execution_count": null,
313+
"id": "20",
314+
"metadata": {},
315+
"outputs": [],
316+
"source": []
317+
}
318+
],
319+
"metadata": {
320+
"kernelspec": {
321+
"display_name": ".venv",
322+
"language": "python",
323+
"name": "python3"
324+
},
325+
"language_info": {
326+
"codemirror_mode": {
327+
"name": "ipython",
328+
"version": 3
329+
},
330+
"file_extension": ".py",
331+
"mimetype": "text/x-python",
332+
"name": "python",
333+
"nbconvert_exporter": "python",
334+
"pygments_lexer": "ipython3",
335+
"version": "3.12.12"
336+
}
337+
},
338+
"nbformat": 4,
339+
"nbformat_minor": 5
340+
}

0 commit comments

Comments
 (0)