Skip to content

Commit c536680

Browse files
authored
Merge pull request #205 from gperdrizet/dev
Updated training function to work for lazy loading or preloading of t…
2 parents 8608017 + 882f4ad commit c536680

File tree

1 file changed

+102
-46
lines changed

1 file changed

+102
-46
lines changed

notebooks/unit4/lesson_31/Lesson_31_activity.ipynb

Lines changed: 102 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
"metadata": {},
8181
"outputs": [],
8282
"source": [
83-
"batch_size = 10000 # Training images come in 5 batches of 10,000\n",
83+
"batch_size = 1000 # Training images come in 5 batches of 10,000\n",
8484
"learning_rate = 1e-3\n",
8585
"epochs = 30\n",
8686
"print_every = 5 # Print training progress every n epochs\n",
@@ -218,11 +218,11 @@
218218
"metadata": {},
219219
"outputs": [],
220220
"source": [
221-
"# TODO: Modify this architecture\n",
222221
"num_classes = 10\n",
223222
"\n",
224223
"model_exp1 = nn.Sequential(\n",
225-
" # Conv block: 1 -> 32 channels, 32 x 32 -> 16 x 16\n",
224+
"\n",
225+
" # Conv block: grayscale input\n",
226226
" nn.Conv2d(1, 32, kernel_size=3, padding=1),\n",
227227
" nn.BatchNorm2d(32),\n",
228228
" nn.ReLU(),\n",
@@ -234,10 +234,11 @@
234234
" \n",
235235
" # Classifier\n",
236236
" nn.Flatten(),\n",
237-
" nn.Linear(32 * 16 * 16, 128), # TODO: Update input size if you add more layers\n",
237+
" nn.Linear(32 * 16 * 16, 128),\n",
238238
" nn.ReLU(),\n",
239239
" nn.Dropout(0.5),\n",
240240
" nn.Linear(128, num_classes)\n",
241+
"\n",
241242
").to(device)\n",
242243
"\n",
243244
"trainable_params = sum(p.numel() for p in model_exp1.parameters() if p.requires_grad)\n",
@@ -267,11 +268,14 @@
267268
" criterion: nn.Module,\n",
268269
" optimizer: optim.Optimizer,\n",
269270
" epochs: int = 10,\n",
270-
" print_every: int = 1\n",
271+
" print_every: int = 1,\n",
272+
" device: torch.device = None\n",
271273
") -> dict[str, list[float]]:\n",
272274
" '''Training loop for PyTorch classification model.\n",
273275
" \n",
274-
" Note: Assumes data is already on the correct device.\n",
276+
" Args:\n",
277+
" device: If provided, moves batches to this device on-the-fly.\n",
278+
" If None, assumes data is already on the correct device.\n",
275279
" '''\n",
276280
"\n",
277281
" history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []}\n",
@@ -285,6 +289,10 @@
285289
" total = 0\n",
286290
"\n",
287291
" for images, labels in train_loader:\n",
292+
" \n",
293+
" # Move batch to device if specified\n",
294+
" if device is not None:\n",
295+
" images, labels = images.to(device), labels.to(device)\n",
288296
"\n",
289297
" # Forward pass\n",
290298
" optimizer.zero_grad()\n",
@@ -314,6 +322,10 @@
314322
" with torch.no_grad():\n",
315323
"\n",
316324
" for images, labels in val_loader:\n",
325+
" \n",
326+
" # Move batch to device if specified\n",
327+
" if device is not None:\n",
328+
" images, labels = images.to(device), labels.to(device)\n",
317329
"\n",
318330
" outputs = model(images)\n",
319331
" loss = criterion(outputs, labels)\n",
@@ -388,11 +400,14 @@
388400
"source": [
389401
"def evaluate_model(\n",
390402
" model: nn.Module,\n",
391-
" test_loader: DataLoader\n",
403+
" test_loader: DataLoader,\n",
404+
" device: torch.device = None\n",
392405
") -> tuple[float, np.ndarray, np.ndarray]:\n",
393406
" '''Evaluate model on test set.\n",
394407
" \n",
395-
" Note: Assumes data is already on the correct device.\n",
408+
" Args:\n",
409+
" device: If provided, moves batches to this device on-the-fly.\n",
410+
" If None, assumes data is already on the correct device.\n",
396411
" '''\n",
397412
"\n",
398413
" model.eval()\n",
@@ -404,6 +419,10 @@
404419
" with torch.no_grad():\n",
405420
"\n",
406421
" for images, labels in test_loader:\n",
422+
" \n",
423+
" # Move batch to device if specified\n",
424+
" if device is not None:\n",
425+
" images, labels = images.to(device), labels.to(device)\n",
407426
"\n",
408427
" outputs = model(images)\n",
409428
" _, predicted = torch.max(outputs.data, 1)\n",
@@ -495,10 +514,8 @@
495514
"source": [
496515
"# TODO: Modify this transform to use RGB instead of grayscale\n",
497516
"transform_exp2 = transforms.Compose([\n",
498-
" # TODO: Remove Grayscale transform\n",
499-
" transforms.Grayscale(num_output_channels=1), # Remove this line\n",
500517
" transforms.ToTensor(),\n",
501-
" transforms.Normalize((0.5,), (0.5,)) # TODO: Update normalization for 3 channels\n",
518+
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
502519
"])\n",
503520
"\n",
504521
"# Load training and test datasets with RGB\n",
@@ -617,7 +634,8 @@
617634
"source": [
618635
"# TODO: Update the first conv layer to accept 3 channels\n",
619636
"model_exp2 = nn.Sequential(\n",
620-
" # Conv block: TODO: Change input channels from 1 to 3\n",
637+
"\n",
638+
" # Conv block: RGB input\n",
621639
" nn.Conv2d(1, 32, kernel_size=3, padding=1),\n",
622640
" nn.BatchNorm2d(32),\n",
623641
" nn.ReLU(),\n",
@@ -633,6 +651,7 @@
633651
" nn.ReLU(),\n",
634652
" nn.Dropout(0.5),\n",
635653
" nn.Linear(128, num_classes)\n",
654+
"\n",
636655
").to(device)\n",
637656
"\n",
638657
"trainable_params = sum(p.numel() for p in model_exp2.parameters() if p.requires_grad)\n",
@@ -761,7 +780,7 @@
761780
"transform_train_exp3 = transforms.Compose([\n",
762781
" # TODO: Add augmentation transforms here (before ToTensor) \n",
763782
" transforms.ToTensor(),\n",
764-
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Using RGB\n",
783+
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
765784
"])\n",
766785
"\n",
767786
"# Validation and test transforms (no augmentation)\n",
@@ -833,9 +852,7 @@
833852
"id": "ba0683e5",
834853
"metadata": {},
835854
"source": [
836-
"### 3.3. Create Data Loaders with Augmentation\n",
837-
"\n",
838-
"**Note**: For augmented data, we cannot preload to GPU because each epoch needs different augmentations. We'll use regular DataLoaders."
855+
"### 3.3. Create Data Loaders with Augmentation"
839856
]
840857
},
841858
{
@@ -845,32 +862,53 @@
845862
"metadata": {},
846863
"outputs": [],
847864
"source": [
848-
"# Create DataLoaders (cannot preload augmented data to GPU)\n",
865+
"# For data augmentation, we must NOT preload data to GPU as tensors.\n",
866+
"# Transforms need to be applied on-the-fly during each epoch so each \n",
867+
"# batch sees different augmented versions of the images.\n",
868+
"\n",
869+
"# Split training data into train and validation sets using Subset\n",
870+
"n_train = int(0.8 * len(train_dataset_exp3))\n",
871+
"n_val = len(train_dataset_exp3) - n_train\n",
872+
"indices = torch.randperm(len(train_dataset_exp3)).tolist()\n",
873+
"\n",
874+
"train_subset_exp3 = torch.utils.data.Subset(train_dataset_exp3, indices[:n_train])\n",
875+
"val_subset_exp3 = torch.utils.data.Subset(train_dataset_exp3, indices[n_train:])\n",
876+
"\n",
877+
"print(f'Training samples: {len(train_subset_exp3)}')\n",
878+
"print(f'Validation samples: {len(val_subset_exp3)}')\n",
879+
"print(f'Test samples: {len(test_dataset_exp3)}')"
880+
]
881+
},
882+
{
883+
"cell_type": "code",
884+
"execution_count": null,
885+
"id": "70daf63e",
886+
"metadata": {},
887+
"outputs": [],
888+
"source": [
889+
"# Create DataLoaders directly from Dataset/Subset objects\n",
890+
"# Transforms are applied on-the-fly when batches are loaded\n",
849891
"train_loader_exp3 = DataLoader(\n",
850-
" train_dataset_exp3,\n",
892+
" train_subset_exp3,\n",
851893
" batch_size=batch_size,\n",
852894
" shuffle=True\n",
853895
")\n",
854896
"\n",
855-
"# For validation/test, we can use the same approach as experiment 2\n",
856-
"X_test_exp3 = torch.stack([img for img, _ in test_dataset_exp3]).to(device)\n",
857-
"y_test_exp3 = torch.tensor([label for _, label in test_dataset_exp3]).to(device)\n",
858-
"\n",
859-
"# Create validation split from training data\n",
860-
"n_val = int(0.2 * len(train_dataset_exp3))\n",
861-
"n_train = len(train_dataset_exp3) - n_val\n",
862-
"\n",
863-
"train_subset_exp3, val_subset_exp3 = torch.utils.data.random_split(\n",
864-
" train_dataset_exp3,\n",
865-
" [n_train, n_val]\n",
897+
"val_loader_exp3 = DataLoader(\n",
898+
" val_subset_exp3,\n",
899+
" batch_size=batch_size,\n",
900+
" shuffle=False\n",
866901
")\n",
867902
"\n",
868-
"val_loader_exp3 = DataLoader(val_subset_exp3, batch_size=batch_size, shuffle=False)\n",
869-
"test_tensor_dataset_exp3 = torch.utils.data.TensorDataset(X_test_exp3, y_test_exp3)\n",
870-
"test_loader_exp3 = DataLoader(test_tensor_dataset_exp3, batch_size=batch_size, shuffle=False)\n",
903+
"test_loader_exp3 = DataLoader(\n",
904+
" test_dataset_exp3,\n",
905+
" batch_size=batch_size,\n",
906+
" shuffle=False\n",
907+
")\n",
871908
"\n",
872909
"print(f'Training batches: {len(train_loader_exp3)}')\n",
873-
"print(f'Validation batches: {len(val_loader_exp3)}')"
910+
"print(f'Validation batches: {len(val_loader_exp3)}')\n",
911+
"print(f'Test batches: {len(test_loader_exp3)}')"
874912
]
875913
},
876914
{
@@ -890,6 +928,7 @@
890928
"source": [
891929
"# Same architecture as Experiment 2 (RGB)\n",
892930
"model_exp3 = nn.Sequential(\n",
931+
"\n",
893932
" # Conv block: RGB input\n",
894933
" nn.Conv2d(3, 32, kernel_size=3, padding=1),\n",
895934
" nn.BatchNorm2d(32),\n",
@@ -906,6 +945,7 @@
906945
" nn.ReLU(),\n",
907946
" nn.Dropout(0.5),\n",
908947
" nn.Linear(128, num_classes)\n",
948+
"\n",
909949
").to(device)\n",
910950
"\n",
911951
"trainable_params = sum(p.numel() for p in model_exp3.parameters() if p.requires_grad)\n",
@@ -932,25 +972,19 @@
932972
"criterion_exp3 = nn.CrossEntropyLoss()\n",
933973
"optimizer_exp3 = optim.Adam(model_exp3.parameters(), lr=learning_rate)\n",
934974
"\n",
975+
"# Pass device to move batches on-the-fly (required for on-the-fly augmentation)\n",
935976
"history_exp3 = train_model(\n",
936977
" model=model_exp3,\n",
937978
" train_loader=train_loader_exp3,\n",
938979
" val_loader=val_loader_exp3,\n",
939980
" criterion=criterion_exp3,\n",
940981
" optimizer=optimizer_exp3,\n",
941982
" epochs=epochs,\n",
942-
" print_every=print_every\n",
983+
" print_every=print_every,\n",
984+
" device=device\n",
943985
")"
944986
]
945987
},
946-
{
947-
"cell_type": "markdown",
948-
"id": "0faf27d3",
949-
"metadata": {},
950-
"source": [
951-
"### 3.5. Train Model with Augmented Data"
952-
]
953-
},
954988
{
955989
"cell_type": "markdown",
956990
"id": "85ad8659",
@@ -986,8 +1020,10 @@
9861020
"plt.tight_layout()\n",
9871021
"plt.show()\n",
9881022
"\n",
989-
"# Test accuracy\n",
990-
"test_accuracy_exp3, predictions_exp3, true_labels_exp3 = evaluate_model(model_exp3, test_loader_exp3)\n",
1023+
"# Test accuracy (pass device for on-the-fly batch loading)\n",
1024+
"test_accuracy_exp3, predictions_exp3, true_labels_exp3 = evaluate_model(\n",
1025+
" model_exp3, test_loader_exp3, device=device\n",
1026+
")\n",
9911027
"print(f'\\nExperiment 3 Test Accuracy: {test_accuracy_exp3:.2f}%')"
9921028
]
9931029
},
@@ -1024,7 +1060,7 @@
10241060
"\n",
10251061
"| Experiment | Description | Test Accuracy | Notes |\n",
10261062
"|------------|-------------|---------------|-------|\n",
1027-
"| Baseline (demo) | Grayscale, simple architecture | ~45% | From demo notebook |\n",
1063+
"| Baseline (demo) | Grayscale, simple architecture | ~60% | From demo notebook |\n",
10281064
"| Experiment 1 | Modified architecture | _% | |\n",
10291065
"| Experiment 2 | RGB images | _% | |\n",
10301066
"| Experiment 3 | Image augmentation | _% | |\n",
@@ -1044,11 +1080,31 @@
10441080
"*Your reflections here:*\n",
10451081
"\n"
10461082
]
1083+
},
1084+
{
1085+
"cell_type": "markdown",
1086+
"id": "b8155ac6",
1087+
"metadata": {},
1088+
"source": []
10471089
}
10481090
],
10491091
"metadata": {
1092+
"kernelspec": {
1093+
"display_name": "Python 3",
1094+
"language": "python",
1095+
"name": "python3"
1096+
},
10501097
"language_info": {
1051-
"name": "python"
1098+
"codemirror_mode": {
1099+
"name": "ipython",
1100+
"version": 3
1101+
},
1102+
"file_extension": ".py",
1103+
"mimetype": "text/x-python",
1104+
"name": "python",
1105+
"nbconvert_exporter": "python",
1106+
"pygments_lexer": "ipython3",
1107+
"version": "3.10.12"
10521108
}
10531109
},
10541110
"nbformat": 4,

0 commit comments

Comments
 (0)