Skip to content

Commit d4354a6

Browse files
author
valhassan
committed
Refactor imports and improve code formatting in 00_quickstart notebook for better readability and organization.
1 parent 6995164 commit d4354a6

File tree

1 file changed

+42
-29
lines changed

1 file changed

+42
-29
lines changed

notebooks/00_quickstart.ipynb

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -74,26 +74,29 @@
7474
],
7575
"source": [
7676
"import time\n",
77+
"\n",
7778
"start_total = time.time()\n",
7879
"\n",
7980
"print(\"Importing required libraries...\")\n",
8081
"\n",
8182
"# Core modules\n",
82-
"import os, sys, platform, random\n",
83-
"import zipfile\n",
83+
"# Data and ML\n",
84+
"import csv\n",
85+
"import platform\n",
86+
"import sys\n",
8487
"import warnings\n",
88+
"import zipfile\n",
8589
"from pathlib import Path\n",
8690
"\n",
87-
"# Data and ML\n",
88-
"import csv\n",
89-
"import numpy as np\n",
90-
"import torch\n",
9191
"import lightning as L\n",
92-
"import rasterio as rio\n",
9392
"import matplotlib.pyplot as plt\n",
93+
"import numpy as np\n",
94+
"import rasterio as rio\n",
95+
"import torch\n",
9496
"\n",
9597
"# Remove warnings from not georeferenced dataset (for this example only)\n",
9698
"from rasterio.errors import NotGeoreferencedWarning\n",
99+
"\n",
97100
"warnings.filterwarnings(\"ignore\", category=NotGeoreferencedWarning)\n",
98101
"\n",
99102
"# Append root path to make module work from notebook (might differ in your environment)\n",
@@ -151,7 +154,7 @@
151154
"source": [
152155
"# Define path to the archive and extract location\n",
153156
"zip_path = Path(\"../data/waterloo_subset_512.zip\")\n",
154-
"extract_dir = zip_path.with_suffix('') # removes .zip\n",
157+
"extract_dir = zip_path.with_suffix(\"\") # removes .zip\n",
155158
"\n",
156159
"# Unzip only if not already done\n",
157160
"if not extract_dir.exists():\n",
@@ -212,7 +215,7 @@
212215
"metadata": {},
213216
"outputs": [
214217
{
215-
"name": "stderr",
218+
"name": "stdout",
216219
"output_type": "stream",
217220
"text": [
218221
"Remapping trn labels: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 339.72it/s]\n",
@@ -222,17 +225,19 @@
222225
}
223226
],
224227
"source": [
225-
"from tqdm import tqdm\n",
226228
"import numpy as np\n",
227229
"import rasterio as rio\n",
230+
"from tqdm import tqdm\n",
228231
"\n",
229232
"# Remap all labels in the dataset\n",
230233
"for split in [\"trn\", \"val\", \"tst\"]:\n",
231234
" lbl_dir = extract_dir / split / \"label\"\n",
232235
" if not lbl_dir.exists():\n",
233236
" continue\n",
234237
"\n",
235-
" for lbl_path in tqdm(sorted(lbl_dir.glob(\"*.tif\")), desc=f\"Remapping {split} labels\"):\n",
238+
" for lbl_path in tqdm(\n",
239+
" sorted(lbl_dir.glob(\"*.tif\")), desc=f\"Remapping {split} labels\",\n",
240+
" ):\n",
236241
" with rio.open(lbl_path) as lbl_ds:\n",
237242
" lbl = lbl_ds.read(1)\n",
238243
"\n",
@@ -284,7 +289,7 @@
284289
" lbl_dir = extract_dir / split / \"label\"\n",
285290
" csv_path = extract_dir / f\"{split}.csv\"\n",
286291
"\n",
287-
" # Collect matching imagelabel pairs\n",
292+
" # Collect matching image-label pairs\n",
288293
" rows = []\n",
289294
" for img_path in sorted(img_dir.glob(\"*.tif\")):\n",
290295
" lbl_path = lbl_dir / img_path.name\n",
@@ -294,7 +299,7 @@
294299
" print(f\"No matching label found for {img_path.name}\")\n",
295300
"\n",
296301
" # Write CSV\n",
297-
" with open(csv_path, \"w\", newline=\"\") as f:\n",
302+
" with csv_path.open(\"w\", newline=\"\") as f:\n",
298303
" writer = csv.writer(f, delimiter=\";\")\n",
299304
" writer.writerows(rows)\n",
300305
"\n",
@@ -336,7 +341,7 @@
336341
"from geo_deep_learning.datasets.csv_dataset import CSVDataset\n",
337342
"\n",
338343
"# Define dataset paths previously extracted from the ZIP\n",
339-
"dataset_root = extract_dir \n",
344+
"dataset_root = extract_dir\n",
340345
"\n",
341346
"# Change mask dtype to match SoftCrossEntropyLoss\n",
342347
"def _load_mask_int64(self, index: int):\n",
@@ -467,22 +472,25 @@
467472
"metadata": {},
468473
"outputs": [
469474
{
470-
"name": "stderr",
475+
"name": "stdout",
471476
"output_type": "stream",
472477
"text": [
473478
"/home/lromanin/miniforge3/envs/gdl_env_v09/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.\n"
474479
]
475480
}
476481
],
477482
"source": [
478-
"from geo_deep_learning.tasks_with_models.segmentation_unetplus import SegmentationUnetPlus\n",
479-
"from segmentation_models_pytorch.losses import SoftCrossEntropyLoss\n",
480483
"import torch\n",
484+
"from segmentation_models_pytorch.losses import SoftCrossEntropyLoss\n",
481485
"from torch.optim import Adam\n",
482486
"from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
483487
"\n",
488+
"from geo_deep_learning.tasks_with_models.segmentation_unetplus import (\n",
489+
" SegmentationUnetPlus,\n",
490+
")\n",
491+
"\n",
484492
"# Loss function instance (multi-class, 2 classes: background + buildings)\n",
485-
"loss_fn = SoftCrossEntropyLoss(smooth_factor=0.1) \n",
493+
"loss_fn = SoftCrossEntropyLoss(smooth_factor=0.1)\n",
486494
"\n",
487495
"# Optimizer and scheduler configs\n",
488496
"optimizer_class = Adam\n",
@@ -511,7 +519,7 @@
511519
" scheduler_config={\n",
512520
" \"interval\": \"epoch\",\n",
513521
" \"frequency\": 1,\n",
514-
" \"monitor\": \"val_loss\"\n",
522+
" \"monitor\": \"val_loss\",\n",
515523
" },\n",
516524
" class_labels=[\"background\", \"buildings\"],\n",
517525
" class_colors=[\"#000000\", \"#FF0000\"],\n",
@@ -543,7 +551,7 @@
543551
"metadata": {},
544552
"outputs": [
545553
{
546-
"name": "stderr",
554+
"name": "stdout",
547555
"output_type": "stream",
548556
"text": [
549557
"GPU available: True (cuda), used: True\n",
@@ -583,7 +591,7 @@
583591
"output_type": "display_data"
584592
},
585593
{
586-
"name": "stderr",
594+
"name": "stdout",
587595
"output_type": "stream",
588596
"text": [
589597
"/home/lromanin/miniforge3/envs/gdl_env_v09/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.\n",
@@ -746,7 +754,7 @@
746754
"output_type": "display_data"
747755
},
748756
{
749-
"name": "stderr",
757+
"name": "stdout",
750758
"output_type": "stream",
751759
"text": [
752760
"`Trainer.fit` stopped: `max_epochs=10` reached.\n"
@@ -756,12 +764,12 @@
756764
"source": [
757765
"import pandas as pd\n",
758766
"from lightning.pytorch import Trainer\n",
759-
"from lightning.pytorch.loggers import MLFlowLogger\n",
760767
"from lightning.pytorch.callbacks import TQDMProgressBar\n",
768+
"from lightning.pytorch.loggers import MLFlowLogger\n",
761769
"\n",
762770
"logger = MLFlowLogger(\n",
763771
" experiment_name=\"unet_segmentation\",\n",
764-
" tracking_uri=\"file:./mlruns\"\n",
772+
" tracking_uri=\"file:./mlruns\",\n",
765773
")\n",
766774
"\n",
767775
"# Define trainer\n",
@@ -804,7 +812,7 @@
804812
"<Axes: >"
805813
]
806814
},
807-
"execution_count": 9,
815+
"execution_count": null,
808816
"metadata": {},
809817
"output_type": "execute_result"
810818
},
@@ -852,7 +860,7 @@
852860
"metadata": {},
853861
"outputs": [
854862
{
855-
"name": "stderr",
863+
"name": "stdout",
856864
"output_type": "stream",
857865
"text": [
858866
"Restoring states from the checkpoint path at ./mlruns/406009257167130993/c7dc0307d39346e4ad7d3d45b19dcff8/checkpoints/epoch=9-step=80.ckpt\n",
@@ -908,14 +916,18 @@
908916
" 'test_loss': 0.2789146900177002}]"
909917
]
910918
},
911-
"execution_count": 10,
919+
"execution_count": null,
912920
"metadata": {},
913921
"output_type": "execute_result"
914922
}
915923
],
916924
"source": [
917925
"# Run evaluation on the test set (defined by the csv_datamodule)\n",
918-
"trainer.test(model, datamodule=dm, ckpt_path=trainer.checkpoint_callback.best_model_path)"
926+
"trainer.test(\n",
927+
" model,\n",
928+
" datamodule=dm,\n",
929+
" ckpt_path=trainer.checkpoint_callback.best_model_path,\n",
930+
")"
919931
]
920932
},
921933
{
@@ -992,8 +1004,9 @@
9921004
],
9931005
"source": [
9941006
"from pathlib import Path\n",
995-
"from PIL import Image\n",
1007+
"\n",
9961008
"import matplotlib.pyplot as plt\n",
1009+
"from PIL import Image\n",
9971010
"\n",
9981011
"# Identify MLflow experiment/run ids\n",
9991012
"print(\"Experiment:\", logger.experiment_id, \"Run:\", logger.run_id)\n",

0 commit comments

Comments
 (0)