Skip to content

Commit 1a669a4

Browse files
authored
Merge pull request #46 from alan-turing-institute/45-encoder-decoder
- Add U-Net (#5) - Add trainable encoder-decoder (#45) - Add tensor type hints (#47)
2 parents f5dfad9 + 59a4c4b commit 1a669a4

35 files changed

Lines changed: 2940 additions & 174 deletions

notebooks/00_exploration.ipynb

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"source": [
1818
"### Example dataaset\n",
1919
"\n",
20-
"We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D."
20+
"We use the `AdvectionDiffusion` dataset as an example dataset to illustrate training and evaluation of models. This dataset simulates the advection-diffusion equation in 2D.\n"
2121
]
2222
},
2323
{
@@ -27,12 +27,12 @@
2727
"metadata": {},
2828
"outputs": [],
2929
"source": [
30-
"\n",
31-
"from autoemulate.simulations.reaction_diffusion import ReactionDiffusion as Sim\n",
30+
"from autoemulate.simulations.advection_diffusion import AdvectionDiffusion as Sim\n",
3231
"\n",
3332
"sim = Sim(return_timeseries=True, log_level=\"error\")\n",
3433
"\n",
35-
"def generate_split(simulator: Sim, n_train: int = 1, n_valid: int = 1, n_test: int = 1):\n",
34+
"\n",
35+
"def generate_split(simulator: Sim, n_train: int = 10, n_valid: int = 2, n_test: int = 2):\n",
3636
" \"\"\"Generate training, validation, and test splits from the simulator.\"\"\"\n",
3737
" train = simulator.forward_samples_spatiotemporal(n_train)\n",
3838
" valid = simulator.forward_samples_spatiotemporal(n_valid)\n",
@@ -63,7 +63,11 @@
6363
"n_steps_input = 4\n",
6464
"n_steps_output = 1\n",
6565
"datamodule = SpatioTemporalDataModule(\n",
66-
" data=combined_data, data_path=None, n_steps_input=n_steps_input, n_steps_output=n_steps_output, batch_size=16\n",
66+
" data=combined_data,\n",
67+
" data_path=None,\n",
68+
" n_steps_input=n_steps_input,\n",
69+
" n_steps_output=n_steps_output,\n",
70+
" batch_size=16,\n",
6771
")"
6872
]
6973
},
@@ -112,7 +116,9 @@
112116
"decoder = ChannelsLast(output_channels=n_channels, time_steps=n_steps_output)\n",
113117
"\n",
114118
"model = EncoderProcessorDecoder.from_encoder_processor_decoder(\n",
115-
" encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),\n",
119+
" encoder_decoder=EncoderDecoder.from_encoder_decoder(\n",
120+
" encoder=encoder, decoder=decoder\n",
121+
" ),\n",
116122
" processor=processor,\n",
117123
")"
118124
]
@@ -144,8 +150,8 @@
144150
"source": [
145151
"import lightning as L\n",
146152
"\n",
147-
"# device = \"mps\" # \"cpu\"\n",
148-
"device = \"cpu\"\n",
153+
"device = \"mps\" # \"cpu\"\n",
154+
"# device = \"cpu\"\n",
149155
"trainer = L.Trainer(max_epochs=1, accelerator=device, log_every_n_steps=10)\n",
150156
"trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())"
151157
]
@@ -155,7 +161,7 @@
155161
"id": "11",
156162
"metadata": {},
157163
"source": [
158-
"### Run the evaluation"
164+
"### Run the evaluation\n"
159165
]
160166
},
161167
{
@@ -173,7 +179,7 @@
173179
"id": "13",
174180
"metadata": {},
175181
"source": [
176-
"### Example rollout"
182+
"### Example rollout\n"
177183
]
178184
},
179185
{
@@ -208,7 +214,7 @@
208214
"outputs": [],
209215
"source": [
210216
"# Run rollout on one trajectory\n",
211-
"preds, trues = model.rollout(batch)"
217+
"preds, trues = model.rollout(batch, free_running_only=True)"
212218
]
213219
},
214220
{
@@ -228,6 +234,7 @@
228234
"metadata": {},
229235
"outputs": [],
230236
"source": [
237+
"assert trues is not None\n",
231238
"print(trues.shape)\n"
232239
]
233240
}

notebooks/01_encoder_decoder.ipynb

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "0",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"# from autoemulate.simulations.advection_diffusion import AdvectionDiffusion\n",
11+
"from autoemulate.simulations.reaction_diffusion import ReactionDiffusion as Sim\n",
12+
"\n",
13+
"sim = Sim(return_timeseries=True, log_level=\"error\")\n",
14+
"\n",
15+
"\n",
16+
"def generate_split(simulator: Sim, n_train: int = 4, n_valid: int = 2, n_test: int = 2):\n",
17+
" \"\"\"Generate training, validation, and test splits from the simulator.\"\"\"\n",
18+
" train = simulator.forward_samples_spatiotemporal(n_train)\n",
19+
" valid = simulator.forward_samples_spatiotemporal(n_valid)\n",
20+
" test = simulator.forward_samples_spatiotemporal(n_test)\n",
21+
" return {\"train\": train, \"valid\": valid, \"test\": test}\n",
22+
"\n",
23+
"\n",
24+
"combined_data = generate_split(sim)"
25+
]
26+
},
27+
{
28+
"cell_type": "code",
29+
"execution_count": null,
30+
"id": "1",
31+
"metadata": {},
32+
"outputs": [],
33+
"source": [
34+
"from auto_cast.data.datamodule import SpatioTemporalDataModule\n",
35+
"\n",
36+
"datamodule = SpatioTemporalDataModule(\n",
37+
" data=combined_data,\n",
38+
" data_path=None,\n",
39+
" n_steps_input=1,\n",
40+
" n_steps_output=0,\n",
41+
" batch_size=16,\n",
42+
" autoencoder_mode=True,\n",
43+
")"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"id": "2",
50+
"metadata": {},
51+
"outputs": [],
52+
"source": [
53+
"batch = next(iter(datamodule.train_dataloader()))\n"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"id": "3",
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"# Check input field shape: batch of single frames with two channels\n",
64+
"batch.input_fields.shape"
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": null,
70+
"id": "4",
71+
"metadata": {},
72+
"outputs": [],
73+
"source": [
74+
"import torch\n",
75+
"\n",
76+
"torch.allclose(batch.input_fields, batch.output_fields)\n"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": null,
82+
"id": "5",
83+
"metadata": {},
84+
"outputs": [],
85+
"source": [
86+
"from auto_cast.decoders.dc import DCDecoder\n",
87+
"from auto_cast.encoders.dc import DCEncoder\n",
88+
"from auto_cast.models.ae import AE\n",
89+
"\n",
90+
"channels = batch.input_fields.shape[-1]\n",
91+
"\n",
92+
"encoder = DCEncoder(\n",
93+
" in_channels=channels,\n",
94+
" out_channels=16,\n",
95+
" hid_channels=(32, 64),\n",
96+
" spatial=2,\n",
97+
" hid_blocks=(2, 2),\n",
98+
" pixel_shuffle=False,\n",
99+
")\n",
100+
"\n",
101+
"decoder = DCDecoder(\n",
102+
" in_channels=16,\n",
103+
" out_channels=channels,\n",
104+
" hid_channels=(64, 32),\n",
105+
" spatial=2,\n",
106+
" hid_blocks=(2, 2),\n",
107+
" pixel_shuffle=False,\n",
108+
")\n",
109+
"model = AE(encoder=encoder, decoder=decoder)"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": null,
115+
"id": "6",
116+
"metadata": {},
117+
"outputs": [],
118+
"source": [
119+
"import lightning as L\n",
120+
"\n",
121+
"device = \"mps\" # \"cpu\"\n",
122+
"trainer = L.Trainer(max_epochs=5, accelerator=device, log_every_n_steps=10)\n",
123+
"trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())"
124+
]
125+
},
126+
{
127+
"cell_type": "code",
128+
"execution_count": null,
129+
"id": "7",
130+
"metadata": {},
131+
"outputs": [],
132+
"source": [
133+
"import matplotlib.pyplot as plt\n",
134+
"\n",
135+
"for idx, batch in enumerate(datamodule.test_dataloader()):\n",
136+
" inputs = batch.input_fields.to(device)\n",
137+
" outputs, latents = model.forward_with_latent(batch)\n",
138+
" print(\"Input shape:\", inputs.shape)\n",
139+
" print(\"Output shape:\", outputs.shape)\n",
140+
" print(\"Latent shape:\", latents.shape)\n",
141+
" fig, axs = plt.subplots(1, 4, figsize=(8, 4))\n",
142+
" axs[0].imshow(inputs[0, 0, :, :, 0].cpu().numpy(), cmap=\"viridis\")\n",
143+
" axs[0].set_title(\"Input\")\n",
144+
" axs[1].imshow(outputs[0, 0, :, :, 0].detach().cpu().numpy(), cmap=\"viridis\")\n",
145+
" axs[1].set_title(\"Reconstruction\")\n",
146+
" axs[2].imshow(\n",
147+
" outputs[0, 0, :, :, 0].detach().cpu().numpy()\n",
148+
" - inputs[0, 0, :, :, 0].cpu().numpy(),\n",
149+
" cmap=\"viridis\",\n",
150+
" )\n",
151+
" axs[2].set_title(\"Difference\")\n",
152+
" axs[3].imshow(latents[0, 0, :, :, 0].detach().cpu().numpy(), cmap=\"viridis\")\n",
153+
" axs[3].set_title(f\"Latent dim {0}\")\n",
154+
" plt.show()\n",
155+
" if idx >= 3:\n",
156+
" break"
157+
]
158+
}
159+
],
160+
"metadata": {
161+
"kernelspec": {
162+
"display_name": ".venv",
163+
"language": "python",
164+
"name": "python3"
165+
},
166+
"language_info": {
167+
"codemirror_mode": {
168+
"name": "ipython",
169+
"version": 3
170+
},
171+
"file_extension": ".py",
172+
"mimetype": "text/x-python",
173+
"name": "python",
174+
"nbconvert_exporter": "python",
175+
"pygments_lexer": "ipython3",
176+
"version": "3.12.12"
177+
}
178+
},
179+
"nbformat": 4,
180+
"nbformat_minor": 5
181+
}

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ authors = [
99
requires-python = ">=3.11,<3.13"
1010
dependencies = [
1111
"autoemulate>=1.2.0",
12+
"azula>=0.7.0",
1213
"einops>=0.8.1",
1314
"h5py>=3.15.1",
15+
"jaxtyping>=0.3.3",
1416
"lightning>=2.5.6",
1517
"neuraloperator>=2.0.0",
1618
"the-well>=1.1.0",
@@ -19,6 +21,7 @@ dependencies = [
1921

2022
[project.optional-dependencies]
2123
dev = [
24+
"beartype>=0.22.8",
2225
"ipykernel>=7.1.0",
2326
"pytest>=9.0.1",
2427
"pytest-cov>=7.0.0",

src/auto_cast/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
1-
def main() -> None: # noqa: D103
2-
print("Hello from auto-cast!")
1+
import os
2+
3+
if os.getenv("RUNTIME_TYPECHECKING", "True").lower() in ["1", "true"]:
4+
from beartype.claw import beartype_this_package
5+
6+
beartype_this_package()

0 commit comments

Comments
 (0)