|
232 | 232 | " nn.MaxPool2d(2, 2),\n", |
233 | 233 | " nn.Dropout(0.5),\n", |
234 | 234 | " \n", |
235 | | - " # TODO: Add more convolutional blocks here\n", |
236 | | - " # Example (uncomment and modify):\n", |
237 | | - " # nn.Conv2d(32, 64, kernel_size=3, padding=1),\n", |
238 | | - " # nn.BatchNorm2d(64),\n", |
239 | | - " # nn.ReLU(),\n", |
240 | | - " # nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", |
241 | | - " # nn.BatchNorm2d(64),\n", |
242 | | - " # nn.ReLU(),\n", |
243 | | - " # nn.MaxPool2d(2, 2),\n", |
244 | | - " # nn.Dropout(0.5),\n", |
245 | | - " \n", |
246 | 235 | " # Classifier\n", |
247 | 236 | " nn.Flatten(),\n", |
248 | 237 | " nn.Linear(32 * 16 * 16, 128), # TODO: Update input size if you add more layers\n", |
|
264 | 253 | "### 1.4. Train Modified Model" |
265 | 254 | ] |
266 | 255 | }, |
| 256 | + { |
| 257 | + "cell_type": "code", |
| 258 | + "execution_count": null, |
| 259 | + "id": "8e4102e5", |
| 260 | + "metadata": {}, |
| 261 | + "outputs": [], |
| 262 | + "source": [ |
| 263 | + "def train_model(\n", |
| 264 | + " model: nn.Module,\n", |
| 265 | + " train_loader: DataLoader,\n", |
| 266 | + " val_loader: DataLoader,\n", |
| 267 | + " criterion: nn.Module,\n", |
| 268 | + " optimizer: optim.Optimizer,\n", |
| 269 | + " epochs: int = 10,\n", |
| 270 | + " print_every: int = 1\n", |
| 271 | + ") -> dict[str, list[float]]:\n", |
| 272 | + " '''Training loop for PyTorch classification model.\n", |
| 273 | + " \n", |
| 274 | + " Note: Assumes data is already on the correct device.\n", |
| 275 | + " '''\n", |
| 276 | + "\n", |
| 277 | + " history = {'train_loss': [], 'val_loss': [], 'train_accuracy': [], 'val_accuracy': []}\n", |
| 278 | + "\n", |
| 279 | + " for epoch in range(epochs):\n", |
| 280 | + "\n", |
| 281 | + " # Training phase\n", |
| 282 | + " model.train()\n", |
| 283 | + " running_loss = 0.0\n", |
| 284 | + " correct = 0\n", |
| 285 | + " total = 0\n", |
| 286 | + "\n", |
| 287 | + " for images, labels in train_loader:\n", |
| 288 | + "\n", |
| 289 | + " # Forward pass\n", |
| 290 | + " optimizer.zero_grad()\n", |
| 291 | + " outputs = model(images)\n", |
| 292 | + " loss = criterion(outputs, labels)\n", |
| 293 | + "\n", |
| 294 | + " # Backward pass\n", |
| 295 | + " loss.backward()\n", |
| 296 | + " optimizer.step()\n", |
| 297 | + "\n", |
| 298 | + " # Track metrics\n", |
| 299 | + " running_loss += loss.item()\n", |
| 300 | + " _, predicted = torch.max(outputs.data, 1)\n", |
| 301 | + " total += labels.size(0)\n", |
| 302 | + " correct += (predicted == labels).sum().item()\n", |
| 303 | + "\n", |
| 304 | + " # Calculate training metrics\n", |
| 305 | + " train_loss = running_loss / len(train_loader)\n", |
| 306 | + " train_accuracy = 100 * correct / total\n", |
| 307 | + "\n", |
| 308 | + " # Validation phase\n", |
| 309 | + " model.eval()\n", |
| 310 | + " val_running_loss = 0.0\n", |
| 311 | + " val_correct = 0\n", |
| 312 | + " val_total = 0\n", |
| 313 | + "\n", |
| 314 | + " with torch.no_grad():\n", |
| 315 | + "\n", |
| 316 | + " for images, labels in val_loader:\n", |
| 317 | + "\n", |
| 318 | + " outputs = model(images)\n", |
| 319 | + " loss = criterion(outputs, labels)\n", |
| 320 | + "\n", |
| 321 | + " val_running_loss += loss.item()\n", |
| 322 | + " _, predicted = torch.max(outputs.data, 1)\n", |
| 323 | + " val_total += labels.size(0)\n", |
| 324 | + " val_correct += (predicted == labels).sum().item()\n", |
| 325 | + "\n", |
| 326 | + " val_loss = val_running_loss / len(val_loader)\n", |
| 327 | + " val_accuracy = 100 * val_correct / val_total\n", |
| 328 | + "\n", |
| 329 | + " # Record metrics\n", |
| 330 | + " history['train_loss'].append(train_loss)\n", |
| 331 | + " history['val_loss'].append(val_loss)\n", |
| 332 | + " history['train_accuracy'].append(train_accuracy)\n", |
| 333 | + " history['val_accuracy'].append(val_accuracy)\n", |
| 334 | + "\n", |
| 335 | + " # Print progress\n", |
| 336 | + " if (epoch + 1) % print_every == 0 or epoch == 0:\n", |
| 337 | + "\n", |
| 338 | + " print(\n", |
| 339 | + " f'Epoch {epoch+1}/{epochs} - ' +\n", |
| 340 | + " f'loss: {train_loss:.4f} - ' +\n", |
| 341 | + " f'accuracy: {train_accuracy:.2f}% - ' +\n", |
| 342 | + " f'val_loss: {val_loss:.4f} - ' +\n", |
| 343 | + " f'val_accuracy: {val_accuracy:.2f}%'\n", |
| 344 | + " )\n", |
| 345 | + "\n", |
| 346 | + " print('\\nTraining complete.')\n", |
| 347 | + "\n", |
| 348 | + " return history" |
| 349 | + ] |
| 350 | + }, |
267 | 351 | { |
268 | 352 | "cell_type": "code", |
269 | 353 | "execution_count": null, |
|
295 | 379 | "### 1.5. Evaluate and Visualize Results" |
296 | 380 | ] |
297 | 381 | }, |
| 382 | + { |
| 383 | + "cell_type": "code", |
| 384 | + "execution_count": null, |
| 385 | + "id": "08942db9", |
| 386 | + "metadata": {}, |
| 387 | + "outputs": [], |
| 388 | + "source": [ |
| 389 | + "def evaluate_model(\n", |
| 390 | + " model: nn.Module,\n", |
| 391 | + " test_loader: DataLoader\n", |
| 392 | + ") -> tuple[float, np.ndarray, np.ndarray]:\n", |
| 393 | + " '''Evaluate model on test set.\n", |
| 394 | + " \n", |
| 395 | + " Note: Assumes data is already on the correct device.\n", |
| 396 | + " '''\n", |
| 397 | + "\n", |
| 398 | + " model.eval()\n", |
| 399 | + " correct = 0\n", |
| 400 | + " total = 0\n", |
| 401 | + " all_predictions = []\n", |
| 402 | + " all_labels = []\n", |
| 403 | + "\n", |
| 404 | + " with torch.no_grad():\n", |
| 405 | + "\n", |
| 406 | + " for images, labels in test_loader:\n", |
| 407 | + "\n", |
| 408 | + " outputs = model(images)\n", |
| 409 | + " _, predicted = torch.max(outputs.data, 1)\n", |
| 410 | + "\n", |
| 411 | + " total += labels.size(0)\n", |
| 412 | + " correct += (predicted == labels).sum().item()\n", |
| 413 | + "\n", |
| 414 | + " all_predictions.extend(predicted.cpu().numpy())\n", |
| 415 | + " all_labels.extend(labels.cpu().numpy())\n", |
| 416 | + "\n", |
| 417 | + " accuracy = 100 * correct / total\n", |
| 418 | + " return accuracy, np.array(all_predictions), np.array(all_labels)" |
| 419 | + ] |
| 420 | + }, |
298 | 421 | { |
299 | 422 | "cell_type": "code", |
300 | 423 | "execution_count": null, |
|
495 | 618 | "# TODO: Update the first conv layer to accept 3 channels\n", |
496 | 619 | "model_exp2 = nn.Sequential(\n", |
497 | 620 | " # Conv block: TODO: Change input channels from 1 to 3\n", |
498 | | - " nn.Conv2d(1, 32, kernel_size=3, padding=1), # TODO: Change first argument to 3\n", |
| 621 | + " nn.Conv2d(1, 32, kernel_size=3, padding=1),\n", |
499 | 622 | " nn.BatchNorm2d(32),\n", |
500 | 623 | " nn.ReLU(),\n", |
501 | 624 | " nn.Conv2d(32, 32, kernel_size=3, padding=1),\n", |
|
0 commit comments