diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8a3f501ae82..bb6149bbc7b 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -559,6 +559,7 @@ xView2 ^^^^^^ .. autoclass:: XView2 +.. autoclass:: XView2DistShift ZueriCrop ^^^^^^^^^ diff --git a/tests/datasets/test_xview.py b/tests/datasets/test_xview.py index c54b597fadf..baa89ae6431 100644 --- a/tests/datasets/test_xview.py +++ b/tests/datasets/test_xview.py @@ -12,7 +12,7 @@ from _pytest.fixtures import SubRequest from pytest import MonkeyPatch -from torchgeo.datasets import DatasetNotFoundError, XView2 +from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift class TestXView2: @@ -92,3 +92,94 @@ def test_plot(self, dataset: XView2) -> None: x['prediction'] = x['mask'][0].clone() dataset.plot(x) plt.close() + + +class TestXView2DistShift: + @pytest.fixture(params=['train', 'test']) + def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistShift: + monkeypatch.setattr( + XView2DistShift, + 'metadata', + { + 'train': { + 'filename': 'train_images_labels_targets.tar.gz', + 'md5': '373e61d55c1b294aa76b94dbbd81332b', + 'directory': 'train', + }, + 'test': { + 'filename': 'test_images_labels_targets.tar.gz', + 'md5': 'bc6de81c956a3bada38b5b4e246266a1', + 'directory': 'test', + }, + }, + ) + root = os.path.join('tests', 'data', 'xview2') + split = request.param + transforms = nn.Identity() + + return XView2DistShift( + root=root, + split=split, + id_ood_disaster=[ + {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, + {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, + ], + transforms=transforms, + checksum=True, + ) + + def test_getitem(self, dataset: XView2DistShift) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x['image'], torch.Tensor) + assert isinstance(x['mask'], torch.Tensor) + assert set(torch.unique(x['mask']).tolist()).issubset({0, 1}) # binary mask + + def test_len(self, dataset: XView2DistShift) -> None: + assert len(dataset) > 0 + + def test_invalid_disaster(self) -> None: + with pytest.raises(ValueError, match='Invalid disaster name'): + XView2DistShift( + root='tests/data/xview2', + id_ood_disaster=[ + {'disaster_name': 'not-a-real-one', 'pre-post': 'post'}, + {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, + ], + ) + + def test_missing_disaster_name_key(self) -> None: + with pytest.raises( + ValueError, match="Each disaster entry must contain a 'disaster_name' key." + ): + XView2DistShift( + root='tests/data/xview2', + id_ood_disaster=[ + {'pre-post': 'post'}, # missing 'disaster_name' + {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, + ], + ) + + def test_missing_pre_post_key(self) -> None: + with pytest.raises( + ValueError, + match="Each disaster entry must contain 'disaster_name' and 'pre-post' keys.", + ): + XView2DistShift( + root='tests/data/xview2', + id_ood_disaster=[ + {'disaster_name': 'hurricane-harvey'}, # missing 'pre-post' + {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, + ], + ) + + def test_invalid_split(self) -> None: + with pytest.raises(AssertionError): + XView2DistShift( + root='tests/data/xview2', + split='bad', # type: ignore[arg-type] + id_ood_disaster=[ + {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, + {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, + ], + ) diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index b1cf39fbfd0..5cbfbb0a18f 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -176,7 +176,7 @@ from .vaihingen import Vaihingen2D from .vhr10 import VHR10 from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture -from .xview import XView2 +from .xview import XView2, XView2DistShift from .zuericrop import ZueriCrop __all__ = ( @@ -347,6 +347,7 @@ 'VectorDataset', 'WesternUSALiveFuelMoisture', 'XView2', + 'XView2DistShift', 'ZueriCrop', 'concat_samples', 'merge_samples', diff --git a/torchgeo/datasets/xview.py b/torchgeo/datasets/xview.py index a7f6a36456a..0d6b3f5ce7d 100644 --- a/torchgeo/datasets/xview.py +++ b/torchgeo/datasets/xview.py @@ -5,8 +5,8 @@ import glob import os -from collections.abc import Callable -from typing import ClassVar +from collections.abc import Callable, Sequence +from typing import ClassVar, Literal import matplotlib.pyplot as plt import numpy as np @@ -138,7 +138,7 @@ def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: Args: root: root dir of dataset - split: subset of dataset, one of [train, test] + split: subset of dataset, one of ['train', 'test'] Returns: list of dicts containing paths for each pair of images and masks @@ -282,3 +282,213 @@ def plot( plt.suptitle(suptitle) return fig + + +class XView2DistShift(XView2): + """A subclass of the XView2 dataset designed to reformat the original train/test splits. + + This class allows for the selection of particular disasters to be used as the + training set (in-domain) and test set (out-of-domain). The dataset can be split + according to the disaster names specified by the user, enabling the model to train + on one disaster type and evaluate on a different, out-of-domain disaster. The goal + is to test the generalization ability of models trained on one disaster to perform + on others. + + .. versionadded:: 0.8 + """ + + binary_classes = ('background', 'building') + + valid_disasters = ( + 'hurricane-harvey', + 'socal-fire', + 'hurricane-matthew', + 'mexico-earthquake', + 'guatemala-volcano', + 'santa-rosa-wildfire', + 'palu-tsunami', + 'hurricane-florence', + 'hurricane-michael', + 'midwest-flooding', + ) + + def __init__( + self, + root: str = 'data', + split: Literal['train', 'test'] = 'train', + id_ood_disaster: Sequence[dict[str, str]] = ( + {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, + {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, + ), + transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, + checksum: bool = False, + ) -> None: + """Initialize the XView2DistShift dataset instance. + + Args: + root: Root directory where the dataset is located. + split: One of "train" or "test". + id_ood_disaster: List containing in-distribution and out-of-distribution disaster names. + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + AssertionError: If *split* is invalid. + ValueError: If a disaster name in *id_ood_disaster* is not one of the valid disasters. + DatasetNotFoundError: If dataset is not found. + """ + assert split in ['train', 'test'], "Split must be either 'train' or 'test'." + assert len(id_ood_disaster) == 2, ( + 'id_ood_disaster must contain exactly two items' + ) + + for disaster in id_ood_disaster: + if 'disaster_name' not in disaster: + raise ValueError( + "Each disaster entry must contain a 'disaster_name' key." + ) + if disaster['disaster_name'] not in self.valid_disasters: + raise ValueError( + f'Invalid disaster name: {disaster["disaster_name"]}. ' + f'Valid options are: {", ".join(self.valid_disasters)}' + ) + + for disaster in id_ood_disaster: + if 'disaster_name' not in disaster or 'pre-post' not in disaster: + raise ValueError( + "Each disaster entry must contain 'disaster_name' and 'pre-post' keys." + ) + + self.root = root + self.split = split + self.transforms = transforms + self.checksum = checksum + + self._verify() + + # Load all files and compute basenames and disasters only once + self.all_files = self._initialize_files(root) + + # Split logic by disaster and pre-post type + self.split_files: dict[str, list[dict[str, str]]] = ( + self._load_split_files_by_disaster_and_type( + self.all_files, id_ood_disaster[0], id_ood_disaster[1] + ) + ) + + train_size, test_size = self.get_id_ood_sizes() + print(f'ID sample len: {train_size}, OOD sample len: {test_size}') + + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + file_info = self.split_files[self.split][index] + + image = self._load_image(file_info['image']) + mask = self._load_target(file_info['mask']).long() + + # Reformulate as building segmentation task + mask[mask == 2] = 1 # minor-damage → building + mask[(mask == 3) | (mask == 4)] = 0 # major/destroyed → background + + sample = {'image': image, 'mask': mask} + if self.transforms: + sample = self.transforms(sample) + + return sample + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.split_files[self.split]) + + def get_id_ood_sizes(self) -> tuple[int, int]: + """Return the number of samples in the train and test splits.""" + return (len(self.split_files['train']), len(self.split_files['test'])) + + def _initialize_files(self, root: str) -> list[dict[str, str]]: + """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" + all_files = [] + for split in self.metadata.keys(): + image_root = os.path.join(root, split, 'images') + mask_root = os.path.join(root, split, 'targets') + images = sorted(glob.glob(os.path.join(image_root, '*.png'))) + + # Extract basenames while preserving the disaster-name and sample number + for img in images: + basename_parts = os.path.basename(img).split('_') + event_name = basename_parts[0] # e.g., mexico-earthquake + sample_number = basename_parts[1] # e.g., 00000001 + basename = f'{event_name}_{sample_number}' + + file_info = { + 'image': img, + 'mask': os.path.join( + mask_root, f'{basename}_pre_disaster_target.png' + ), + 'basename': basename, + } + all_files.append(file_info) + return all_files + + def _load_split_files_by_disaster_and_type( + self, + files: list[dict[str, str]], + id_disaster: dict[str, str], + ood_disaster: dict[str, str], + ) -> dict[str, list[dict[str, str]]]: + """Return train and test filepaths based on disaster name and pre/post type. + + Args: + files: List of file paths with their corresponding information. + id_disaster: Dictionary specifying in-domain (ID) disaster and type, + e.g., {'disaster_name': 'guatemala-volcano', 'pre-post': 'pre'}. + ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type, + e.g., {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}. + + Returns: + A dictionary containing 'train' (ID) and 'test' (OOD) file lists. + """ + train_files = [] + test_files = [] + + for file_info in files: + basename = file_info['basename'] + disaster_name = basename.split('_')[0] + pre_post = 'pre' if 'pre_disaster' in file_info['image'] else 'post' + + if disaster_name == id_disaster['disaster_name']: + if ( + id_disaster.get('pre-post') == 'both' + or id_disaster['pre-post'] == pre_post + ): + image = ( + file_info['image'].replace('post_disaster', 'pre_disaster') + if pre_post == 'pre' + else file_info['image'] + ) + mask = ( + file_info['mask'].replace('post_disaster', 'pre_disaster') + if pre_post == 'pre' + else file_info['mask'] + ) + train_files.append(dict(image=image, mask=mask)) + + if disaster_name == ood_disaster['disaster_name']: + if ( + ood_disaster.get('pre-post') == 'both' + or ood_disaster['pre-post'] == pre_post + ): + test_files.append(file_info) + + return {'train': train_files, 'test': test_files}