|
45 | 45 | "source": [ |
46 | 46 | "import matplotlib.pyplot as plt\n", |
47 | 47 | "import torch\n", |
| 48 | + "import copy\n", |
48 | 49 | "import lightning as L\n", |
49 | 50 | "from torchvision.models import resnet18\n", |
50 | 51 | "from torchvision.transforms import v2\n", |
|
121 | 122 | }, |
122 | 123 | { |
123 | 124 | "cell_type": "code", |
124 | | - "execution_count": 7, |
| 125 | + "execution_count": null, |
125 | 126 | "id": "794004c5-588c-4590-ae96-c6d9e52109ff", |
126 | 127 | "metadata": {}, |
127 | 128 | "outputs": [], |
128 | 129 | "source": [ |
129 | 130 | "EPOCHS = 2 # number of epochs to train\n", |
130 | | - "FAST_DEV_RUN = 5 # Quick prototype, comment line for full epochs training" |
| 131 | + "FAST_DEV_RUN = None # Quick prototype, comment line for full epochs training" |
131 | 132 | ] |
132 | 133 | }, |
133 | 134 | { |
|
278 | 279 | "outputs": [], |
279 | 280 | "source": [ |
280 | 281 | "# Train first model (regular training) using our backbone\n", |
281 | | - "model_regular = LitMNIST(backbone=resnet)" |
| 282 | + "model_regular = LitMNIST(backbone=copy.deepcopy(resnet))" |
282 | 283 | ] |
283 | 284 | }, |
284 | 285 | { |
|
295 | 296 | "TPU available: False, using: 0 TPU cores\n", |
296 | 297 | "IPU available: False, using: 0 IPUs\n", |
297 | 298 | "HPU available: False, using: 0 HPUs\n", |
298 | | - "Running in `fast_dev_run` mode: will run the requested loop using 3 batch(es). Logging and checkpointing is suppressed.\n" |
| 299 | + "Running in `fast_dev_run` mode: will run the requested loop using 10 batch(es). Logging and checkpointing is suppressed.\n" |
299 | 300 | ] |
300 | 301 | } |
301 | 302 | ], |
|
337 | 338 | "name": "stdout", |
338 | 339 | "output_type": "stream", |
339 | 340 | "text": [ |
340 | | - "Epoch 0: 100%|██████████| 3/3 [00:50<00:00, 0.06it/s, loss_step=274.0, val_loss_step=260.0, cindex_step=0.615, val_loss_epoch=261.0, cindex_epoch=0.618, loss_epoch=300.0]" |
| 341 | + "Epoch 0: 100%|██████████| 10/10 [06:00<00:00, 0.03it/s, loss_step=242.0, val_loss_step=260.0, cindex_step=0.585, val_loss_epoch=260.0, cindex_epoch=0.579, loss_epoch=265.0]" |
341 | 342 | ] |
342 | 343 | }, |
343 | 344 | { |
344 | 345 | "name": "stderr", |
345 | 346 | "output_type": "stream", |
346 | 347 | "text": [ |
347 | | - "`Trainer.fit` stopped: `max_steps=3` reached.\n" |
| 348 | + "`Trainer.fit` stopped: `max_steps=10` reached.\n" |
348 | 349 | ] |
349 | 350 | }, |
350 | 351 | { |
351 | 352 | "name": "stdout", |
352 | 353 | "output_type": "stream", |
353 | 354 | "text": [ |
354 | | - "Epoch 0: 100%|██████████| 3/3 [00:50<00:00, 0.06it/s, loss_step=274.0, val_loss_step=260.0, cindex_step=0.615, val_loss_epoch=261.0, cindex_epoch=0.618, loss_epoch=300.0]\n" |
| 355 | + "Epoch 0: 100%|██████████| 10/10 [06:00<00:00, 0.03it/s, loss_step=242.0, val_loss_step=260.0, cindex_step=0.585, val_loss_epoch=260.0, cindex_epoch=0.579, loss_epoch=265.0]\n" |
355 | 356 | ] |
356 | 357 | } |
357 | 358 | ], |
|
370 | 371 | "name": "stdout", |
371 | 372 | "output_type": "stream", |
372 | 373 | "text": [ |
373 | | - "Testing DataLoader 0: 100%|██████████| 3/3 [00:20<00:00, 0.15it/s]\n", |
| 374 | + "Testing DataLoader 0: 100%|██████████| 10/10 [03:09<00:00, 0.05it/s]\n", |
374 | 375 | "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", |
375 | 376 | " Test metric DataLoader 0\n", |
376 | 377 | "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", |
377 | | - " cindex_epoch 0.612917959690094\n", |
378 | | - " val_loss_epoch 34.167057037353516\n", |
| 378 | + " cindex_epoch 0.5841928124427795\n", |
| 379 | + " val_loss_epoch -90.17676544189453\n", |
379 | 380 | "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n" |
380 | 381 | ] |
381 | 382 | }, |
382 | 383 | { |
383 | 384 | "data": { |
384 | 385 | "text/plain": [ |
385 | | - "[{'val_loss_epoch': 34.167057037353516, 'cindex_epoch': 0.612917959690094}]" |
| 386 | + "[{'val_loss_epoch': -90.17676544189453, 'cindex_epoch': 0.5841928124427795}]" |
386 | 387 | ] |
387 | 388 | }, |
388 | 389 | "execution_count": 15, |
|
415 | 416 | "outputs": [], |
416 | 417 | "source": [ |
417 | 418 | "FACTOR = 10 # Number of batch to keep in memory. Increase our training batch size artificially by factor of 10 here\n", |
418 | | - "resnet_momentum = Momentum(resnet, neg_partial_log_likelihood, steps=FACTOR, rate=0.999)\n", |
| 419 | + "resnet_momentum = Momentum(copy.deepcopy(resnet), neg_partial_log_likelihood, steps=FACTOR, rate=0.999)\n", |
419 | 420 | "model_momentum = LitMomentum(backbone=resnet_momentum)\n", |
420 | 421 | "\n", |
421 | 422 | "# By using momentum, we can in theory reduce our batch size by factor and still have the same effective sample size\n", |
|
438 | 439 | "TPU available: False, using: 0 TPU cores\n", |
439 | 440 | "IPU available: False, using: 0 IPUs\n", |
440 | 441 | "HPU available: False, using: 0 HPUs\n", |
441 | | - "Running in `fast_dev_run` mode: will run the requested loop using 3 batch(es). Logging and checkpointing is suppressed.\n", |
| 442 | + "Running in `fast_dev_run` mode: will run the requested loop using 10 batch(es). Logging and checkpointing is suppressed.\n", |
442 | 443 | "\n", |
443 | 444 | " | Name | Type | Params\n", |
444 | 445 | "-----------------------------------\n", |
|
454 | 455 | "name": "stdout", |
455 | 456 | "output_type": "stream", |
456 | 457 | "text": [ |
457 | | - "Epoch 0: 100%|██████████| 3/3 [00:15<00:00, 0.19it/s, loss_step=44.70, val_loss_step=74.10, cindex_step=0.614, val_loss_epoch=74.00, cindex_epoch=0.631, loss_epoch=23.30]" |
| 458 | + "Epoch 0: 100%|██████████| 10/10 [01:27<00:00, 0.11it/s, loss_step=65.60, val_loss_step=67.00, cindex_step=0.524, val_loss_epoch=66.40, cindex_epoch=0.519, loss_epoch=52.70]" |
458 | 459 | ] |
459 | 460 | }, |
460 | 461 | { |
461 | 462 | "name": "stderr", |
462 | 463 | "output_type": "stream", |
463 | 464 | "text": [ |
464 | | - "`Trainer.fit` stopped: `max_steps=3` reached.\n" |
| 465 | + "`Trainer.fit` stopped: `max_steps=10` reached.\n" |
465 | 466 | ] |
466 | 467 | }, |
467 | 468 | { |
468 | 469 | "name": "stdout", |
469 | 470 | "output_type": "stream", |
470 | 471 | "text": [ |
471 | | - "Epoch 0: 100%|██████████| 3/3 [00:15<00:00, 0.19it/s, loss_step=44.70, val_loss_step=74.10, cindex_step=0.614, val_loss_epoch=74.00, cindex_epoch=0.631, loss_epoch=23.30]\n" |
| 472 | + "Epoch 0: 100%|██████████| 10/10 [01:27<00:00, 0.11it/s, loss_step=65.60, val_loss_step=67.00, cindex_step=0.524, val_loss_epoch=66.40, cindex_epoch=0.519, loss_epoch=52.70]\n" |
472 | 473 | ] |
473 | 474 | } |
474 | 475 | ], |
|
497 | 498 | "name": "stdout", |
498 | 499 | "output_type": "stream", |
499 | 500 | "text": [ |
500 | | - "Testing DataLoader 0: 100%|██████████| 3/3 [00:04<00:00, 0.61it/s]\n", |
| 501 | + "Testing DataLoader 0: 100%|██████████| 10/10 [00:29<00:00, 0.34it/s]\n", |
501 | 502 | "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", |
502 | 503 | " Test metric DataLoader 0\n", |
503 | 504 | "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", |
504 | | - " cindex_epoch 0.609957754611969\n", |
505 | | - " val_loss_epoch 65.41207122802734\n", |
| 505 | + " cindex_epoch 0.521008312702179\n", |
| 506 | + " val_loss_epoch 66.2925796508789\n", |
506 | 507 | "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n" |
507 | 508 | ] |
508 | 509 | }, |
509 | 510 | { |
510 | 511 | "data": { |
511 | 512 | "text/plain": [ |
512 | | - "[{'val_loss_epoch': 65.41207122802734, 'cindex_epoch': 0.609957754611969}]" |
| 513 | + "[{'val_loss_epoch': 66.2925796508789, 'cindex_epoch': 0.521008312702179}]" |
513 | 514 | ] |
514 | 515 | }, |
515 | 516 | "execution_count": 18, |
|
580 | 581 | "name": "stdout", |
581 | 582 | "output_type": "stream", |
582 | 583 | "text": [ |
583 | | - "Cindex (regular) = 0.637220025062561\n", |
584 | | - "Cindex (momentum) = 0.6211331486701965\n", |
585 | | - "Compare (p-value) = 0.9429321885108948\n" |
| 584 | + "Cindex (regular) = 0.5925866365432739\n", |
| 585 | + "Cindex (momentum) = 0.5304314494132996\n", |
| 586 | + "Compare (p-value) = 0.9395550489425659\n" |
586 | 587 | ] |
587 | 588 | } |
588 | 589 | ], |
|
0 commit comments