Skip to content

Commit 22c96d6

Browse files
committed
Added training and evaluation functions
1 parent 14ab49a commit 22c96d6

File tree

2 files changed

+294
-12
lines changed

2 files changed

+294
-12
lines changed

notebooks/unit4/lesson_31/Lesson_31_activity.ipynb

Lines changed: 135 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,17 +232,6 @@
232232
" nn.MaxPool2d(2, 2),\n",
233233
" nn.Dropout(0.5),\n",
234234
" \n",
235-
" # TODO: Add more convolutional blocks here\n",
236-
" # Example (uncomment and modify):\n",
237-
" # nn.Conv2d(32, 64, kernel_size=3, padding=1),\n",
238-
" # nn.BatchNorm2d(64),\n",
239-
" # nn.ReLU(),\n",
240-
" # nn.Conv2d(64, 64, kernel_size=3, padding=1),\n",
241-
" # nn.BatchNorm2d(64),\n",
242-
" # nn.ReLU(),\n",
243-
" # nn.MaxPool2d(2, 2),\n",
244-
" # nn.Dropout(0.5),\n",
245-
" \n",
246235
" # Classifier\n",
247236
" nn.Flatten(),\n",
248237
" nn.Linear(32 * 16 * 16, 128), # TODO: Update input size if you add more layers\n",
@@ -264,6 +253,101 @@
264253
"### 1.4. Train Modified Model"
265254
]
266255
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": null,
259+
"id": "8e4102e5",
260+
"metadata": {},
261+
"outputs": [],
262+
"source": [
263+
"def train_model(\n",
264+
" model: nn.Module,\n",
265+
" train_loader: DataLoader,\n",
266+
" val_loader: DataLoader,\n",
267+
" criterion: nn.Module,\n",
268+
" optimizer: optim.Optimizer,\n",
269+
" epochs: int = 10,\n",
270+
" print_every: int = 1\n",
271+
") -> dict[str, list[float]]:\n",
272+
" '''Training loop for PyTorch classification model.\n",
273+
" \n",
274+
" Note: Assumes data is already on the correct device.\n",
275+
" '''\n",
276+
"\n",
277+
" history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []}\n",
278+
"\n",
279+
" for epoch in range(epochs):\n",
280+
"\n",
281+
" # Training phase\n",
282+
" model.train()\n",
283+
" running_loss = 0.0\n",
284+
" correct = 0\n",
285+
" total = 0\n",
286+
"\n",
287+
" for images, labels in train_loader:\n",
288+
"\n",
289+
" # Forward pass\n",
290+
" optimizer.zero_grad()\n",
291+
" outputs = model(images)\n",
292+
" loss = criterion(outputs, labels)\n",
293+
"\n",
294+
" # Backward pass\n",
295+
" loss.backward()\n",
296+
" optimizer.step()\n",
297+
"\n",
298+
" # Track metrics\n",
299+
" running_loss += loss.item()\n",
300+
" _, predicted = torch.max(outputs.data, 1)\n",
301+
" total += labels.size(0)\n",
302+
" correct += (predicted == labels).sum().item()\n",
303+
"\n",
304+
" # Calculate training metrics\n",
305+
" train_loss = running_loss / len(train_loader)\n",
306+
" train_accuracy = 100 * correct / total\n",
307+
"\n",
308+
" # Validation phase\n",
309+
" model.eval()\n",
310+
" val_running_loss = 0.0\n",
311+
" val_correct = 0\n",
312+
" val_total = 0\n",
313+
"\n",
314+
" with torch.no_grad():\n",
315+
"\n",
316+
" for images, labels in val_loader:\n",
317+
"\n",
318+
" outputs = model(images)\n",
319+
" loss = criterion(outputs, labels)\n",
320+
"\n",
321+
" val_running_loss += loss.item()\n",
322+
" _, predicted = torch.max(outputs.data, 1)\n",
323+
" val_total += labels.size(0)\n",
324+
" val_correct += (predicted == labels).sum().item()\n",
325+
"\n",
326+
" val_loss = val_running_loss / len(val_loader)\n",
327+
" val_accuracy = 100 * val_correct / val_total\n",
328+
"\n",
329+
" # Record metrics\n",
330+
" history['train_loss'].append(train_loss)\n",
331+
" history['val_loss'].append(val_loss)\n",
332+
" history['train_accuracy'].append(train_accuracy)\n",
333+
" history['val_accuracy'].append(val_accuracy)\n",
334+
"\n",
335+
" # Print progress\n",
336+
" if (epoch + 1) % print_every == 0 or epoch == 0:\n",
337+
"\n",
338+
" print(\n",
339+
" f'Epoch {epoch+1}/{epochs} - ' +\n",
340+
" f'loss: {train_loss:.4f} - ' +\n",
341+
" f'accuracy: {train_accuracy:.2f}% - ' +\n",
342+
" f'val_loss: {val_loss:.4f} - ' +\n",
343+
" f'val_accuracy: {val_accuracy:.2f}%'\n",
344+
" )\n",
345+
"\n",
346+
" print('\\nTraining complete.')\n",
347+
"\n",
348+
" return history"
349+
]
350+
},
267351
{
268352
"cell_type": "code",
269353
"execution_count": null,
@@ -295,6 +379,45 @@
295379
"### 1.5. Evaluate and Visualize Results"
296380
]
297381
},
382+
{
383+
"cell_type": "code",
384+
"execution_count": null,
385+
"id": "08942db9",
386+
"metadata": {},
387+
"outputs": [],
388+
"source": [
389+
"def evaluate_model(\n",
390+
" model: nn.Module,\n",
391+
" test_loader: DataLoader\n",
392+
") -> tuple[float, np.ndarray, np.ndarray]:\n",
393+
" '''Evaluate model on test set.\n",
394+
" \n",
395+
" Note: Assumes data is already on the correct device.\n",
396+
" '''\n",
397+
"\n",
398+
" model.eval()\n",
399+
" correct = 0\n",
400+
" total = 0\n",
401+
" all_predictions = []\n",
402+
" all_labels = []\n",
403+
"\n",
404+
" with torch.no_grad():\n",
405+
"\n",
406+
" for images, labels in test_loader:\n",
407+
"\n",
408+
" outputs = model(images)\n",
409+
" _, predicted = torch.max(outputs.data, 1)\n",
410+
"\n",
411+
" total += labels.size(0)\n",
412+
" correct += (predicted == labels).sum().item()\n",
413+
"\n",
414+
" all_predictions.extend(predicted.cpu().numpy())\n",
415+
" all_labels.extend(labels.cpu().numpy())\n",
416+
"\n",
417+
" accuracy = 100 * correct / total\n",
418+
" return accuracy, np.array(all_predictions), np.array(all_labels)"
419+
]
420+
},
298421
{
299422
"cell_type": "code",
300423
"execution_count": null,
@@ -495,7 +618,7 @@
495618
"# TODO: Update the first conv layer to accept 3 channels\n",
496619
"model_exp2 = nn.Sequential(\n",
497620
" # Conv block: TODO: Change input channels from 1 to 3\n",
498-
" nn.Conv2d(1, 32, kernel_size=3, padding=1), # TODO: Change first argument to 3\n",
621+
" nn.Conv2d(1, 32, kernel_size=3, padding=1),\n",
499622
" nn.BatchNorm2d(32),\n",
500623
" nn.ReLU(),\n",
501624
" nn.Conv2d(32, 32, kernel_size=3, padding=1),\n",

notebooks/unit4/lesson_31/Lesson_31_demo.ipynb

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,101 @@
322322
"### 2.3. Train model"
323323
]
324324
},
325+
{
326+
"cell_type": "code",
327+
"execution_count": null,
328+
"id": "b5f22f21",
329+
"metadata": {},
330+
"outputs": [],
331+
"source": [
332+
"def train_model(\n",
333+
" model: nn.Module,\n",
334+
" train_loader: DataLoader,\n",
335+
" val_loader: DataLoader,\n",
336+
" criterion: nn.Module,\n",
337+
" optimizer: optim.Optimizer,\n",
338+
" epochs: int = 10,\n",
339+
" print_every: int = 1\n",
340+
") -> dict[str, list[float]]:\n",
341+
" '''Training loop for PyTorch classification model.\n",
342+
" \n",
343+
" Note: Assumes data is already on the correct device.\n",
344+
" '''\n",
345+
"\n",
346+
" history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []}\n",
347+
"\n",
348+
" for epoch in range(epochs):\n",
349+
"\n",
350+
" # Training phase\n",
351+
" model.train()\n",
352+
" running_loss = 0.0\n",
353+
" correct = 0\n",
354+
" total = 0\n",
355+
"\n",
356+
" for images, labels in train_loader:\n",
357+
"\n",
358+
" # Forward pass\n",
359+
" optimizer.zero_grad()\n",
360+
" outputs = model(images)\n",
361+
" loss = criterion(outputs, labels)\n",
362+
"\n",
363+
" # Backward pass\n",
364+
" loss.backward()\n",
365+
" optimizer.step()\n",
366+
"\n",
367+
" # Track metrics\n",
368+
" running_loss += loss.item()\n",
369+
" _, predicted = torch.max(outputs.data, 1)\n",
370+
" total += labels.size(0)\n",
371+
" correct += (predicted == labels).sum().item()\n",
372+
"\n",
373+
" # Calculate training metrics\n",
374+
" train_loss = running_loss / len(train_loader)\n",
375+
" train_accuracy = 100 * correct / total\n",
376+
"\n",
377+
" # Validation phase\n",
378+
" model.eval()\n",
379+
" val_running_loss = 0.0\n",
380+
" val_correct = 0\n",
381+
" val_total = 0\n",
382+
"\n",
383+
" with torch.no_grad():\n",
384+
"\n",
385+
" for images, labels in val_loader:\n",
386+
"\n",
387+
" outputs = model(images)\n",
388+
" loss = criterion(outputs, labels)\n",
389+
"\n",
390+
" val_running_loss += loss.item()\n",
391+
" _, predicted = torch.max(outputs.data, 1)\n",
392+
" val_total += labels.size(0)\n",
393+
" val_correct += (predicted == labels).sum().item()\n",
394+
"\n",
395+
" val_loss = val_running_loss / len(val_loader)\n",
396+
" val_accuracy = 100 * val_correct / val_total\n",
397+
"\n",
398+
" # Record metrics\n",
399+
" history['train_loss'].append(train_loss)\n",
400+
" history['val_loss'].append(val_loss)\n",
401+
" history['train_accuracy'].append(train_accuracy)\n",
402+
" history['val_accuracy'].append(val_accuracy)\n",
403+
"\n",
404+
" # Print progress\n",
405+
" if (epoch + 1) % print_every == 0 or epoch == 0:\n",
406+
"\n",
407+
" print(\n",
408+
" f'Epoch {epoch+1}/{epochs} - ' +\n",
409+
" f'loss: {train_loss:.4f} - ' +\n",
410+
" f'accuracy: {train_accuracy:.2f}% - ' +\n",
411+
" f'val_loss: {val_loss:.4f} - ' +\n",
412+
" f'val_accuracy: {val_accuracy:.2f}%'\n",
413+
" )\n",
414+
"\n",
415+
" print('\\nTraining complete.')\n",
416+
"\n",
417+
" return history"
418+
]
419+
},
325420
{
326421
"cell_type": "code",
327422
"execution_count": null,
@@ -387,6 +482,45 @@
387482
"### 3.1. Calculate test accuracy"
388483
]
389484
},
485+
{
486+
"cell_type": "code",
487+
"execution_count": null,
488+
"id": "59bdd8c6",
489+
"metadata": {},
490+
"outputs": [],
491+
"source": [
492+
"def evaluate_model(\n",
493+
" model: nn.Module,\n",
494+
" test_loader: DataLoader\n",
495+
") -> tuple[float, np.ndarray, np.ndarray]:\n",
496+
" '''Evaluate model on test set.\n",
497+
" \n",
498+
" Note: Assumes data is already on the correct device.\n",
499+
" '''\n",
500+
"\n",
501+
" model.eval()\n",
502+
" correct = 0\n",
503+
" total = 0\n",
504+
" all_predictions = []\n",
505+
" all_labels = []\n",
506+
"\n",
507+
" with torch.no_grad():\n",
508+
"\n",
509+
" for images, labels in test_loader:\n",
510+
"\n",
511+
" outputs = model(images)\n",
512+
" _, predicted = torch.max(outputs.data, 1)\n",
513+
"\n",
514+
" total += labels.size(0)\n",
515+
" correct += (predicted == labels).sum().item()\n",
516+
"\n",
517+
" all_predictions.extend(predicted.cpu().numpy())\n",
518+
" all_labels.extend(labels.cpu().numpy())\n",
519+
"\n",
520+
" accuracy = 100 * correct / total\n",
521+
" return accuracy, np.array(all_predictions), np.array(all_labels)"
522+
]
523+
},
390524
{
391525
"cell_type": "code",
392526
"execution_count": null,
@@ -581,6 +715,31 @@
581715
"plt.tight_layout()\n",
582716
"plt.show()"
583717
]
718+
},
719+
{
720+
"cell_type": "markdown",
721+
"id": "e7652023",
722+
"metadata": {},
723+
"source": [
724+
"## 4. Save model for inference"
725+
]
726+
},
727+
{
728+
"cell_type": "code",
729+
"execution_count": null,
730+
"id": "6743f566",
731+
"metadata": {},
732+
"outputs": [],
733+
"source": [
734+
"# Define model save path\n",
735+
"model_dir = Path('../models')\n",
736+
"model_dir.mkdir(parents=True, exist_ok=True)\n",
737+
"model_path = model_dir / 'cifar10_cnn_model.pth'\n",
738+
"\n",
739+
"# Save the model state dict (recommended for inference)\n",
740+
"torch.save(model.state_dict(), model_path)\n",
741+
"print(f'Model saved to: {model_path.resolve()}')"
742+
]
584743
}
585744
],
586745
"metadata": {

0 commit comments

Comments
 (0)