|
87 | 87 | "metadata": {}, |
88 | 88 | "outputs": [], |
89 | 89 | "source": [ |
90 | | - "combined_data[\"test\"].keys()" |
| 90 | + "from autocast.logging import create_wandb_logger, maybe_watch_model\n", |
| 91 | + "from autocast.logging.wandb import create_notebook_logger\n", |
| 92 | + "\n", |
| 93 | + "logger, watch = create_notebook_logger(\n", |
| 94 | + " project=\"autocast-notebooks\",\n", |
| 95 | + " name=f\"00_01_exploration_{simulation_name}\",\n", |
| 96 | + " tags=[\"notebook\", simulation_name]\n", |
| 97 | + ")" |
91 | 98 | ] |
92 | 99 | }, |
93 | 100 | { |
|
193 | 200 | " hid_blocks=(2, 2, 2),\n", |
194 | 201 | " spatial=2,\n", |
195 | 202 | " periodic=False,\n", |
196 | | - ")\n", |
| 203 | + " )\n", |
197 | 204 | "\n", |
198 | 205 | "if processor_name == \"flow_matching\":\n", |
199 | 206 | " processor = FlowMatchingProcessor(\n", |
|
203 | 210 | " n_channels_out=n_channels,\n", |
204 | 211 | " stride=n_steps_output,\n", |
205 | 212 | " flow_ode_steps=4,\n", |
206 | | - " )\n", |
| 213 | + " )\n", |
207 | 214 | "else:\n", |
208 | 215 | " from autocast.processors.diffusion import DiffusionProcessor\n", |
209 | 216 | "\n", |
|
213 | 220 | " n_steps_output=n_steps_output,\n", |
214 | 221 | " n_channels_out=n_channels,\n", |
215 | 222 | " stride=n_steps_output,\n", |
216 | | - " )\n", |
| 223 | + " )\n", |
217 | 224 | "\n", |
218 | 225 | "encoder = IdentityEncoder()\n", |
219 | 226 | "decoder = IdentityDecoder()\n", |
|
224 | 231 | " train_processor_only=True,\n", |
225 | 232 | " # learning_rate=1e-5,\n", |
226 | 233 | " learning_rate=1e-4,\n", |
227 | | - ")" |
| 234 | + " )\n", |
| 235 | + "maybe_watch_model(logger, model, watch)" |
228 | 236 | ] |
229 | 237 | }, |
230 | 238 | { |
|
256 | 264 | "\n", |
257 | 265 | "device = \"mps\" # \"cpu\"\n", |
258 | 266 | "# device = \"cpu\"\n", |
259 | | - "trainer = L.Trainer(max_epochs=4, accelerator=device, log_every_n_steps=10)\n", |
| 267 | + "trainer = L.Trainer(max_epochs=4, accelerator=device, log_every_n_steps=10, logger=logger)\n", |
260 | 268 | "trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())\n", |
261 | 269 | "trainer.save_checkpoint(f\"./{simulation_name}_{processor_name}_model.ckpt\")" |
262 | 270 | ] |
|
0 commit comments