|
74 | 74 | ], |
75 | 75 | "source": [ |
76 | 76 | "import time\n", |
| 77 | + "\n", |
77 | 78 | "start_total = time.time()\n", |
78 | 79 | "\n", |
79 | 80 | "print(\"Importing required libraries...\")\n", |
80 | 81 | "\n", |
81 | 82 | "# Core modules\n", |
82 | | - "import os, sys, platform, random\n", |
83 | | - "import zipfile\n", |
| 83 | + "# Data and ML\n", |
| 84 | + "import csv\n", |
| 85 | + "import platform\n", |
| 86 | + "import sys\n", |
84 | 87 | "import warnings\n", |
| 88 | + "import zipfile\n", |
85 | 89 | "from pathlib import Path\n", |
86 | 90 | "\n", |
87 | | - "# Data and ML\n", |
88 | | - "import csv\n", |
89 | | - "import numpy as np\n", |
90 | | - "import torch\n", |
91 | 91 | "import lightning as L\n", |
92 | | - "import rasterio as rio\n", |
93 | 92 | "import matplotlib.pyplot as plt\n", |
| 93 | + "import numpy as np\n", |
| 94 | + "import rasterio as rio\n", |
| 95 | + "import torch\n", |
94 | 96 | "\n", |
95 | 97 | "# Remove warnings from not georeferenced dataset (for this example only)\n", |
96 | 98 | "from rasterio.errors import NotGeoreferencedWarning\n", |
| 99 | + "\n", |
97 | 100 | "warnings.filterwarnings(\"ignore\", category=NotGeoreferencedWarning)\n", |
98 | 101 | "\n", |
99 | 102 | "# Append root path to make module work from notebook (might differ in your environment)\n", |
|
151 | 154 | "source": [ |
152 | 155 | "# Define path to the archive and extract location\n", |
153 | 156 | "zip_path = Path(\"../data/waterloo_subset_512.zip\")\n", |
154 | | - "extract_dir = zip_path.with_suffix('') # removes .zip\n", |
| 157 | + "extract_dir = zip_path.with_suffix(\"\") # removes .zip\n", |
155 | 158 | "\n", |
156 | 159 | "# Unzip only if not already done\n", |
157 | 160 | "if not extract_dir.exists():\n", |
|
212 | 215 | "metadata": {}, |
213 | 216 | "outputs": [ |
214 | 217 | { |
215 | | - "name": "stderr", |
| 218 | + "name": "stdout", |
216 | 219 | "output_type": "stream", |
217 | 220 | "text": [ |
218 | 221 | "Remapping trn labels: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 339.72it/s]\n", |
|
222 | 225 | } |
223 | 226 | ], |
224 | 227 | "source": [ |
225 | | - "from tqdm import tqdm\n", |
226 | 228 | "import numpy as np\n", |
227 | 229 | "import rasterio as rio\n", |
| 230 | + "from tqdm import tqdm\n", |
228 | 231 | "\n", |
229 | 232 | "# Remap all labels in the dataset\n", |
230 | 233 | "for split in [\"trn\", \"val\", \"tst\"]:\n", |
231 | 234 | " lbl_dir = extract_dir / split / \"label\"\n", |
232 | 235 | " if not lbl_dir.exists():\n", |
233 | 236 | " continue\n", |
234 | 237 | "\n", |
235 | | - " for lbl_path in tqdm(sorted(lbl_dir.glob(\"*.tif\")), desc=f\"Remapping {split} labels\"):\n", |
| 238 | + " for lbl_path in tqdm(\n", |
| 239 | + " sorted(lbl_dir.glob(\"*.tif\")), desc=f\"Remapping {split} labels\",\n", |
| 240 | + " ):\n", |
236 | 241 | " with rio.open(lbl_path) as lbl_ds:\n", |
237 | 242 | " lbl = lbl_ds.read(1)\n", |
238 | 243 | "\n", |
|
284 | 289 | " lbl_dir = extract_dir / split / \"label\"\n", |
285 | 290 | " csv_path = extract_dir / f\"{split}.csv\"\n", |
286 | 291 | "\n", |
287 | | - " # Collect matching image–label pairs\n", |
| 292 | + " # Collect matching image-label pairs\n", |
288 | 293 | " rows = []\n", |
289 | 294 | " for img_path in sorted(img_dir.glob(\"*.tif\")):\n", |
290 | 295 | " lbl_path = lbl_dir / img_path.name\n", |
|
294 | 299 | " print(f\"No matching label found for {img_path.name}\")\n", |
295 | 300 | "\n", |
296 | 301 | " # Write CSV\n", |
297 | | - " with open(csv_path, \"w\", newline=\"\") as f:\n", |
| 302 | + " with csv_path.open(\"w\", newline=\"\") as f:\n", |
298 | 303 | " writer = csv.writer(f, delimiter=\";\")\n", |
299 | 304 | " writer.writerows(rows)\n", |
300 | 305 | "\n", |
|
336 | 341 | "from geo_deep_learning.datasets.csv_dataset import CSVDataset\n", |
337 | 342 | "\n", |
338 | 343 | "# Define dataset paths previously extracted from the ZIP\n", |
339 | | - "dataset_root = extract_dir \n", |
| 344 | + "dataset_root = extract_dir\n", |
340 | 345 | "\n", |
341 | 346 | "# Change mask dtype to match SoftCrossEntropyLoss\n", |
342 | 347 | "def _load_mask_int64(self, index: int):\n", |
|
467 | 472 | "metadata": {}, |
468 | 473 | "outputs": [ |
469 | 474 | { |
470 | | - "name": "stderr", |
| 475 | + "name": "stdout", |
471 | 476 | "output_type": "stream", |
472 | 477 | "text": [ |
473 | 478 | "/home/lromanin/miniforge3/envs/gdl_env_v09/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.\n" |
474 | 479 | ] |
475 | 480 | } |
476 | 481 | ], |
477 | 482 | "source": [ |
478 | | - "from geo_deep_learning.tasks_with_models.segmentation_unetplus import SegmentationUnetPlus\n", |
479 | | - "from segmentation_models_pytorch.losses import SoftCrossEntropyLoss\n", |
480 | 483 | "import torch\n", |
| 484 | + "from segmentation_models_pytorch.losses import SoftCrossEntropyLoss\n", |
481 | 485 | "from torch.optim import Adam\n", |
482 | 486 | "from torch.optim.lr_scheduler import ReduceLROnPlateau\n", |
483 | 487 | "\n", |
| 488 | + "from geo_deep_learning.tasks_with_models.segmentation_unetplus import (\n", |
| 489 | + " SegmentationUnetPlus,\n", |
| 490 | + ")\n", |
| 491 | + "\n", |
484 | 492 | "# Loss function instance (multi-class, 2 classes: background + buildings)\n", |
485 | | - "loss_fn = SoftCrossEntropyLoss(smooth_factor=0.1) \n", |
| 493 | + "loss_fn = SoftCrossEntropyLoss(smooth_factor=0.1)\n", |
486 | 494 | "\n", |
487 | 495 | "# Optimizer and scheduler configs\n", |
488 | 496 | "optimizer_class = Adam\n", |
|
511 | 519 | " scheduler_config={\n", |
512 | 520 | " \"interval\": \"epoch\",\n", |
513 | 521 | " \"frequency\": 1,\n", |
514 | | - " \"monitor\": \"val_loss\"\n", |
| 522 | + " \"monitor\": \"val_loss\",\n", |
515 | 523 | " },\n", |
516 | 524 | " class_labels=[\"background\", \"buildings\"],\n", |
517 | 525 | " class_colors=[\"#000000\", \"#FF0000\"],\n", |
|
543 | 551 | "metadata": {}, |
544 | 552 | "outputs": [ |
545 | 553 | { |
546 | | - "name": "stderr", |
| 554 | + "name": "stdout", |
547 | 555 | "output_type": "stream", |
548 | 556 | "text": [ |
549 | 557 | "GPU available: True (cuda), used: True\n", |
|
583 | 591 | "output_type": "display_data" |
584 | 592 | }, |
585 | 593 | { |
586 | | - "name": "stderr", |
| 594 | + "name": "stdout", |
587 | 595 | "output_type": "stream", |
588 | 596 | "text": [ |
589 | 597 | "/home/lromanin/miniforge3/envs/gdl_env_v09/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.\n", |
|
746 | 754 | "output_type": "display_data" |
747 | 755 | }, |
748 | 756 | { |
749 | | - "name": "stderr", |
| 757 | + "name": "stdout", |
750 | 758 | "output_type": "stream", |
751 | 759 | "text": [ |
752 | 760 | "`Trainer.fit` stopped: `max_epochs=10` reached.\n" |
|
756 | 764 | "source": [ |
757 | 765 | "import pandas as pd\n", |
758 | 766 | "from lightning.pytorch import Trainer\n", |
759 | | - "from lightning.pytorch.loggers import MLFlowLogger\n", |
760 | 767 | "from lightning.pytorch.callbacks import TQDMProgressBar\n", |
| 768 | + "from lightning.pytorch.loggers import MLFlowLogger\n", |
761 | 769 | "\n", |
762 | 770 | "logger = MLFlowLogger(\n", |
763 | 771 | " experiment_name=\"unet_segmentation\",\n", |
764 | | - " tracking_uri=\"file:./mlruns\"\n", |
| 772 | + " tracking_uri=\"file:./mlruns\",\n", |
765 | 773 | ")\n", |
766 | 774 | "\n", |
767 | 775 | "# Define trainer\n", |
|
804 | 812 | "<Axes: >" |
805 | 813 | ] |
806 | 814 | }, |
807 | | - "execution_count": 9, |
| 815 | + "execution_count": null, |
808 | 816 | "metadata": {}, |
809 | 817 | "output_type": "execute_result" |
810 | 818 | }, |
|
852 | 860 | "metadata": {}, |
853 | 861 | "outputs": [ |
854 | 862 | { |
855 | | - "name": "stderr", |
| 863 | + "name": "stdout", |
856 | 864 | "output_type": "stream", |
857 | 865 | "text": [ |
858 | 866 | "Restoring states from the checkpoint path at ./mlruns/406009257167130993/c7dc0307d39346e4ad7d3d45b19dcff8/checkpoints/epoch=9-step=80.ckpt\n", |
|
908 | 916 | " 'test_loss': 0.2789146900177002}]" |
909 | 917 | ] |
910 | 918 | }, |
911 | | - "execution_count": 10, |
| 919 | + "execution_count": null, |
912 | 920 | "metadata": {}, |
913 | 921 | "output_type": "execute_result" |
914 | 922 | } |
915 | 923 | ], |
916 | 924 | "source": [ |
917 | 925 | "# Run evaluation on the test set (defined by the csv_datamodule)\n", |
918 | | - "trainer.test(model, datamodule=dm, ckpt_path=trainer.checkpoint_callback.best_model_path)" |
| 926 | + "trainer.test(\n", |
| 927 | + " model,\n", |
| 928 | + " datamodule=dm,\n", |
| 929 | + " ckpt_path=trainer.checkpoint_callback.best_model_path,\n", |
| 930 | + ")" |
919 | 931 | ] |
920 | 932 | }, |
921 | 933 | { |
|
992 | 1004 | ], |
993 | 1005 | "source": [ |
994 | 1006 | "from pathlib import Path\n", |
995 | | - "from PIL import Image\n", |
| 1007 | + "\n", |
996 | 1008 | "import matplotlib.pyplot as plt\n", |
| 1009 | + "from PIL import Image\n", |
997 | 1010 | "\n", |
998 | 1011 | "# Identify MLflow experiment/run ids\n", |
999 | 1012 | "print(\"Experiment:\", logger.experiment_id, \"Run:\", logger.run_id)\n", |
|
0 commit comments