From da2399b47997044d9956555bf229cdd453bc493d Mon Sep 17 00:00:00 2001 From: Burak Date: Tue, 9 Apr 2024 15:15:00 +0200 Subject: [PATCH 01/28] minor typo in custom_raster_dataset.ipynb --- docs/tutorials/custom_raster_dataset.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/custom_raster_dataset.ipynb b/docs/tutorials/custom_raster_dataset.ipynb index 7401e580edb..e4da8499114 100644 --- a/docs/tutorials/custom_raster_dataset.ipynb +++ b/docs/tutorials/custom_raster_dataset.ipynb @@ -345,7 +345,7 @@ "\n", "### `rgb_bands`\n", "\n", - "If your data is a multispectral iamge, you can define which bands correspond to the red, green, and blue channels. In the case of Sentinel-2, this corresponds to B04, B03, and B02, in that order.\n", + "If your data is a multispectral image, you can define which bands correspond to the red, green, and blue channels. In the case of Sentinel-2, this corresponds to B04, B03, and B02, in that order.\n", "\n", "Putting this all together into a single class, we get:" ] From 0f57ecf9d88fea403701123a159f3e1bd8707ebb Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Mon, 18 Nov 2024 15:43:34 +0100 Subject: [PATCH 02/28] xview2 dist shift initial commit --- docs/api/datasets.rst | 1 + torchgeo/datasets/__init__.py | 3 +- torchgeo/datasets/xview.py | 169 ++++++++++++++++++++++++++++++++++ 3 files changed, 172 insertions(+), 1 deletion(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 20ce1bfcbac..8b5149ade43 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -462,6 +462,7 @@ xView2 ^^^^^^ .. autoclass:: XView2 +.. autoclass:: XView2DistShift ZueriCrop ^^^^^^^^^ diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 6a15fabdf76..e1daabf3ed9 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -136,7 +136,7 @@ from .vaihingen import Vaihingen2D from .vhr10 import VHR10 from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture -from .xview import XView2 +from .xview import XView2, XView2DistShift from .zuericrop import ZueriCrop __all__ = ( @@ -258,6 +258,7 @@ 'VHR10', 'WesternUSALiveFuelMoisture', 'XView2', + 'XView2DistShift' 'ZueriCrop', # Base classes 'GeoDataset', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 5716c06f593..9854d18458f 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -270,3 +270,172 @@ def plot( plt.suptitle(suptitle) return fig + + +class XView2DistShift(XView2): + """ + A subclass of the XView2 dataset designed to reformat the original train/test splits + based on specific in-domain (ID) and out-of-domain (OOD) disasters. + + This class allows for the selection of particular disasters to be used as the + training set (in-domain) and test set (out-of-domain). The dataset can be split + according to the disaster names specified by the user, enabling the model to train + on one disaster type and evaluate on a different, out-of-domain disaster. The goal + is to test the generalization ability of models trained on one disaster to perform + on others. + """ + + classes = ["background", "building"] + + # List of possible disaster names + valid_disasters = [ + 'hurricane-harvey', 'socal-fire', 'hurricane-matthew', 'mexico-earthquake', + 'guatemala-volcano', 'santa-rosa-wildfire', 'palu-tsunami', 'hurricane-florence', + 'hurricane-michael', 'midwest-flooding' + ] + + def __init__( + self, + root: str = "data", + split: str = "train", + id_ood_disaster: list[dict[str, str]] = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + checksum: bool = False, + ) -> None: + """Initialize the XView2DistShift dataset instance. + + Args: + root: Root directory where the dataset is located. + split: One of "train" or "test". + id_ood_disaster: List containing in-distribution and out-of-distribution disaster names. + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is invalid. + ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. + DatasetNotFoundError: If dataset is not found. + """ + assert split in ["train", "test"], "Split must be either 'train' or 'test'." + # Validate that the disasters are valid + + if id_ood_disaster[0]['disaster_name'] not in self.valid_disasters or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters: + raise ValueError(f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}") + + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + self._verify() + + # Load all files and compute basenames and disasters only once + self.all_files = self._initialize_files(root) + + # Split logic by disaster and pre-post type + self.files = self._load_split_files_by_disaster_and_type(self.all_files, id_ood_disaster[0], id_ood_disaster[1]) + print(f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files.") + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get an item from the dataset at the given index.""" + file_info = ( + self.files["train"][index] + if self.split == "train" + else self.files["test"][index]) + + image = self._load_image(file_info["image"]).to("cuda") + mask = self._load_target(file_info["mask"]).long().to("cuda") + mask[mask == 2] = 1 + mask[(mask == 3) | (mask == 4)] = 0 + + sample = {"image": image, "mask": mask} + + if self.transforms: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the total number of samples in the dataset.""" + return ( + len(self.files["train"]) + if self.split == "train" + else len(self.files["test"]) + ) + + def _initialize_files(self, root: Path) -> List[Dict[str, str]]: + """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" + all_files = [] + for split in self.metadata.keys(): + image_root = os.path.join(root, split, "images") + mask_root = os.path.join(root, split, "targets") + images = glob.glob(os.path.join(image_root, "*.png")) + + # Extract basenames while preserving the event-name and sample number + for img in images: + basename_parts = os.path.basename(img).split("_") + event_name = basename_parts[0] # e.g., mexico-earthquake + sample_number = basename_parts[1] # e.g., 00000001 + basename = ( + f"{event_name}_{sample_number}" # e.g., mexico-earthquake_00000001 + ) + + + file_info = { + "image": img, + "mask": os.path.join( + mask_root, f"{basename}_pre_disaster_target.png" + ), + "basename": basename, + } + all_files.append(file_info) + return all_files + + def _load_split_files_by_disaster_and_type( + self, files: List[Dict[str, str]], id_disaster: Dict[str, str], ood_disaster: Dict[str, str] + ) -> Dict[str, List[Dict[str, str]]]: + """ + Return the paths of the files for the train (ID) and test (OOD) sets based on the specified disaster name + and pre-post disaster type. + + Args: + files: List of file paths with their corresponding information. + id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {"disaster_name": "guatemala-volcano", "pre-post": "pre"}). + ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {"disaster_name": "mexico-earthquake", "pre-post": "post"}). + + Returns: + A dictionary containing 'train' (ID) and 'test' (OOD) file lists. + """ + train_files = [] + test_files = [] + disaster_list = [] + + for file_info in files: + basename = file_info["basename"] + disaster_name = basename.split("_")[0] # Extract disaster name from basename + pre_post = ("pre" if "pre_disaster" in file_info["image"] else "post") # Identify pre/post type + + disaster_list.append(disaster_name) + + # Filter for in-domain (ID) training set + if disaster_name == id_disaster["disaster_name"]: + if id_disaster.get("pre-post") == "both" or id_disaster["pre-post"] == pre_post: + image = ( + file_info["image"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["image"] + ) + mask = ( + file_info["mask"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["mask"] + ) + train_files.append(dict(image=image, mask=mask)) + + # Filter for out-of-domain (OOD) test set + if disaster_name == ood_disaster["disaster_name"]: + if ood_disaster.get("pre-post") == "both" or ood_disaster["pre-post"] == pre_post: + test_files.append(file_info) + + return {"train": train_files, "test": test_files, "disasters":disaster_list} \ No newline at end of file From 62919bfcf08f73d74ce9553d617319b57437fc49 Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Mon, 18 Nov 2024 15:46:50 +0100 Subject: [PATCH 03/28] xview2distshift dataset --- docs/api/datasets.rst | 1 + torchgeo/datasets/__init__.py | 3 +- torchgeo/datasets/xview.py | 169 ++++++++++++++++++++++++++++++++++ 3 files changed, 172 insertions(+), 1 deletion(-) diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 20ce1bfcbac..8b5149ade43 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -462,6 +462,7 @@ xView2 ^^^^^^ .. autoclass:: XView2 +.. autoclass:: XView2DistShift ZueriCrop ^^^^^^^^^ diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 6a15fabdf76..f84760ce865 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -136,7 +136,7 @@ from .vaihingen import Vaihingen2D from .vhr10 import VHR10 from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture -from .xview import XView2 +from .xview import XView2, XView2DistShift from .zuericrop import ZueriCrop __all__ = ( @@ -258,6 +258,7 @@ 'VHR10', 'WesternUSALiveFuelMoisture', 'XView2', + 'XView2DistShift', 'ZueriCrop', # Base classes 'GeoDataset', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 5716c06f593..98fe3ceb534 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -270,3 +270,172 @@ def plot( plt.suptitle(suptitle) return fig + +class XView2DistShift(XView2): + """ + A subclass of the XView2 dataset designed to reformat the original train/test splits + based on specific in-domain (ID) and out-of-domain (OOD) disasters. + + This class allows for the selection of particular disasters to be used as the + training set (in-domain) and test set (out-of-domain). The dataset can be split + according to the disaster names specified by the user, enabling the model to train + on one disaster type and evaluate on a different, out-of-domain disaster. The goal + is to test the generalization ability of models trained on one disaster to perform + on others. + """ + + classes = ["background", "building"] + + # List of disaster names + valid_disasters = [ + 'hurricane-harvey', 'socal-fire', 'hurricane-matthew', 'mexico-earthquake', + 'guatemala-volcano', 'santa-rosa-wildfire', 'palu-tsunami', 'hurricane-florence', + 'hurricane-michael', 'midwest-flooding' + ] + + def __init__( + self, + root: Path = "data", + split: str = "train", + id_ood_disaster: List[Dict[str, str]] = [{"disaster_name": "hurricane-matthew", "pre-post": "post"}, {"disaster_name": "mexico-earthquake", "pre-post": "post"}], + transforms: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, + checksum: bool = False, + **kwargs + ) -> None: + """Initialize the XView2DistShift dataset instance. + + Args: + root: Root directory where the dataset is located. + split: One of "train" or "test". + id_ood_disaster: List containing in-distribution and out-of-distribution disaster names. + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is invalid. + ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. + DatasetNotFoundError: If dataset is not found. + """ + assert split in ["train", "test"], "Split must be either 'train' or 'test'." + # Validate that the disasters are valid + + if id_ood_disaster[0]['disaster_name'] not in self.valid_disasters or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters: + raise ValueError(f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}") + + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + self._verify() + + # Load all files and compute basenames and disasters only once + self.all_files = self._initialize_files(root) + + # Split logic by disaster and pre-post type + self.files = self._load_split_files_by_disaster_and_type(self.all_files, id_ood_disaster[0], id_ood_disaster[1]) + print(f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files.") + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + """Get an item from the dataset at the given index.""" + file_info = ( + self.files["train"][index] + if self.split == "train" + else self.files["test"][index]) + + image = self._load_image(file_info["image"]).to("cuda") + mask = self._load_target(file_info["mask"]).long().to("cuda") + mask[mask == 2] = 1 + mask[(mask == 3) | (mask == 4)] = 0 + + sample = {"image": image, "mask": mask} + + if self.transforms: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the total number of samples in the dataset.""" + return ( + len(self.files["train"]) + if self.split == "train" + else len(self.files["test"]) + ) + + def _initialize_files(self, root: Path) -> List[Dict[str, str]]: + """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" + all_files = [] + for split in self.metadata.keys(): + image_root = os.path.join(root, split, "images") + mask_root = os.path.join(root, split, "targets") + images = glob.glob(os.path.join(image_root, "*.png")) + + # Extract basenames while preserving the event-name and sample number + for img in images: + basename_parts = os.path.basename(img).split("_") + event_name = basename_parts[0] # e.g., mexico-earthquake + sample_number = basename_parts[1] # e.g., 00000001 + basename = ( + f"{event_name}_{sample_number}" # e.g., mexico-earthquake_00000001 + ) + + + file_info = { + "image": img, + "mask": os.path.join( + mask_root, f"{basename}_pre_disaster_target.png" + ), + "basename": basename, + } + all_files.append(file_info) + return all_files + + def _load_split_files_by_disaster_and_type( + self, files: List[Dict[str, str]], id_disaster: Dict[str, str], ood_disaster: Dict[str, str] + ) -> Dict[str, List[Dict[str, str]]]: + """ + Return the paths of the files for the train (ID) and test (OOD) sets based on the specified disaster name + and pre-post disaster type. + + Args: + files: List of file paths with their corresponding information. + id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {"disaster_name": "guatemala-volcano", "pre-post": "pre"}). + ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {"disaster_name": "mexico-earthquake", "pre-post": "post"}). + + Returns: + A dictionary containing 'train' (ID) and 'test' (OOD) file lists. + """ + train_files = [] + test_files = [] + disaster_list = [] + + for file_info in files: + basename = file_info["basename"] + disaster_name = basename.split("_")[0] # Extract disaster name from basename + pre_post = ("pre" if "pre_disaster" in file_info["image"] else "post") # Identify pre/post type + + disaster_list.append(disaster_name) + + # Filter for in-domain (ID) training set + if disaster_name == id_disaster["disaster_name"]: + if id_disaster.get("pre-post") == "both" or id_disaster["pre-post"] == pre_post: + image = ( + file_info["image"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["image"] + ) + mask = ( + file_info["mask"].replace("post_disaster", "pre_disaster") + if pre_post == "pre" + else file_info["mask"] + ) + train_files.append(dict(image=image, mask=mask)) + + # Filter for out-of-domain (OOD) test set + if disaster_name == ood_disaster["disaster_name"]: + if ood_disaster.get("pre-post") == "both" or ood_disaster["pre-post"] == pre_post: + test_files.append(file_info) + + return {"train": train_files, "test": test_files, "disasters":disaster_list} \ No newline at end of file From 5985f44c451b99329e8e62b50a90f7c2fb2f724e Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Mon, 18 Nov 2024 16:00:56 +0100 Subject: [PATCH 04/28] test xview2 --- tests/datasets/test_xview2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index 7689acf5f78..35e02e27c28 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import DatasetNotFoundError, XView2 +from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift class TestXView2: From a23344e69567d6f435e8f95018a6716ae0fe5aa4 Mon Sep 17 00:00:00 2001 From: Burak Date: Mon, 18 Nov 2024 17:03:23 +0100 Subject: [PATCH 05/28] formatting --- tests/datasets/test_xview2.py | 2 +- torchgeo/datasets/xview.py | 211 +++++++++++++++++++--------------- 2 files changed, 121 insertions(+), 92 deletions(-) diff --git a/tests/datasets/test_xview2.py b/tests/datasets/test_xview2.py index 35e02e27c28..dc8774d0933 100644 --- a/tests/datasets/test_xview2.py +++ b/tests/datasets/test_xview2.py @@ -14,7 +14,6 @@ from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift - class TestXView2: @pytest.fixture(params=['train', 'test']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: @@ -27,6 +26,7 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: 'md5': '373e61d55c1b294aa76b94dbbd81332b', 'directory': 'train', }, + 'test': { 'filename': 'test_images_labels_targets.tar.gz', 'md5': 'bc6de81c956a3bada38b5b4e246266a1', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 98fe3ceb534..bc9ce8d5992 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -19,6 +19,7 @@ from .utils import check_integrity, draw_semantic_segmentation_masks, extract_archive + class XView2(NonGeoDataset): """xView2 dataset. @@ -50,24 +51,24 @@ class XView2(NonGeoDataset): """ metadata = { - 'train': { - 'filename': 'train_images_labels_targets.tar.gz', - 'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16', - 'directory': 'train', + "train": { + "filename": "train_images_labels_targets.tar.gz", + "md5": "a20ebbfb7eb3452785b63ad02ffd1e16", + "directory": "train", }, - 'test': { - 'filename': 'test_images_labels_targets.tar.gz', - 'md5': '1b39c47e05d1319c17cc8763cee6fe0c', - 'directory': 'test', + "test": { + "filename": "test_images_labels_targets.tar.gz", + "md5": "1b39c47e05d1319c17cc8763cee6fe0c", + "directory": "test", }, } - classes = ['background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed'] - colormap = ['green', 'blue', 'orange', 'red'] + classes = ["background", "no-damage", "minor-damage", "major-damage", "destroyed"] + colormap = ["green", "blue", "orange", "red"] def __init__( self, - root: str = 'data', - split: str = 'train', + root: str = "data", + split: str = "train", transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: @@ -105,14 +106,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image1 = self._load_image(files['image1']) - image2 = self._load_image(files['image2']) - mask1 = self._load_target(files['mask1']) - mask2 = self._load_target(files['mask2']) + image1 = self._load_image(files["image1"]) + image2 = self._load_image(files["image2"]) + mask1 = self._load_target(files["mask1"]) + mask2 = self._load_target(files["mask2"]) image = torch.stack(tensors=[image1, image2], dim=0) mask = torch.stack(tensors=[mask1, mask2], dim=0) - sample = {'image': image, 'mask': mask} + sample = {"image": image, "mask": mask} if self.transforms is not None: sample = self.transforms(sample) @@ -138,17 +139,17 @@ def _load_files(self, root: str, split: str) -> list[dict[str, str]]: list of dicts containing paths for each pair of images and masks """ files = [] - directory = self.metadata[split]['directory'] - image_root = os.path.join(root, directory, 'images') - mask_root = os.path.join(root, directory, 'targets') - images = glob.glob(os.path.join(image_root, '*.png')) + directory = self.metadata[split]["directory"] + image_root = os.path.join(root, directory, "images") + mask_root = os.path.join(root, directory, "targets") + images = glob.glob(os.path.join(image_root, "*.png")) basenames = [os.path.basename(f) for f in images] - basenames = ['_'.join(f.split('_')[:-2]) for f in basenames] + basenames = ["_".join(f.split("_")[:-2]) for f in basenames] for name in sorted(set(basenames)): - image1 = os.path.join(image_root, f'{name}_pre_disaster.png') - image2 = os.path.join(image_root, f'{name}_post_disaster.png') - mask1 = os.path.join(mask_root, f'{name}_pre_disaster_target.png') - mask2 = os.path.join(mask_root, f'{name}_post_disaster_target.png') + image1 = os.path.join(image_root, f"{name}_pre_disaster.png") + image2 = os.path.join(image_root, f"{name}_post_disaster.png") + mask1 = os.path.join(mask_root, f"{name}_pre_disaster_target.png") + mask2 = os.path.join(mask_root, f"{name}_post_disaster_target.png") files.append(dict(image1=image1, image2=image2, mask1=mask1, mask2=mask2)) return files @@ -163,7 +164,7 @@ def _load_image(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert('RGB')) + array: np.typing.NDArray[np.int_] = np.array(img.convert("RGB")) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -180,7 +181,7 @@ def _load_target(self, path: str) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert('L')) + array: np.typing.NDArray[np.int_] = np.array(img.convert("L")) tensor = torch.from_numpy(array) tensor = tensor.to(torch.long) return tensor @@ -190,10 +191,10 @@ def _verify(self) -> None: # Check if the files already exist exists = [] for split_info in self.metadata.values(): - for directory in ['images', 'targets']: + for directory in ["images", "targets"]: exists.append( os.path.exists( - os.path.join(self.root, split_info['directory'], directory) + os.path.join(self.root, split_info["directory"], directory) ) ) @@ -203,10 +204,10 @@ def _verify(self) -> None: # Check if .tar.gz files already exists (if so then extract) exists = [] for split_info in self.metadata.values(): - filepath = os.path.join(self.root, split_info['filename']) + filepath = os.path.join(self.root, split_info["filename"]) if os.path.isfile(filepath): - if self.checksum and not check_integrity(filepath, split_info['md5']): - raise RuntimeError('Dataset found, but corrupted.') + if self.checksum and not check_integrity(filepath, split_info["md5"]): + raise RuntimeError("Dataset found, but corrupted.") exists.append(True) extract_archive(filepath) else: @@ -237,70 +238,78 @@ def plot( """ ncols = 2 image1 = draw_semantic_segmentation_masks( - sample['image'][0], sample['mask'][0], alpha=alpha, colors=self.colormap + sample["image"][0], sample["mask"][0], alpha=alpha, colors=self.colormap ) image2 = draw_semantic_segmentation_masks( - sample['image'][1], sample['mask'][1], alpha=alpha, colors=self.colormap + sample["image"][1], sample["mask"][1], alpha=alpha, colors=self.colormap ) - if 'prediction' in sample: # NOTE: this assumes predictions are made for post + if "prediction" in sample: # NOTE: this assumes predictions are made for post ncols += 1 image3 = draw_semantic_segmentation_masks( - sample['image'][1], - sample['prediction'], + sample["image"][1], + sample["prediction"], alpha=alpha, colors=self.colormap, ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) - axs[0].axis('off') + axs[0].axis("off") axs[1].imshow(image2) - axs[1].axis('off') + axs[1].axis("off") if ncols > 2: axs[2].imshow(image3) - axs[2].axis('off') + axs[2].axis("off") if show_titles: - axs[0].set_title('Pre disaster') - axs[1].set_title('Post disaster') + axs[0].set_title("Pre disaster") + axs[1].set_title("Post disaster") if ncols > 2: - axs[2].set_title('Predictions') + axs[2].set_title("Predictions") if suptitle is not None: plt.suptitle(suptitle) return fig - + + class XView2DistShift(XView2): - """ - A subclass of the XView2 dataset designed to reformat the original train/test splits - based on specific in-domain (ID) and out-of-domain (OOD) disasters. - - This class allows for the selection of particular disasters to be used as the - training set (in-domain) and test set (out-of-domain). The dataset can be split - according to the disaster names specified by the user, enabling the model to train - on one disaster type and evaluate on a different, out-of-domain disaster. The goal - is to test the generalization ability of models trained on one disaster to perform + """A subclass of the XView2 dataset designed to reformat the original train/test splits. + + This class allows for the selection of particular disasters to be used as the + training set (in-domain) and test set (out-of-domain). The dataset can be split + according to the disaster names specified by the user, enabling the model to train + on one disaster type and evaluate on a different, out-of-domain disaster. The goal + is to test the generalization ability of models trained on one disaster to perform on others. """ - + classes = ["background", "building"] - + # List of disaster names valid_disasters = [ - 'hurricane-harvey', 'socal-fire', 'hurricane-matthew', 'mexico-earthquake', - 'guatemala-volcano', 'santa-rosa-wildfire', 'palu-tsunami', 'hurricane-florence', - 'hurricane-michael', 'midwest-flooding' + "hurricane-harvey", + "socal-fire", + "hurricane-matthew", + "mexico-earthquake", + "guatemala-volcano", + "santa-rosa-wildfire", + "palu-tsunami", + "hurricane-florence", + "hurricane-michael", + "midwest-flooding", ] - + def __init__( self, - root: Path = "data", + root: str = "data", split: str = "train", - id_ood_disaster: List[Dict[str, str]] = [{"disaster_name": "hurricane-matthew", "pre-post": "post"}, {"disaster_name": "mexico-earthquake", "pre-post": "post"}], - transforms: Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]] = None, + id_ood_disaster: list[dict[str, str]] = [ + {"disaster_name": "hurricane-matthew", "pre-post": "post"}, + {"disaster_name": "mexico-earthquake", "pre-post": "post"}, + ], + transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, checksum: bool = False, - **kwargs ) -> None: """Initialize the XView2DistShift dataset instance. @@ -311,7 +320,7 @@ def __init__( transforms: a function/transform that takes input sample and its target as entry and returns a transformed version checksum: if True, check the MD5 of the downloaded files (may be slow) - + Raises: AssertionError: If *split* is invalid. ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. @@ -319,9 +328,14 @@ def __init__( """ assert split in ["train", "test"], "Split must be either 'train' or 'test'." # Validate that the disasters are valid - - if id_ood_disaster[0]['disaster_name'] not in self.valid_disasters or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters: - raise ValueError(f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}") + + if ( + id_ood_disaster[0]["disaster_name"] not in self.valid_disasters + or id_ood_disaster[1]["disaster_name"] not in self.valid_disasters + ): + raise ValueError( + f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}" + ) self.root = root self.split = split @@ -332,23 +346,28 @@ def __init__( # Load all files and compute basenames and disasters only once self.all_files = self._initialize_files(root) - + # Split logic by disaster and pre-post type - self.files = self._load_split_files_by_disaster_and_type(self.all_files, id_ood_disaster[0], id_ood_disaster[1]) - print(f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files.") + self.files = self._load_split_files_by_disaster_and_type( + self.all_files, id_ood_disaster[0], id_ood_disaster[1] + ) + print( + f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files." + ) - def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: """Get an item from the dataset at the given index.""" file_info = ( self.files["train"][index] if self.split == "train" - else self.files["test"][index]) + else self.files["test"][index] + ) image = self._load_image(file_info["image"]).to("cuda") mask = self._load_target(file_info["mask"]).long().to("cuda") mask[mask == 2] = 1 mask[(mask == 3) | (mask == 4)] = 0 - + sample = {"image": image, "mask": mask} if self.transforms: @@ -364,14 +383,14 @@ def __len__(self) -> int: else len(self.files["test"]) ) - def _initialize_files(self, root: Path) -> List[Dict[str, str]]: + def _initialize_files(self, root: str) -> list[dict[str, str]]: """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" all_files = [] for split in self.metadata.keys(): image_root = os.path.join(root, split, "images") mask_root = os.path.join(root, split, "targets") images = glob.glob(os.path.join(image_root, "*.png")) - + # Extract basenames while preserving the event-name and sample number for img in images: basename_parts = os.path.basename(img).split("_") @@ -381,7 +400,6 @@ def _initialize_files(self, root: Path) -> List[Dict[str, str]]: f"{event_name}_{sample_number}" # e.g., mexico-earthquake_00000001 ) - file_info = { "image": img, "mask": os.path.join( @@ -393,11 +411,12 @@ def _initialize_files(self, root: Path) -> List[Dict[str, str]]: return all_files def _load_split_files_by_disaster_and_type( - self, files: List[Dict[str, str]], id_disaster: Dict[str, str], ood_disaster: Dict[str, str] - ) -> Dict[str, List[Dict[str, str]]]: - """ - Return the paths of the files for the train (ID) and test (OOD) sets based on the specified disaster name - and pre-post disaster type. + self, + files: list[dict[str, str]], + id_disaster: dict[str, str], + ood_disaster: dict[str, str], + ) -> dict[str, list[dict[str, str]]]: + """Return the filepaths for the train (ID) and test (OOD) sets based on disaster name and pre-post disaster type. Args: files: List of file paths with their corresponding information. @@ -410,17 +429,24 @@ def _load_split_files_by_disaster_and_type( train_files = [] test_files = [] disaster_list = [] - + for file_info in files: basename = file_info["basename"] - disaster_name = basename.split("_")[0] # Extract disaster name from basename - pre_post = ("pre" if "pre_disaster" in file_info["image"] else "post") # Identify pre/post type - - disaster_list.append(disaster_name) - + disaster_name = basename.split("_")[ + 0 + ] # Extract disaster name from basename + pre_post = ( + "pre" if "pre_disaster" in file_info["image"] else "post" + ) # Identify pre/post type + + disaster_list.append(disaster_name) + # Filter for in-domain (ID) training set if disaster_name == id_disaster["disaster_name"]: - if id_disaster.get("pre-post") == "both" or id_disaster["pre-post"] == pre_post: + if ( + id_disaster.get("pre-post") == "both" + or id_disaster["pre-post"] == pre_post + ): image = ( file_info["image"].replace("post_disaster", "pre_disaster") if pre_post == "pre" @@ -435,7 +461,10 @@ def _load_split_files_by_disaster_and_type( # Filter for out-of-domain (OOD) test set if disaster_name == ood_disaster["disaster_name"]: - if ood_disaster.get("pre-post") == "both" or ood_disaster["pre-post"] == pre_post: + if ( + ood_disaster.get("pre-post") == "both" + or ood_disaster["pre-post"] == pre_post + ): test_files.append(file_info) - return {"train": train_files, "test": test_files, "disasters":disaster_list} \ No newline at end of file + return {"train": train_files, "test": test_files, "disasters": disaster_list} From 8239d32f80dd882f80f50553fd4638ef9d1c4bae Mon Sep 17 00:00:00 2001 From: Burak Date: Sat, 1 Feb 2025 16:41:18 +0100 Subject: [PATCH 06/28] " to ' --- torchgeo/datasets/xview.py | 178 ++++++++++++++++++------------------- 1 file changed, 89 insertions(+), 89 deletions(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 222790248cb..976ca219aad 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -61,10 +61,10 @@ class XView2(NonGeoDataset): 'md5': 'a20ebbfb7eb3452785b63ad02ffd1e16', 'directory': 'train', }, - "test": { - "filename": "test_images_labels_targets.tar.gz", - "md5": "1b39c47e05d1319c17cc8763cee6fe0c", - "directory": "test", + 'test': { + 'filename': 'test_images_labels_targets.tar.gz', + 'md5': '1b39c47e05d1319c17cc8763cee6fe0c', + 'directory': 'test', }, } classes = ('background', 'no-damage', 'minor-damage', 'major-damage', 'destroyed') @@ -111,14 +111,14 @@ def __getitem__(self, index: int) -> dict[str, Tensor]: data and label at that index """ files = self.files[index] - image1 = self._load_image(files["image1"]) - image2 = self._load_image(files["image2"]) - mask1 = self._load_target(files["mask1"]) - mask2 = self._load_target(files["mask2"]) + image1 = self._load_image(files['image1']) + image2 = self._load_image(files['image2']) + mask1 = self._load_target(files['mask1']) + mask2 = self._load_target(files['mask2']) image = torch.stack(tensors=[image1, image2], dim=0) mask = torch.stack(tensors=[mask1, mask2], dim=0) - sample = {"image": image, "mask": mask} + sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) @@ -138,23 +138,23 @@ def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: Args: root: root dir of dataset - split: subset of dataset, one of [train, test] + split: subset of dataset, one of ['train', 'test'] Returns: list of dicts containing paths for each pair of images and masks """ files = [] - directory = self.metadata[split]["directory"] - image_root = os.path.join(root, directory, "images") - mask_root = os.path.join(root, directory, "targets") - images = glob.glob(os.path.join(image_root, "*.png")) + directory = self.metadata[split]['directory'] + image_root = os.path.join(root, directory, 'images') + mask_root = os.path.join(root, directory, 'targets') + images = glob.glob(os.path.join(image_root, '*.png')) basenames = [os.path.basename(f) for f in images] - basenames = ["_".join(f.split("_")[:-2]) for f in basenames] + basenames = ['_'.join(f.split('_')[:-2]) for f in basenames] for name in sorted(set(basenames)): - image1 = os.path.join(image_root, f"{name}_pre_disaster.png") - image2 = os.path.join(image_root, f"{name}_post_disaster.png") - mask1 = os.path.join(mask_root, f"{name}_pre_disaster_target.png") - mask2 = os.path.join(mask_root, f"{name}_post_disaster_target.png") + image1 = os.path.join(image_root, f'{name}_pre_disaster.png') + image2 = os.path.join(image_root, f'{name}_post_disaster.png') + mask1 = os.path.join(mask_root, f'{name}_pre_disaster_target.png') + mask2 = os.path.join(mask_root, f'{name}_post_disaster_target.png') files.append(dict(image1=image1, image2=image2, mask1=mask1, mask2=mask2)) return files @@ -169,7 +169,7 @@ def _load_image(self, path: Path) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert("RGB")) + array: np.typing.NDArray[np.int_] = np.array(img.convert('RGB')) tensor = torch.from_numpy(array) # Convert from HxWxC to CxHxW tensor = tensor.permute((2, 0, 1)) @@ -186,7 +186,7 @@ def _load_target(self, path: Path) -> Tensor: """ filename = os.path.join(path) with Image.open(filename) as img: - array: np.typing.NDArray[np.int_] = np.array(img.convert("L")) + array: np.typing.NDArray[np.int_] = np.array(img.convert('L')) tensor = torch.from_numpy(array) tensor = tensor.to(torch.long) return tensor @@ -196,10 +196,10 @@ def _verify(self) -> None: # Check if the files already exist exists = [] for split_info in self.metadata.values(): - for directory in ["images", "targets"]: + for directory in ['images', 'targets']: exists.append( os.path.exists( - os.path.join(self.root, split_info["directory"], directory) + os.path.join(self.root, split_info['directory'], directory) ) ) @@ -209,10 +209,10 @@ def _verify(self) -> None: # Check if .tar.gz files already exists (if so then extract) exists = [] for split_info in self.metadata.values(): - filepath = os.path.join(self.root, split_info["filename"]) + filepath = os.path.join(self.root, split_info['filename']) if os.path.isfile(filepath): - if self.checksum and not check_integrity(filepath, split_info["md5"]): - raise RuntimeError("Dataset found, but corrupted.") + if self.checksum and not check_integrity(filepath, split_info['md5']): + raise RuntimeError('Dataset found, but corrupted.') exists.append(True) extract_archive(filepath) else: @@ -254,29 +254,29 @@ def plot( alpha=alpha, colors=list(self.colormap), ) - if "prediction" in sample: # NOTE: this assumes predictions are made for post + if 'prediction' in sample: # NOTE: this assumes predictions are made for post ncols += 1 image3 = draw_semantic_segmentation_masks( - sample["image"][1], - sample["prediction"], + sample['image'][1], + sample['prediction'], alpha=alpha, colors=list(self.colormap), ) fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) axs[0].imshow(image1) - axs[0].axis("off") + axs[0].axis('off') axs[1].imshow(image2) - axs[1].axis("off") + axs[1].axis('off') if ncols > 2: axs[2].imshow(image3) - axs[2].axis("off") + axs[2].axis('off') if show_titles: - axs[0].set_title("Pre disaster") - axs[1].set_title("Post disaster") + axs[0].set_title('Pre disaster') + axs[1].set_title('Post disaster') if ncols > 2: - axs[2].set_title("Predictions") + axs[2].set_title('Predictions') if suptitle is not None: plt.suptitle(suptitle) @@ -295,29 +295,29 @@ class XView2DistShift(XView2): on others. """ - classes = ["background", "building"] + classes = ['background', 'building'] # List of disaster names valid_disasters = [ - "hurricane-harvey", - "socal-fire", - "hurricane-matthew", - "mexico-earthquake", - "guatemala-volcano", - "santa-rosa-wildfire", - "palu-tsunami", - "hurricane-florence", - "hurricane-michael", - "midwest-flooding", + 'hurricane-harvey', + 'socal-fire', + 'hurricane-matthew', + 'mexico-earthquake', + 'guatemala-volcano', + 'santa-rosa-wildfire', + 'palu-tsunami', + 'hurricane-florence', + 'hurricane-michael', + 'midwest-flooding', ] def __init__( self, - root: str = "data", - split: str = "train", + root: str = 'data', + split: str = 'train', id_ood_disaster: list[dict[str, str]] = [ - {"disaster_name": "hurricane-matthew", "pre-post": "post"}, - {"disaster_name": "mexico-earthquake", "pre-post": "post"}, + {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, + {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, ], transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, checksum: bool = False, @@ -337,12 +337,12 @@ def __init__( ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. DatasetNotFoundError: If dataset is not found. """ - assert split in ["train", "test"], "Split must be either 'train' or 'test'." + assert split in ['train', 'test'], "Split must be either 'train' or 'test'." # Validate that the disasters are valid if ( - id_ood_disaster[0]["disaster_name"] not in self.valid_disasters - or id_ood_disaster[1]["disaster_name"] not in self.valid_disasters + id_ood_disaster[0]['disaster_name'] not in self.valid_disasters + or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters ): raise ValueError( f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}" @@ -369,17 +369,17 @@ def __init__( def __getitem__(self, index: int) -> dict[str, torch.Tensor]: """Get an item from the dataset at the given index.""" file_info = ( - self.files["train"][index] - if self.split == "train" - else self.files["test"][index] + self.files['train'][index] + if self.split == 'train' + else self.files['test'][index] ) - image = self._load_image(file_info["image"]).to("cuda") - mask = self._load_target(file_info["mask"]).long().to("cuda") + image = self._load_image(file_info['image']).to('cuda') + mask = self._load_target(file_info['mask']).long().to('cuda') mask[mask == 2] = 1 mask[(mask == 3) | (mask == 4)] = 0 - sample = {"image": image, "mask": mask} + sample = {'image': image, 'mask': mask} if self.transforms: sample = self.transforms(sample) @@ -389,34 +389,34 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: def __len__(self) -> int: """Return the total number of samples in the dataset.""" return ( - len(self.files["train"]) - if self.split == "train" - else len(self.files["test"]) + len(self.files['train']) + if self.split == 'train' + else len(self.files['test']) ) def _initialize_files(self, root: str) -> list[dict[str, str]]: """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" all_files = [] for split in self.metadata.keys(): - image_root = os.path.join(root, split, "images") - mask_root = os.path.join(root, split, "targets") - images = glob.glob(os.path.join(image_root, "*.png")) + image_root = os.path.join(root, split, 'images') + mask_root = os.path.join(root, split, 'targets') + images = glob.glob(os.path.join(image_root, '*.png')) # Extract basenames while preserving the event-name and sample number for img in images: - basename_parts = os.path.basename(img).split("_") + basename_parts = os.path.basename(img).split('_') event_name = basename_parts[0] # e.g., mexico-earthquake sample_number = basename_parts[1] # e.g., 00000001 basename = ( - f"{event_name}_{sample_number}" # e.g., mexico-earthquake_00000001 + f'{event_name}_{sample_number}' # e.g., mexico-earthquake_00000001 ) file_info = { - "image": img, - "mask": os.path.join( - mask_root, f"{basename}_pre_disaster_target.png" + 'image': img, + 'mask': os.path.join( + mask_root, f'{basename}_pre_disaster_target.png' ), - "basename": basename, + 'basename': basename, } all_files.append(file_info) return all_files @@ -431,8 +431,8 @@ def _load_split_files_by_disaster_and_type( Args: files: List of file paths with their corresponding information. - id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {"disaster_name": "guatemala-volcano", "pre-post": "pre"}). - ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {"disaster_name": "mexico-earthquake", "pre-post": "post"}). + id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {'disaster_name': 'guatemala-volcano', 'pre-post': 'pre'}). + ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}). Returns: A dictionary containing 'train' (ID) and 'test' (OOD) file lists. @@ -442,40 +442,40 @@ def _load_split_files_by_disaster_and_type( disaster_list = [] for file_info in files: - basename = file_info["basename"] - disaster_name = basename.split("_")[ + basename = file_info['basename'] + disaster_name = basename.split('_')[ 0 ] # Extract disaster name from basename pre_post = ( - "pre" if "pre_disaster" in file_info["image"] else "post" + 'pre' if 'pre_disaster' in file_info['image'] else 'post' ) # Identify pre/post type disaster_list.append(disaster_name) # Filter for in-domain (ID) training set - if disaster_name == id_disaster["disaster_name"]: + if disaster_name == id_disaster['disaster_name']: if ( - id_disaster.get("pre-post") == "both" - or id_disaster["pre-post"] == pre_post + id_disaster.get('pre-post') == 'both' + or id_disaster['pre-post'] == pre_post ): image = ( - file_info["image"].replace("post_disaster", "pre_disaster") - if pre_post == "pre" - else file_info["image"] + file_info['image'].replace('post_disaster', 'pre_disaster') + if pre_post == 'pre' + else file_info['image'] ) mask = ( - file_info["mask"].replace("post_disaster", "pre_disaster") - if pre_post == "pre" - else file_info["mask"] + file_info['mask'].replace('post_disaster', 'pre_disaster') + if pre_post == 'pre' + else file_info['mask'] ) train_files.append(dict(image=image, mask=mask)) # Filter for out-of-domain (OOD) test set - if disaster_name == ood_disaster["disaster_name"]: + if disaster_name == ood_disaster['disaster_name']: if ( - ood_disaster.get("pre-post") == "both" - or ood_disaster["pre-post"] == pre_post + ood_disaster.get('pre-post') == 'both' + or ood_disaster['pre-post'] == pre_post ): test_files.append(file_info) - return {"train": train_files, "test": test_files, "disasters": disaster_list} + return {'train': train_files, 'test': test_files, 'disasters': disaster_list} From 0f3ceb7931601dd6e2af9824a7f7bbdeb68217e6 Mon Sep 17 00:00:00 2001 From: burakekim Date: Sat, 1 Feb 2025 19:41:22 +0000 Subject: [PATCH 07/28] no cuda yes docstring --- .gitignore | 1 + torchgeo/datasets/xview.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 180c27c47b2..8357c23b3f8 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,4 @@ dmypy.json # Pyre type checker .pyre/ +xbdood.ipynb diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 976ca219aad..12a30d16991 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -374,10 +374,12 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: else self.files['test'][index] ) - image = self._load_image(file_info['image']).to('cuda') - mask = self._load_target(file_info['mask']).long().to('cuda') - mask[mask == 2] = 1 - mask[(mask == 3) | (mask == 4)] = 0 + image = self._load_image(file_info['image']) + mask = self._load_target(file_info['mask']).long() + + # Reformulate as building segmentation task + mask[mask == 2] = 1 # Map damage class 2 to 1 + mask[(mask == 3) | (mask == 4)] = 0 # Map 3 and 4 damage classes to background sample = {'image': image, 'mask': mask} @@ -402,7 +404,7 @@ def _initialize_files(self, root: str) -> list[dict[str, str]]: mask_root = os.path.join(root, split, 'targets') images = glob.glob(os.path.join(image_root, '*.png')) - # Extract basenames while preserving the event-name and sample number + # Extract basenames while preserving the disaster-name and sample number for img in images: basename_parts = os.path.basename(img).split('_') event_name = basename_parts[0] # e.g., mexico-earthquake @@ -470,7 +472,7 @@ def _load_split_files_by_disaster_and_type( ) train_files.append(dict(image=image, mask=mask)) - # Filter for out-of-domain (OOD) test set + # Filter for out-of-distribution (OOD) test set if disaster_name == ood_disaster['disaster_name']: if ( ood_disaster.get('pre-post') == 'both' From a74c99fa2d211e07c647124ae54679a88ffaafc9 Mon Sep 17 00:00:00 2001 From: burakekim Date: Sat, 1 Feb 2025 19:49:35 +0000 Subject: [PATCH 08/28] id ood length method and polishing --- torchgeo/datasets/xview.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 12a30d16991..f6ba2092792 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -362,10 +362,10 @@ def __init__( self.files = self._load_split_files_by_disaster_and_type( self.all_files, id_ood_disaster[0], id_ood_disaster[1] ) - print( - f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files." - ) + train_size, test_size = self.get_id_ood_sizes() + print(f"ID sample len: {train_size}, OOD sample len: {test_size}") + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: """Get an item from the dataset at the given index.""" file_info = ( @@ -396,6 +396,12 @@ def __len__(self) -> int: else len(self.files['test']) ) + + def get_id_ood_sizes(self) -> tuple[int, int]: + """Return the number of samples in the train and test splits.""" + return (len(self.files['train']), len(self.files['test'])) + + def _initialize_files(self, root: str) -> list[dict[str, str]]: """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" all_files = [] @@ -454,7 +460,7 @@ def _load_split_files_by_disaster_and_type( disaster_list.append(disaster_name) - # Filter for in-domain (ID) training set + # Filter for in-distribution (ID) training set if disaster_name == id_disaster['disaster_name']: if ( id_disaster.get('pre-post') == 'both' From b66c7923619fcb08df511530d44e27d5e0f0813f Mon Sep 17 00:00:00 2001 From: burakekim Date: Fri, 18 Apr 2025 20:40:40 +0000 Subject: [PATCH 09/28] idk what this is --- tests/datasets/test_xview.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index e873cd567d5..e5d8b6634e5 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -12,7 +12,8 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift +from torchgeo.datasets import DatasetNotFoundError, XView2 + class TestXView2: @pytest.fixture(params=['train', 'test']) From 5c2a6e67c3f037363318106f3d8a3c2bba0c0ec1 Mon Sep 17 00:00:00 2001 From: burakekim Date: Fri, 18 Apr 2025 20:48:53 +0000 Subject: [PATCH 10/28] ruff fixes --- torchgeo/datasets/xview.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index f6ba2092792..71071ea2368 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -295,10 +295,11 @@ class XView2DistShift(XView2): on others. """ - classes = ['background', 'building'] - # List of disaster names - valid_disasters = [ + classes: ClassVar[list[str]] = ['background', 'building'] + + + valid_disasters: ClassVar[list[str]] = [ 'hurricane-harvey', 'socal-fire', 'hurricane-matthew', @@ -319,7 +320,7 @@ def __init__( {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, ], - transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize the XView2DistShift dataset instance. From ae33510690ef440bdb023ae147fc5b25f69aee19 Mon Sep 17 00:00:00 2001 From: burakekim Date: Fri, 18 Apr 2025 21:13:27 +0000 Subject: [PATCH 11/28] make mypy happy --- torchgeo/datasets/xview.py | 73 +++++++++++--------------------------- 1 file changed, 20 insertions(+), 53 deletions(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 71071ea2368..0eef89e161d 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -295,9 +295,7 @@ class XView2DistShift(XView2): on others. """ - - classes: ClassVar[list[str]] = ['background', 'building'] - + binary_classes: ClassVar[tuple[str, str]] = ('background', 'building') valid_disasters: ClassVar[list[str]] = [ 'hurricane-harvey', @@ -339,7 +337,6 @@ def __init__( DatasetNotFoundError: If dataset is not found. """ assert split in ['train', 'test'], "Split must be either 'train' or 'test'." - # Validate that the disasters are valid if ( id_ood_disaster[0]['disaster_name'] not in self.valid_disasters @@ -360,30 +357,29 @@ def __init__( self.all_files = self._initialize_files(root) # Split logic by disaster and pre-post type - self.files = self._load_split_files_by_disaster_and_type( + self.split_files: dict[str, list[dict[str, str]]] = self._load_split_files_by_disaster_and_type( self.all_files, id_ood_disaster[0], id_ood_disaster[1] ) train_size, test_size = self.get_id_ood_sizes() print(f"ID sample len: {train_size}, OOD sample len: {test_size}") - + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: """Get an item from the dataset at the given index.""" file_info = ( - self.files['train'][index] + self.split_files['train'][index] if self.split == 'train' - else self.files['test'][index] + else self.split_files['test'][index] ) image = self._load_image(file_info['image']) mask = self._load_target(file_info['mask']).long() - # Reformulate as building segmentation task - mask[mask == 2] = 1 # Map damage class 2 to 1 - mask[(mask == 3) | (mask == 4)] = 0 # Map 3 and 4 damage classes to background + # Reformulate as building segmentation task + mask[mask == 2] = 1 # minor-damage → building + mask[(mask == 3) | (mask == 4)] = 0 # major/destroyed → background sample = {'image': image, 'mask': mask} - if self.transforms: sample = self.transforms(sample) @@ -392,16 +388,14 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: def __len__(self) -> int: """Return the total number of samples in the dataset.""" return ( - len(self.files['train']) + len(self.split_files['train']) if self.split == 'train' - else len(self.files['test']) + else len(self.split_files['test']) ) - def get_id_ood_sizes(self) -> tuple[int, int]: """Return the number of samples in the train and test splits.""" - return (len(self.files['train']), len(self.files['test'])) - + return (len(self.split_files['train']), len(self.split_files['test'])) def _initialize_files(self, root: str) -> list[dict[str, str]]: """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" @@ -416,15 +410,11 @@ def _initialize_files(self, root: str) -> list[dict[str, str]]: basename_parts = os.path.basename(img).split('_') event_name = basename_parts[0] # e.g., mexico-earthquake sample_number = basename_parts[1] # e.g., 00000001 - basename = ( - f'{event_name}_{sample_number}' # e.g., mexico-earthquake_00000001 - ) + basename = f'{event_name}_{sample_number}' file_info = { 'image': img, - 'mask': os.path.join( - mask_root, f'{basename}_pre_disaster_target.png' - ), + 'mask': os.path.join(mask_root, f'{basename}_pre_disaster_target.png'), 'basename': basename, } all_files.append(file_info) @@ -448,43 +438,20 @@ def _load_split_files_by_disaster_and_type( """ train_files = [] test_files = [] - disaster_list = [] for file_info in files: basename = file_info['basename'] - disaster_name = basename.split('_')[ - 0 - ] # Extract disaster name from basename - pre_post = ( - 'pre' if 'pre_disaster' in file_info['image'] else 'post' - ) # Identify pre/post type - - disaster_list.append(disaster_name) + disaster_name = basename.split('_')[0] + pre_post = 'pre' if 'pre_disaster' in file_info['image'] else 'post' - # Filter for in-distribution (ID) training set if disaster_name == id_disaster['disaster_name']: - if ( - id_disaster.get('pre-post') == 'both' - or id_disaster['pre-post'] == pre_post - ): - image = ( - file_info['image'].replace('post_disaster', 'pre_disaster') - if pre_post == 'pre' - else file_info['image'] - ) - mask = ( - file_info['mask'].replace('post_disaster', 'pre_disaster') - if pre_post == 'pre' - else file_info['mask'] - ) + if id_disaster.get('pre-post') == 'both' or id_disaster['pre-post'] == pre_post: + image = file_info['image'].replace('post_disaster', 'pre_disaster') if pre_post == 'pre' else file_info['image'] + mask = file_info['mask'].replace('post_disaster', 'pre_disaster') if pre_post == 'pre' else file_info['mask'] train_files.append(dict(image=image, mask=mask)) - # Filter for out-of-distribution (OOD) test set if disaster_name == ood_disaster['disaster_name']: - if ( - ood_disaster.get('pre-post') == 'both' - or ood_disaster['pre-post'] == pre_post - ): + if ood_disaster.get('pre-post') == 'both' or ood_disaster['pre-post'] == pre_post: test_files.append(file_info) - return {'train': train_files, 'test': test_files, 'disasters': disaster_list} + return {'train': train_files, 'test': test_files} From 7d2dbbb438447df6c10ff5c4a38ae02826b49baf Mon Sep 17 00:00:00 2001 From: burakekim Date: Sat, 19 Apr 2025 19:27:41 +0000 Subject: [PATCH 12/28] make ruff happy --- tests/datasets/test_xview.py | 1 - torchgeo/datasets/xview.py | 39 +++++++++++++++++++++++++++--------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index e5d8b6634e5..c54b597fadf 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -27,7 +27,6 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: 'md5': '373e61d55c1b294aa76b94dbbd81332b', 'directory': 'train', }, - 'test': { 'filename': 'test_images_labels_targets.tar.gz', 'md5': 'bc6de81c956a3bada38b5b4e246266a1', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 0eef89e161d..c059da6ebbd 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -318,7 +318,8 @@ def __init__( {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, ], - transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] | None = None, + transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] + | None = None, checksum: bool = False, ) -> None: """Initialize the XView2DistShift dataset instance. @@ -343,7 +344,7 @@ def __init__( or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters ): raise ValueError( - f"Invalid disaster names. Valid options are: {', '.join(self.valid_disasters)}" + f'Invalid disaster names. Valid options are: {", ".join(self.valid_disasters)}' ) self.root = root @@ -357,12 +358,14 @@ def __init__( self.all_files = self._initialize_files(root) # Split logic by disaster and pre-post type - self.split_files: dict[str, list[dict[str, str]]] = self._load_split_files_by_disaster_and_type( - self.all_files, id_ood_disaster[0], id_ood_disaster[1] + self.split_files: dict[str, list[dict[str, str]]] = ( + self._load_split_files_by_disaster_and_type( + self.all_files, id_ood_disaster[0], id_ood_disaster[1] + ) ) train_size, test_size = self.get_id_ood_sizes() - print(f"ID sample len: {train_size}, OOD sample len: {test_size}") + print(f'ID sample len: {train_size}, OOD sample len: {test_size}') def __getitem__(self, index: int) -> dict[str, torch.Tensor]: """Get an item from the dataset at the given index.""" @@ -414,7 +417,9 @@ def _initialize_files(self, root: str) -> list[dict[str, str]]: file_info = { 'image': img, - 'mask': os.path.join(mask_root, f'{basename}_pre_disaster_target.png'), + 'mask': os.path.join( + mask_root, f'{basename}_pre_disaster_target.png' + ), 'basename': basename, } all_files.append(file_info) @@ -445,13 +450,27 @@ def _load_split_files_by_disaster_and_type( pre_post = 'pre' if 'pre_disaster' in file_info['image'] else 'post' if disaster_name == id_disaster['disaster_name']: - if id_disaster.get('pre-post') == 'both' or id_disaster['pre-post'] == pre_post: - image = file_info['image'].replace('post_disaster', 'pre_disaster') if pre_post == 'pre' else file_info['image'] - mask = file_info['mask'].replace('post_disaster', 'pre_disaster') if pre_post == 'pre' else file_info['mask'] + if ( + id_disaster.get('pre-post') == 'both' + or id_disaster['pre-post'] == pre_post + ): + image = ( + file_info['image'].replace('post_disaster', 'pre_disaster') + if pre_post == 'pre' + else file_info['image'] + ) + mask = ( + file_info['mask'].replace('post_disaster', 'pre_disaster') + if pre_post == 'pre' + else file_info['mask'] + ) train_files.append(dict(image=image, mask=mask)) if disaster_name == ood_disaster['disaster_name']: - if ood_disaster.get('pre-post') == 'both' or ood_disaster['pre-post'] == pre_post: + if ( + ood_disaster.get('pre-post') == 'both' + or ood_disaster['pre-post'] == pre_post + ): test_files.append(file_info) return {'train': train_files, 'test': test_files} From f011b44156a69d0d015e146952e5459c8c5c10b3 Mon Sep 17 00:00:00 2001 From: burakekim Date: Sat, 19 Apr 2025 19:57:29 +0000 Subject: [PATCH 13/28] add tests --- tests/datasets/test_xview.py | 68 +++++++++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index c54b597fadf..48696e172ca 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import DatasetNotFoundError, XView2 +from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift class TestXView2: @@ -92,3 +92,69 @@ def test_plot(self, dataset: XView2) -> None: x['prediction'] = x['mask'][0].clone() dataset.plot(x) plt.close() + + +class TestXView2DistShift: + @pytest.fixture(params=['train', 'test']) + def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistShift: + monkeypatch.setattr( + XView2DistShift, + 'metadata', + { + 'train': { + 'filename': 'train_images_labels_targets.tar.gz', + 'md5': '373e61d55c1b294aa76b94dbbd81332b', + 'directory': 'train', + }, + 'test': { + 'filename': 'test_images_labels_targets.tar.gz', + 'md5': 'bc6de81c956a3bada38b5b4e246266a1', + 'directory': 'test', + }, + }, + ) + root = os.path.join('tests', 'data', 'xview2') + split = request.param + transforms = nn.Identity() + + return XView2DistShift( + root=root, + split=split, + id_ood_disaster=[ + {"disaster_name": "hurricane-matthew", "pre-post": "post"}, + {"disaster_name": "mexico-earthquake", "pre-post": "post"}, + ], + transforms=transforms, + checksum=True, + ) + + def test_getitem(self, dataset: XView2DistShift) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert set(torch.unique(x['mask']).tolist()).issubset({0, 1}) # binary mask + + def test_len(self, dataset: XView2DistShift) -> None: + assert len(dataset) > 0 + + def test_invalid_disaster(self) -> None: + with pytest.raises(ValueError, match='Invalid disaster names'): + XView2DistShift( + root='tests/data/xview2', + id_ood_disaster=[ + {'disaster_name': 'not-a-real-one', 'pre-post': 'post'}, + {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, + ], + ) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + XView2DistShift( + root='tests/data/xview2', + split='bad', + id_ood_disaster=[ + {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, + {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, + ], + ) \ No newline at end of file From 4bc5510ff98e982e684e4fd1bca805ab686a22a7 Mon Sep 17 00:00:00 2001 From: burakekim Date: Sat, 19 Apr 2025 20:41:22 +0000 Subject: [PATCH 14/28] more test --- tests/datasets/test_xview.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index 48696e172ca..c2c480a3e2b 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -121,8 +121,8 @@ def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistSh root=root, split=split, id_ood_disaster=[ - {"disaster_name": "hurricane-matthew", "pre-post": "post"}, - {"disaster_name": "mexico-earthquake", "pre-post": "post"}, + {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, + {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, ], transforms=transforms, checksum=True, @@ -144,7 +144,7 @@ def test_invalid_disaster(self) -> None: root='tests/data/xview2', id_ood_disaster=[ {'disaster_name': 'not-a-real-one', 'pre-post': 'post'}, - {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, + {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, ], ) @@ -157,4 +157,4 @@ def test_invalid_split(self) -> None: {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, ], - ) \ No newline at end of file + ) From 4cc0ac08b9e5c08f536206885a3d81998cd220e0 Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Sun, 20 Apr 2025 14:18:42 +0100 Subject: [PATCH 15/28] Update torchgeo/datasets/xview.py Adam for the win Co-authored-by: Adam J. Stewart --- torchgeo/datasets/xview.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index c059da6ebbd..721efe5a7fe 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -313,7 +313,7 @@ class XView2DistShift(XView2): def __init__( self, root: str = 'data', - split: str = 'train', + split: Literal['train', 'test'] = 'train', id_ood_disaster: list[dict[str, str]] = [ {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, From 6d4ac9e2c5b1c6cf28968dac5c16d564a3eb9bae Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Sun, 20 Apr 2025 14:19:04 +0100 Subject: [PATCH 16/28] Update torchgeo/datasets/xview.py Co-authored-by: Adam J. Stewart --- torchgeo/datasets/xview.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 721efe5a7fe..6baeefc22a7 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -318,7 +318,7 @@ def __init__( {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, ], - transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: From 6e32b6be563065177dbaf79edd24d8e8e83893db Mon Sep 17 00:00:00 2001 From: Burak <68427259+burakekim@users.noreply.github.com> Date: Sun, 20 Apr 2025 14:19:51 +0100 Subject: [PATCH 17/28] Update torchgeo/datasets/xview.py Co-authored-by: Adam J. Stewart --- torchgeo/datasets/xview.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 6baeefc22a7..88ee0aa4412 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -334,7 +334,7 @@ def __init__( Raises: AssertionError: If *split* is invalid. - ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. + ValueError: If a disaster name in *id_ood_disaster* is not one of the valid disasters. DatasetNotFoundError: If dataset is not found. """ assert split in ['train', 'test'], "Split must be either 'train' or 'test'." From dd5e626c38efcf108f887747cbf8d1b75d8cd1db Mon Sep 17 00:00:00 2001 From: burakekim Date: Sun, 20 Apr 2025 13:44:12 +0000 Subject: [PATCH 18/28] post review changes --- torchgeo/datasets/xview.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 88ee0aa4412..ff6af569082 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -6,7 +6,7 @@ import glob import os from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, Literal import matplotlib.pyplot as plt import numpy as np @@ -293,6 +293,8 @@ class XView2DistShift(XView2): on one disaster type and evaluate on a different, out-of-domain disaster. The goal is to test the generalization ability of models trained on one disaster to perform on others. + + .. versionadded:: 0.8 """ binary_classes: ClassVar[tuple[str, str]] = ('background', 'building') @@ -318,8 +320,7 @@ def __init__( {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, ], - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] - | None = None, + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize the XView2DistShift dataset instance. @@ -368,7 +369,14 @@ def __init__( print(f'ID sample len: {train_size}, OOD sample len: {test_size}') def __getitem__(self, index: int) -> dict[str, torch.Tensor]: - """Get an item from the dataset at the given index.""" + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ file_info = ( self.split_files['train'][index] if self.split == 'train' @@ -389,7 +397,11 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: return sample def __len__(self) -> int: - """Return the total number of samples in the dataset.""" + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ return ( len(self.split_files['train']) if self.split == 'train' @@ -406,7 +418,7 @@ def _initialize_files(self, root: str) -> list[dict[str, str]]: for split in self.metadata.keys(): image_root = os.path.join(root, split, 'images') mask_root = os.path.join(root, split, 'targets') - images = glob.glob(os.path.join(image_root, '*.png')) + images = sorted(glob.glob(os.path.join(image_root, '*.png'))) # Extract basenames while preserving the disaster-name and sample number for img in images: @@ -431,12 +443,14 @@ def _load_split_files_by_disaster_and_type( id_disaster: dict[str, str], ood_disaster: dict[str, str], ) -> dict[str, list[dict[str, str]]]: - """Return the filepaths for the train (ID) and test (OOD) sets based on disaster name and pre-post disaster type. + """Return train and test filepaths based on disaster name and pre/post type. Args: files: List of file paths with their corresponding information. - id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {'disaster_name': 'guatemala-volcano', 'pre-post': 'pre'}). - ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}). + id_disaster: Dictionary specifying in-domain (ID) disaster and type, + e.g., {'disaster_name': 'guatemala-volcano', 'pre-post': 'pre'}. + ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type, + e.g., {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}. Returns: A dictionary containing 'train' (ID) and 'test' (OOD) file lists. From b73de3c9472c26fc7d5d2fb246a823ad786378a9 Mon Sep 17 00:00:00 2001 From: burakekim Date: Sun, 20 Apr 2025 13:53:24 +0000 Subject: [PATCH 19/28] smol fixes --- .gitignore | 3 +-- torchgeo/datasets/xview.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 8357c23b3f8..65c5afb46c2 100644 --- a/.gitignore +++ b/.gitignore @@ -147,5 +147,4 @@ venv.bak/ dmypy.json # Pyre type checker -.pyre/ -xbdood.ipynb +.pyre/ \ No newline at end of file diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index ff6af569082..675e557a016 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -320,7 +320,7 @@ def __init__( {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, ], - transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize the XView2DistShift dataset instance. From 0150c32428213364c16a7e8948a94a5bf9727c02 Mon Sep 17 00:00:00 2001 From: burakekim Date: Sun, 20 Apr 2025 14:04:02 +0000 Subject: [PATCH 20/28] ruff again --- torchgeo/datasets/xview.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 675e557a016..c8802e7d7ab 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -320,7 +320,8 @@ def __init__( {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, ], - transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] | None = None, + transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] + | None = None, checksum: bool = False, ) -> None: """Initialize the XView2DistShift dataset instance. From b7296e7ea6c1e2a21eb373bbc0c494cee792dd88 Mon Sep 17 00:00:00 2001 From: burakekim Date: Sun, 20 Apr 2025 14:21:31 +0000 Subject: [PATCH 21/28] mypy fixes --- tests/datasets/test_xview.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index c2c480a3e2b..a094b7dd593 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -16,7 +16,7 @@ class TestXView2: - @pytest.fixture(params=['train', 'test']) + @pytest.fixture(params=['train', 'test']) # type: ignore[misc] def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: monkeypatch.setattr( XView2, @@ -95,7 +95,7 @@ def test_plot(self, dataset: XView2) -> None: class TestXView2DistShift: - @pytest.fixture(params=['train', 'test']) + @pytest.fixture(params=['train', 'test']) # type: ignore[misc] def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistShift: monkeypatch.setattr( XView2DistShift, @@ -152,7 +152,7 @@ def test_invalid_split(self) -> None: with pytest.raises(AssertionError): XView2DistShift( root='tests/data/xview2', - split='bad', + split='bad', # type: ignore[arg-type] id_ood_disaster=[ {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, From 83c82cfc1795a8a5740ec759688ad940d10f6c24 Mon Sep 17 00:00:00 2001 From: burakekim Date: Sun, 20 Apr 2025 20:13:54 +0000 Subject: [PATCH 22/28] mypy fix --- tests/datasets/test_xview.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index a094b7dd593..10b7e0f68d3 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -16,7 +16,7 @@ class TestXView2: - @pytest.fixture(params=['train', 'test']) # type: ignore[misc] + @pytest.fixture(params=['train', 'test']) # type: ignore[misc] def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: monkeypatch.setattr( XView2, @@ -95,7 +95,7 @@ def test_plot(self, dataset: XView2) -> None: class TestXView2DistShift: - @pytest.fixture(params=['train', 'test']) # type: ignore[misc] + @pytest.fixture(params=['train', 'test']) # type: ignore[misc] def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistShift: monkeypatch.setattr( XView2DistShift, From a273d5b9dde87ca65fe3c4e7aa6c0705d8b2dd0b Mon Sep 17 00:00:00 2001 From: burakekim Date: Sun, 20 Apr 2025 20:16:17 +0000 Subject: [PATCH 23/28] ruff reformat --- tests/datasets/test_xview.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index 10b7e0f68d3..a094b7dd593 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -16,7 +16,7 @@ class TestXView2: - @pytest.fixture(params=['train', 'test']) # type: ignore[misc] + @pytest.fixture(params=['train', 'test']) # type: ignore[misc] def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: monkeypatch.setattr( XView2, @@ -95,7 +95,7 @@ def test_plot(self, dataset: XView2) -> None: class TestXView2DistShift: - @pytest.fixture(params=['train', 'test']) # type: ignore[misc] + @pytest.fixture(params=['train', 'test']) # type: ignore[misc] def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistShift: monkeypatch.setattr( XView2DistShift, From 3a812ce501661a1958d990f8a0415f115d75d55a Mon Sep 17 00:00:00 2001 From: burakekim Date: Sun, 20 Apr 2025 20:20:21 +0000 Subject: [PATCH 24/28] mypy ignore failing --- tests/datasets/test_xview.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index a094b7dd593..8e6c910ba65 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -16,7 +16,7 @@ class TestXView2: - @pytest.fixture(params=['train', 'test']) # type: ignore[misc] + @pytest.fixture(params=['train', 'test']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2: monkeypatch.setattr( XView2, @@ -95,7 +95,7 @@ def test_plot(self, dataset: XView2) -> None: class TestXView2DistShift: - @pytest.fixture(params=['train', 'test']) # type: ignore[misc] + @pytest.fixture(params=['train', 'test']) def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistShift: monkeypatch.setattr( XView2DistShift, From b04065957f1394174930b476c7a0e50b7708103e Mon Sep 17 00:00:00 2001 From: burakekim Date: Wed, 23 Apr 2025 20:52:22 +0000 Subject: [PATCH 25/28] caleb review --- torchgeo/datasets/xview.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index c8802e7d7ab..3995bd0fe9a 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -378,11 +378,7 @@ def __getitem__(self, index: int) -> dict[str, torch.Tensor]: Returns: data and label at that index """ - file_info = ( - self.split_files['train'][index] - if self.split == 'train' - else self.split_files['test'][index] - ) + file_info = self.split_files[self.split][index] image = self._load_image(file_info['image']) mask = self._load_target(file_info['mask']).long() @@ -403,11 +399,7 @@ def __len__(self) -> int: Returns: length of the dataset """ - return ( - len(self.split_files['train']) - if self.split == 'train' - else len(self.split_files['test']) - ) + return len(self.split_files[self.split]) def get_id_ood_sizes(self) -> tuple[int, int]: """Return the number of samples in the train and test splits.""" From 16d62937f874c056850dee6969aef15c391babbd Mon Sep 17 00:00:00 2001 From: burakekim Date: Tue, 6 May 2025 15:10:17 +0000 Subject: [PATCH 26/28] covering edge scenarios per Adam's comments --- .gitignore | 2 +- torchgeo/datasets/xview.py | 41 ++++++++++++++++++++++++-------------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/.gitignore b/.gitignore index 65c5afb46c2..180c27c47b2 100644 --- a/.gitignore +++ b/.gitignore @@ -147,4 +147,4 @@ venv.bak/ dmypy.json # Pyre type checker -.pyre/ \ No newline at end of file +.pyre/ diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index 3995bd0fe9a..0d6b3f5ce7d 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -5,7 +5,7 @@ import glob import os -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import ClassVar, Literal import matplotlib.pyplot as plt @@ -297,9 +297,9 @@ class XView2DistShift(XView2): .. versionadded:: 0.8 """ - binary_classes: ClassVar[tuple[str, str]] = ('background', 'building') + binary_classes = ('background', 'building') - valid_disasters: ClassVar[list[str]] = [ + valid_disasters = ( 'hurricane-harvey', 'socal-fire', 'hurricane-matthew', @@ -310,18 +310,17 @@ class XView2DistShift(XView2): 'hurricane-florence', 'hurricane-michael', 'midwest-flooding', - ] + ) def __init__( self, root: str = 'data', split: Literal['train', 'test'] = 'train', - id_ood_disaster: list[dict[str, str]] = [ + id_ood_disaster: Sequence[dict[str, str]] = ( {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, - ], - transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] - | None = None, + ), + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, checksum: bool = False, ) -> None: """Initialize the XView2DistShift dataset instance. @@ -340,14 +339,26 @@ def __init__( DatasetNotFoundError: If dataset is not found. """ assert split in ['train', 'test'], "Split must be either 'train' or 'test'." + assert len(id_ood_disaster) == 2, ( + 'id_ood_disaster must contain exactly two items' + ) - if ( - id_ood_disaster[0]['disaster_name'] not in self.valid_disasters - or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters - ): - raise ValueError( - f'Invalid disaster names. Valid options are: {", ".join(self.valid_disasters)}' - ) + for disaster in id_ood_disaster: + if 'disaster_name' not in disaster: + raise ValueError( + "Each disaster entry must contain a 'disaster_name' key." + ) + if disaster['disaster_name'] not in self.valid_disasters: + raise ValueError( + f'Invalid disaster name: {disaster["disaster_name"]}. ' + f'Valid options are: {", ".join(self.valid_disasters)}' + ) + + for disaster in id_ood_disaster: + if 'disaster_name' not in disaster or 'pre-post' not in disaster: + raise ValueError( + "Each disaster entry must contain 'disaster_name' and 'pre-post' keys." + ) self.root = root self.split = split From 8bb52f604a76eb038f00d45ea9ee2ce13d3a05f0 Mon Sep 17 00:00:00 2001 From: burakekim Date: Tue, 6 May 2025 19:16:19 +0000 Subject: [PATCH 27/28] fix valueerror match + prettier --- tests/datasets/test_xview.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index 8e6c910ba65..9c28c65bfbe 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -139,7 +139,7 @@ def test_len(self, dataset: XView2DistShift) -> None: assert len(dataset) > 0 def test_invalid_disaster(self) -> None: - with pytest.raises(ValueError, match='Invalid disaster names'): + with pytest.raises(ValueError, match='Invalid disaster name'): XView2DistShift( root='tests/data/xview2', id_ood_disaster=[ From c333a8e80458d36f111cf31bed8b7467e20135a5 Mon Sep 17 00:00:00 2001 From: burakekim Date: Tue, 6 May 2025 19:31:24 +0000 Subject: [PATCH 28/28] improved test cov --- tests/datasets/test_xview.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index 9c28c65bfbe..baa89ae6431 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -148,6 +148,31 @@ def test_invalid_disaster(self) -> None: ], ) + def test_missing_disaster_name_key(self) -> None: + with pytest.raises( + ValueError, match="Each disaster entry must contain a 'disaster_name' key." + ): + XView2DistShift( + root='tests/data/xview2', + id_ood_disaster=[ + {'pre-post': 'post'}, # missing 'disaster_name' + {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, + ], + ) + + def test_missing_pre_post_key(self) -> None: + with pytest.raises( + ValueError, + match="Each disaster entry must contain 'disaster_name' and 'pre-post' keys.", + ): + XView2DistShift( + root='tests/data/xview2', + id_ood_disaster=[ + {'disaster_name': 'hurricane-harvey'}, # missing 'pre-post' + {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, + ], + ) + def test_invalid_split(self) -> None: with pytest.raises(AssertionError): XView2DistShift(