Skip to content

Commit dc84c87

Browse files
authored
Merge pull request #94 from alan-turing-institute/91-initial-logging-config
Initial logging config (#91)
2 parents 08ca1da + 18fb67a commit dc84c87

16 files changed

Lines changed: 611 additions & 68 deletions

README.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,27 @@ uv run evaluate_processor \
3333

3434
Evaluation writes a CSV of aggregate metrics to `--csv-path` (defaults to
3535
`<work-dir>/evaluation_metrics.csv`) and, when `--batch-index` is provided,
36-
stores rollout animations for the specified test batches.
36+
stores rollout animations for the specified test batches.
37+
38+
## Experiment Tracking with Weights & Biases
39+
40+
AutoCast now ships with an optional [Weights & Biases](https://wandb.ai/) integration that is
41+
fully driven by the Hydra config under `configs/logging/wandb.yaml`.
42+
43+
- Enable logging for CLI workflows by overriding `logging.wandb.enabled=true` and
44+
optionally providing `project`, `name`, or `tags` overrides:
45+
46+
```bash
47+
uv run train_processor \
48+
--config-path=configs \
49+
--override logging.wandb.enabled=true \
50+
--override logging.wandb.project=autocast-experiments \
51+
--override logging.wandb.name=processor-baseline
52+
```
53+
54+
- The autoencoder/processor training CLIs pass the configured `WandbLogger` directly into Lightning so that metrics, checkpoints, and artifacts are synchronized automatically.
55+
- The evaluation CLI reports aggregate test metrics to the same run when logging is enabled, making it easy to compare training and evaluation outputs in one dashboard.
56+
- All notebooks contain a dedicated cell that instantiates a `wandb_logger` via `autocast.logging.create_wandb_logger`. Toggle the `enabled` flag in that cell to control tracking when experimenting interactively.
57+
58+
When `enabled` remains `false` (the default), the logger is skipped entirely, so the stack can
59+
be used without a W&B account.

configs/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ defaults:
22
- data: reaction_diffusion
33
- model: ae
44
- trainer: default
5+
- logging: wandb
56
- _self_
67

78
seed: 42

configs/logging/wandb.yaml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
wandb:
2+
enabled: false
3+
project: autocast
4+
entity: null
5+
name: null
6+
group: null
7+
job_type: ${experiment_name}
8+
tags: []
9+
notes: null
10+
mode: online
11+
resume: null
12+
id: null
13+
log_model: false
14+
save_dir: null
15+
settings: {}
16+
config: {}
17+
watch:
18+
log: null
19+
log_freq: 100

configs/processor.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ defaults:
44
- decoder: channels_last
55
- processor: fno
66
- trainer: default
7+
- logging: wandb
78
- _self_
89

910
seed: 42

notebooks/00_01_exploration_diffusion_reaction.ipynb

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,14 @@
8787
"metadata": {},
8888
"outputs": [],
8989
"source": [
90-
"combined_data[\"test\"].keys()"
90+
"from autocast.logging import create_wandb_logger, maybe_watch_model\n",
91+
"from autocast.logging.wandb import create_notebook_logger\n",
92+
"\n",
93+
"logger, watch = create_notebook_logger(\n",
94+
" project=\"autocast-notebooks\",\n",
95+
" name=f\"00_01_exploration_{simulation_name}\",\n",
96+
" tags=[\"notebook\", simulation_name]\n",
97+
")"
9198
]
9299
},
93100
{
@@ -193,7 +200,7 @@
193200
" hid_blocks=(2, 2, 2),\n",
194201
" spatial=2,\n",
195202
" periodic=False,\n",
196-
")\n",
203+
" )\n",
197204
"\n",
198205
"if processor_name == \"flow_matching\":\n",
199206
" processor = FlowMatchingProcessor(\n",
@@ -203,7 +210,7 @@
203210
" n_channels_out=n_channels,\n",
204211
" stride=n_steps_output,\n",
205212
" flow_ode_steps=4,\n",
206-
" )\n",
213+
" )\n",
207214
"else:\n",
208215
" from autocast.processors.diffusion import DiffusionProcessor\n",
209216
"\n",
@@ -213,7 +220,7 @@
213220
" n_steps_output=n_steps_output,\n",
214221
" n_channels_out=n_channels,\n",
215222
" stride=n_steps_output,\n",
216-
" )\n",
223+
" )\n",
217224
"\n",
218225
"encoder = IdentityEncoder()\n",
219226
"decoder = IdentityDecoder()\n",
@@ -224,7 +231,8 @@
224231
" train_processor_only=True,\n",
225232
" # learning_rate=1e-5,\n",
226233
" learning_rate=1e-4,\n",
227-
")"
234+
" )\n",
235+
"maybe_watch_model(logger, model, watch)"
228236
]
229237
},
230238
{
@@ -256,7 +264,7 @@
256264
"\n",
257265
"device = \"mps\" # \"cpu\"\n",
258266
"# device = \"cpu\"\n",
259-
"trainer = L.Trainer(max_epochs=4, accelerator=device, log_every_n_steps=10)\n",
267+
"trainer = L.Trainer(max_epochs=4, accelerator=device, log_every_n_steps=10, logger=logger)\n",
260268
"trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())\n",
261269
"trainer.save_checkpoint(f\"./{simulation_name}_{processor_name}_model.ckpt\")"
262270
]

notebooks/00_exploration.ipynb

Lines changed: 48 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,41 @@
6666
]
6767
},
6868
{
69-
"cell_type": "markdown",
69+
"cell_type": "code",
70+
"execution_count": null,
7071
"id": "3",
7172
"metadata": {},
73+
"outputs": [],
74+
"source": [
75+
"from autocast.logging import create_wandb_logger, maybe_watch_model\n",
76+
"\n",
77+
"logging_cfg = {\n",
78+
" \"wandb\": {\n",
79+
" \"enabled\": True, # Set to False to disable wandb for this run.\n",
80+
" \"project\": \"autocast-notebooks\",\n",
81+
" \"name\": \"00_exploration\",\n",
82+
" \"tags\": [\"notebook\", \"00-exploration\"],\n",
83+
" },\n",
84+
"}\n",
85+
"wandb_logger, wandb_watch = create_wandb_logger(\n",
86+
" logging_cfg,\n",
87+
" experiment_name=\"00_exploration\",\n",
88+
" job_type=\"notebook\",\n",
89+
")"
90+
]
91+
},
92+
{
93+
"cell_type": "markdown",
94+
"id": "4",
95+
"metadata": {},
7296
"source": [
7397
"### Read combined data into datamodule\n"
7498
]
7599
},
76100
{
77101
"cell_type": "code",
78102
"execution_count": null,
79-
"id": "4",
103+
"id": "5",
80104
"metadata": {},
81105
"outputs": [],
82106
"source": [
@@ -97,7 +121,7 @@
97121
},
98122
{
99123
"cell_type": "markdown",
100-
"id": "5",
124+
"id": "6",
101125
"metadata": {},
102126
"source": [
103127
"### Example batch\n"
@@ -106,7 +130,7 @@
106130
{
107131
"cell_type": "code",
108132
"execution_count": null,
109-
"id": "6",
133+
"id": "7",
110134
"metadata": {},
111135
"outputs": [],
112136
"source": [
@@ -118,7 +142,7 @@
118142
{
119143
"cell_type": "code",
120144
"execution_count": null,
121-
"id": "7",
145+
"id": "8",
122146
"metadata": {},
123147
"outputs": [],
124148
"source": [
@@ -137,21 +161,22 @@
137161
" hidden_channels=64,\n",
138162
" stride=n_steps_output,\n",
139163
" max_rollout_steps=100,\n",
140-
")\n",
164+
" )\n",
141165
"encoder = PermuteConcat(with_constants=False)\n",
142166
"decoder = ChannelsLast(output_channels=n_channels, time_steps=n_steps_output)\n",
143167
"\n",
144168
"model = EncoderProcessorDecoder(\n",
145169
" encoder_decoder=EncoderDecoder(encoder=encoder, decoder=decoder),\n",
146170
" processor=processor,\n",
147171
" stride=stride,\n",
148-
")"
172+
" )\n",
173+
"maybe_watch_model(wandb_logger, model, wandb_watch)"
149174
]
150175
},
151176
{
152177
"cell_type": "code",
153178
"execution_count": null,
154-
"id": "8",
179+
"id": "9",
155180
"metadata": {},
156181
"outputs": [],
157182
"source": [
@@ -160,7 +185,7 @@
160185
},
161186
{
162187
"cell_type": "markdown",
163-
"id": "9",
188+
"id": "10",
164189
"metadata": {},
165190
"source": [
166191
"### Run trainer\n"
@@ -169,21 +194,21 @@
169194
{
170195
"cell_type": "code",
171196
"execution_count": null,
172-
"id": "10",
197+
"id": "11",
173198
"metadata": {},
174199
"outputs": [],
175200
"source": [
176201
"import lightning as L\n",
177202
"\n",
178203
"device = \"mps\" # \"cpu\"\n",
179204
"# device = \"cpu\"\n",
180-
"trainer = L.Trainer(max_epochs=1, accelerator=device, log_every_n_steps=10)\n",
205+
"trainer = L.Trainer(max_epochs=1, accelerator=device, log_every_n_steps=10, logger=wandb_logger)\n",
181206
"trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())"
182207
]
183208
},
184209
{
185210
"cell_type": "markdown",
186-
"id": "11",
211+
"id": "12",
187212
"metadata": {},
188213
"source": [
189214
"### Run the evaluation\n"
@@ -192,7 +217,7 @@
192217
{
193218
"cell_type": "code",
194219
"execution_count": null,
195-
"id": "12",
220+
"id": "13",
196221
"metadata": {},
197222
"outputs": [],
198223
"source": [
@@ -201,7 +226,7 @@
201226
},
202227
{
203228
"cell_type": "markdown",
204-
"id": "13",
229+
"id": "14",
205230
"metadata": {},
206231
"source": [
207232
"### Example rollout\n"
@@ -210,7 +235,7 @@
210235
{
211236
"cell_type": "code",
212237
"execution_count": null,
213-
"id": "14",
238+
"id": "15",
214239
"metadata": {},
215240
"outputs": [],
216241
"source": [
@@ -221,7 +246,7 @@
221246
{
222247
"cell_type": "code",
223248
"execution_count": null,
224-
"id": "15",
249+
"id": "16",
225250
"metadata": {},
226251
"outputs": [],
227252
"source": [
@@ -234,7 +259,7 @@
234259
{
235260
"cell_type": "code",
236261
"execution_count": null,
237-
"id": "16",
262+
"id": "17",
238263
"metadata": {},
239264
"outputs": [],
240265
"source": [
@@ -245,7 +270,7 @@
245270
{
246271
"cell_type": "code",
247272
"execution_count": null,
248-
"id": "17",
273+
"id": "18",
249274
"metadata": {},
250275
"outputs": [],
251276
"source": [
@@ -259,7 +284,7 @@
259284
{
260285
"cell_type": "code",
261286
"execution_count": null,
262-
"id": "18",
287+
"id": "19",
263288
"metadata": {},
264289
"outputs": [],
265290
"source": [
@@ -269,7 +294,7 @@
269294
{
270295
"cell_type": "code",
271296
"execution_count": null,
272-
"id": "19",
297+
"id": "20",
273298
"metadata": {},
274299
"outputs": [],
275300
"source": [
@@ -280,7 +305,7 @@
280305
{
281306
"cell_type": "code",
282307
"execution_count": null,
283-
"id": "20",
308+
"id": "21",
284309
"metadata": {},
285310
"outputs": [],
286311
"source": [
@@ -293,7 +318,7 @@
293318
{
294319
"cell_type": "code",
295320
"execution_count": null,
296-
"id": "21",
321+
"id": "22",
297322
"metadata": {},
298323
"outputs": [],
299324
"source": [
@@ -311,7 +336,7 @@
311336
{
312337
"cell_type": "code",
313338
"execution_count": null,
314-
"id": "22",
339+
"id": "23",
315340
"metadata": {},
316341
"outputs": [],
317342
"source": []

0 commit comments

Comments
 (0)