Skip to content

Commit 55cac81

Browse files
authored
Merge pull request #84 from Novartis/83-update-notebook
fixed example
2 parents 6ed7d8c + 179b634 commit 55cac81

File tree

2 files changed

+25
-23
lines changed

2 files changed

+25
-23
lines changed

docs/notebooks/helpers_momentum.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def train_dataloader(self):
163163
num_workers=self.num_workers,
164164
persistent_workers=True,
165165
shuffle=True,
166+
drop_last=True,
166167
)
167168

168169
def val_dataloader(self):

docs/notebooks/momentum.ipynb

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"source": [
4646
"import matplotlib.pyplot as plt\n",
4747
"import torch\n",
48+
"import copy\n",
4849
"import lightning as L\n",
4950
"from torchvision.models import resnet18\n",
5051
"from torchvision.transforms import v2\n",
@@ -121,13 +122,13 @@
121122
},
122123
{
123124
"cell_type": "code",
124-
"execution_count": 7,
125+
"execution_count": null,
125126
"id": "794004c5-588c-4590-ae96-c6d9e52109ff",
126127
"metadata": {},
127128
"outputs": [],
128129
"source": [
129130
"EPOCHS = 2 # number of epochs to train\n",
130-
"FAST_DEV_RUN = 5 # Quick prototype, comment line for full epochs training"
131+
"FAST_DEV_RUN = None # Quick prototype, comment line for full epochs training"
131132
]
132133
},
133134
{
@@ -278,7 +279,7 @@
278279
"outputs": [],
279280
"source": [
280281
"# Train first model (regular training) using our backbone\n",
281-
"model_regular = LitMNIST(backbone=resnet)"
282+
"model_regular = LitMNIST(backbone=copy.deepcopy(resnet))"
282283
]
283284
},
284285
{
@@ -295,7 +296,7 @@
295296
"TPU available: False, using: 0 TPU cores\n",
296297
"IPU available: False, using: 0 IPUs\n",
297298
"HPU available: False, using: 0 HPUs\n",
298-
"Running in `fast_dev_run` mode: will run the requested loop using 3 batch(es). Logging and checkpointing is suppressed.\n"
299+
"Running in `fast_dev_run` mode: will run the requested loop using 10 batch(es). Logging and checkpointing is suppressed.\n"
299300
]
300301
}
301302
],
@@ -337,21 +338,21 @@
337338
"name": "stdout",
338339
"output_type": "stream",
339340
"text": [
340-
"Epoch 0: 100%|██████████| 3/3 [00:50<00:00, 0.06it/s, loss_step=274.0, val_loss_step=260.0, cindex_step=0.615, val_loss_epoch=261.0, cindex_epoch=0.618, loss_epoch=300.0]"
341+
"Epoch 0: 100%|██████████| 10/10 [06:00<00:00, 0.03it/s, loss_step=242.0, val_loss_step=260.0, cindex_step=0.585, val_loss_epoch=260.0, cindex_epoch=0.579, loss_epoch=265.0]"
341342
]
342343
},
343344
{
344345
"name": "stderr",
345346
"output_type": "stream",
346347
"text": [
347-
"`Trainer.fit` stopped: `max_steps=3` reached.\n"
348+
"`Trainer.fit` stopped: `max_steps=10` reached.\n"
348349
]
349350
},
350351
{
351352
"name": "stdout",
352353
"output_type": "stream",
353354
"text": [
354-
"Epoch 0: 100%|██████████| 3/3 [00:50<00:00, 0.06it/s, loss_step=274.0, val_loss_step=260.0, cindex_step=0.615, val_loss_epoch=261.0, cindex_epoch=0.618, loss_epoch=300.0]\n"
355+
"Epoch 0: 100%|██████████| 10/10 [06:00<00:00, 0.03it/s, loss_step=242.0, val_loss_step=260.0, cindex_step=0.585, val_loss_epoch=260.0, cindex_epoch=0.579, loss_epoch=265.0]\n"
355356
]
356357
}
357358
],
@@ -370,19 +371,19 @@
370371
"name": "stdout",
371372
"output_type": "stream",
372373
"text": [
373-
"Testing DataLoader 0: 100%|██████████| 3/3 [00:20<00:00, 0.15it/s]\n",
374+
"Testing DataLoader 0: 100%|██████████| 10/10 [03:09<00:00, 0.05it/s]\n",
374375
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
375376
" Test metric DataLoader 0\n",
376377
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
377-
" cindex_epoch 0.612917959690094\n",
378-
" val_loss_epoch 34.167057037353516\n",
378+
" cindex_epoch 0.5841928124427795\n",
379+
" val_loss_epoch -90.17676544189453\n",
379380
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
380381
]
381382
},
382383
{
383384
"data": {
384385
"text/plain": [
385-
"[{'val_loss_epoch': 34.167057037353516, 'cindex_epoch': 0.612917959690094}]"
386+
"[{'val_loss_epoch': -90.17676544189453, 'cindex_epoch': 0.5841928124427795}]"
386387
]
387388
},
388389
"execution_count": 15,
@@ -415,7 +416,7 @@
415416
"outputs": [],
416417
"source": [
417418
"FACTOR = 10 # Number of batch to keep in memory. Increase our training batch size artificially by factor of 10 here\n",
418-
"resnet_momentum = Momentum(resnet, neg_partial_log_likelihood, steps=FACTOR, rate=0.999)\n",
419+
"resnet_momentum = Momentum(copy.deepcopy(resnet), neg_partial_log_likelihood, steps=FACTOR, rate=0.999)\n",
419420
"model_momentum = LitMomentum(backbone=resnet_momentum)\n",
420421
"\n",
421422
"# By using momentum, we can in theory reduce our batch size by factor and still have the same effective sample size\n",
@@ -438,7 +439,7 @@
438439
"TPU available: False, using: 0 TPU cores\n",
439440
"IPU available: False, using: 0 IPUs\n",
440441
"HPU available: False, using: 0 HPUs\n",
441-
"Running in `fast_dev_run` mode: will run the requested loop using 3 batch(es). Logging and checkpointing is suppressed.\n",
442+
"Running in `fast_dev_run` mode: will run the requested loop using 10 batch(es). Logging and checkpointing is suppressed.\n",
442443
"\n",
443444
" | Name | Type | Params\n",
444445
"-----------------------------------\n",
@@ -454,21 +455,21 @@
454455
"name": "stdout",
455456
"output_type": "stream",
456457
"text": [
457-
"Epoch 0: 100%|██████████| 3/3 [00:15<00:00, 0.19it/s, loss_step=44.70, val_loss_step=74.10, cindex_step=0.614, val_loss_epoch=74.00, cindex_epoch=0.631, loss_epoch=23.30]"
458+
"Epoch 0: 100%|██████████| 10/10 [01:27<00:00, 0.11it/s, loss_step=65.60, val_loss_step=67.00, cindex_step=0.524, val_loss_epoch=66.40, cindex_epoch=0.519, loss_epoch=52.70]"
458459
]
459460
},
460461
{
461462
"name": "stderr",
462463
"output_type": "stream",
463464
"text": [
464-
"`Trainer.fit` stopped: `max_steps=3` reached.\n"
465+
"`Trainer.fit` stopped: `max_steps=10` reached.\n"
465466
]
466467
},
467468
{
468469
"name": "stdout",
469470
"output_type": "stream",
470471
"text": [
471-
"Epoch 0: 100%|██████████| 3/3 [00:15<00:00, 0.19it/s, loss_step=44.70, val_loss_step=74.10, cindex_step=0.614, val_loss_epoch=74.00, cindex_epoch=0.631, loss_epoch=23.30]\n"
472+
"Epoch 0: 100%|██████████| 10/10 [01:27<00:00, 0.11it/s, loss_step=65.60, val_loss_step=67.00, cindex_step=0.524, val_loss_epoch=66.40, cindex_epoch=0.519, loss_epoch=52.70]\n"
472473
]
473474
}
474475
],
@@ -497,19 +498,19 @@
497498
"name": "stdout",
498499
"output_type": "stream",
499500
"text": [
500-
"Testing DataLoader 0: 100%|██████████| 3/3 [00:04<00:00, 0.61it/s]\n",
501+
"Testing DataLoader 0: 100%|██████████| 10/10 [00:29<00:00, 0.34it/s]\n",
501502
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
502503
" Test metric DataLoader 0\n",
503504
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
504-
" cindex_epoch 0.609957754611969\n",
505-
" val_loss_epoch 65.41207122802734\n",
505+
" cindex_epoch 0.521008312702179\n",
506+
" val_loss_epoch 66.2925796508789\n",
506507
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n"
507508
]
508509
},
509510
{
510511
"data": {
511512
"text/plain": [
512-
"[{'val_loss_epoch': 65.41207122802734, 'cindex_epoch': 0.609957754611969}]"
513+
"[{'val_loss_epoch': 66.2925796508789, 'cindex_epoch': 0.521008312702179}]"
513514
]
514515
},
515516
"execution_count": 18,
@@ -580,9 +581,9 @@
580581
"name": "stdout",
581582
"output_type": "stream",
582583
"text": [
583-
"Cindex (regular) = 0.637220025062561\n",
584-
"Cindex (momentum) = 0.6211331486701965\n",
585-
"Compare (p-value) = 0.9429321885108948\n"
584+
"Cindex (regular) = 0.5925866365432739\n",
585+
"Cindex (momentum) = 0.5304314494132996\n",
586+
"Compare (p-value) = 0.9395550489425659\n"
586587
]
587588
}
588589
],

0 commit comments

Comments
 (0)