-
Notifications
You must be signed in to change notification settings - Fork 506
Custom disaster-based train/test splits for xView2 dataset #2416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 20 commits
da2399b
9791f12
0f57ecf
62919bf
5985f44
a23344e
8072f2b
dc97c66
459704c
8239d32
c45bb25
0f3ceb7
a74c99f
b66c792
5c2a6e6
ae33510
f587d9a
7d2dbbb
f011b44
4bc5510
4cc0ac0
6d4ac9e
6e32b6b
dd5e626
b73de3c
0150c32
b7296e7
83c82cf
a273d5b
3a812ce
34a5687
b040659
7b947c8
d442682
50cd637
16d6293
103d7cf
8bb52f6
f08cbbb
c333a8e
ec1b8ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -148,3 +148,4 @@ dmypy.json | |
|
|
||
| # Pyre type checker | ||
| .pyre/ | ||
| xbdood.ipynb | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -559,6 +559,7 @@ xView2 | |
| ^^^^^^ | ||
|
|
||
| .. autoclass:: XView2 | ||
| .. autoclass:: XView2DistShift | ||
|
|
||
| ZueriCrop | ||
| ^^^^^^^^^ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |
| from _pytest.fixtures import SubRequest | ||
| from pytest import MonkeyPatch | ||
|
|
||
| from torchgeo.datasets import DatasetNotFoundError, XView2 | ||
| from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift | ||
|
|
||
|
|
||
| class TestXView2: | ||
|
|
@@ -92,3 +92,69 @@ def test_plot(self, dataset: XView2) -> None: | |
| x['prediction'] = x['mask'][0].clone() | ||
| dataset.plot(x) | ||
| plt.close() | ||
|
|
||
|
|
||
| class TestXView2DistShift: | ||
| @pytest.fixture(params=['train', 'test']) | ||
| def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistShift: | ||
| monkeypatch.setattr( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might not even need to monkeypatch this if you remove |
||
| XView2DistShift, | ||
| 'metadata', | ||
| { | ||
| 'train': { | ||
| 'filename': 'train_images_labels_targets.tar.gz', | ||
| 'md5': '373e61d55c1b294aa76b94dbbd81332b', | ||
| 'directory': 'train', | ||
| }, | ||
| 'test': { | ||
| 'filename': 'test_images_labels_targets.tar.gz', | ||
| 'md5': 'bc6de81c956a3bada38b5b4e246266a1', | ||
| 'directory': 'test', | ||
| }, | ||
| }, | ||
| ) | ||
| root = os.path.join('tests', 'data', 'xview2') | ||
| split = request.param | ||
| transforms = nn.Identity() | ||
|
|
||
| return XView2DistShift( | ||
| root=root, | ||
| split=split, | ||
| id_ood_disaster=[ | ||
| {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, | ||
| {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this duplication required for some reason? |
||
| ], | ||
| transforms=transforms, | ||
| checksum=True, | ||
| ) | ||
|
|
||
| def test_getitem(self, dataset: XView2DistShift) -> None: | ||
| x = dataset[0] | ||
| assert isinstance(x, dict) | ||
| assert isinstance(x['image'], torch.Tensor) | ||
| assert isinstance(x['mask'], torch.Tensor) | ||
| assert set(torch.unique(x['mask']).tolist()).issubset({0, 1}) # binary mask | ||
|
|
||
| def test_len(self, dataset: XView2DistShift) -> None: | ||
| assert len(dataset) > 0 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's test a specific length to ensure it behaves as expected. Can use an if-statement if it's different for train and test. |
||
|
|
||
| def test_invalid_disaster(self) -> None: | ||
| with pytest.raises(ValueError, match='Invalid disaster names'): | ||
| XView2DistShift( | ||
| root='tests/data/xview2', | ||
| id_ood_disaster=[ | ||
| {'disaster_name': 'not-a-real-one', 'pre-post': 'post'}, | ||
| {'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, | ||
| ], | ||
| ) | ||
|
|
||
| def test_invalid_split(self) -> None: | ||
| with pytest.raises(AssertionError): | ||
| XView2DistShift( | ||
| root='tests/data/xview2', | ||
| split='bad', | ||
| id_ood_disaster=[ | ||
| {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, | ||
| {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, | ||
| ], | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -138,7 +138,7 @@ def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: | |
|
|
||
| Args: | ||
| root: root dir of dataset | ||
| split: subset of dataset, one of [train, test] | ||
| split: subset of dataset, one of ['train', 'test'] | ||
|
|
||
| Returns: | ||
| list of dicts containing paths for each pair of images and masks | ||
|
|
@@ -282,3 +282,195 @@ def plot( | |
| plt.suptitle(suptitle) | ||
|
|
||
| return fig | ||
|
|
||
|
|
||
| class XView2DistShift(XView2): | ||
| """A subclass of the XView2 dataset designed to reformat the original train/test splits. | ||
|
|
||
| This class allows for the selection of particular disasters to be used as the | ||
| training set (in-domain) and test set (out-of-domain). The dataset can be split | ||
| according to the disaster names specified by the user, enabling the model to train | ||
| on one disaster type and evaluate on a different, out-of-domain disaster. The goal | ||
| is to test the generalization ability of models trained on one disaster to perform | ||
| on others. | ||
burakekim marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| binary_classes: ClassVar[tuple[str, str]] = ('background', 'building') | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| valid_disasters: ClassVar[list[str]] = [ | ||
| 'hurricane-harvey', | ||
| 'socal-fire', | ||
| 'hurricane-matthew', | ||
| 'mexico-earthquake', | ||
| 'guatemala-volcano', | ||
| 'santa-rosa-wildfire', | ||
| 'palu-tsunami', | ||
| 'hurricane-florence', | ||
| 'hurricane-michael', | ||
| 'midwest-flooding', | ||
| ] | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def __init__( | ||
| self, | ||
| root: str = 'data', | ||
| split: str = 'train', | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| id_ood_disaster: list[dict[str, str]] = [ | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| {'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, | ||
| {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, | ||
| ], | ||
| transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| | None = None, | ||
| checksum: bool = False, | ||
| ) -> None: | ||
| """Initialize the XView2DistShift dataset instance. | ||
|
|
||
| Args: | ||
| root: Root directory where the dataset is located. | ||
| split: One of "train" or "test". | ||
| id_ood_disaster: List containing in-distribution and out-of-distribution disaster names. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs more description. Clarify that both |
||
| transforms: a function/transform that takes input sample and its target as | ||
| entry and returns a transformed version | ||
| checksum: if True, check the MD5 of the downloaded files (may be slow) | ||
|
|
||
| Raises: | ||
| AssertionError: If *split* is invalid. | ||
| ValueError: If a disaster name in `id_ood_disaster` is not one of the valid disasters. | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| DatasetNotFoundError: If dataset is not found. | ||
| """ | ||
| assert split in ['train', 'test'], "Split must be either 'train' or 'test'." | ||
|
|
||
| if ( | ||
| id_ood_disaster[0]['disaster_name'] not in self.valid_disasters | ||
| or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters | ||
| ): | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise ValueError( | ||
| f'Invalid disaster names. Valid options are: {", ".join(self.valid_disasters)}' | ||
| ) | ||
|
|
||
| self.root = root | ||
| self.split = split | ||
| self.transforms = transforms | ||
| self.checksum = checksum | ||
|
|
||
| self._verify() | ||
|
|
||
| # Load all files and compute basenames and disasters only once | ||
| self.all_files = self._initialize_files(root) | ||
|
|
||
| # Split logic by disaster and pre-post type | ||
| self.split_files: dict[str, list[dict[str, str]]] = ( | ||
| self._load_split_files_by_disaster_and_type( | ||
| self.all_files, id_ood_disaster[0], id_ood_disaster[1] | ||
burakekim marked this conversation as resolved.
Show resolved
Hide resolved
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I pass in a dict of length 100, are the later 98 arguments ignored? Seems like a dict is a bad choice for this. Why not use 4 different parameters so they can be more clearly documented and type checked? |
||
| ) | ||
| ) | ||
|
|
||
| train_size, test_size = self.get_id_ood_sizes() | ||
| print(f'ID sample len: {train_size}, OOD sample len: {test_size}') | ||
|
|
||
| def __getitem__(self, index: int) -> dict[str, torch.Tensor]: | ||
| """Get an item from the dataset at the given index.""" | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| file_info = ( | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.split_files['train'][index] | ||
| if self.split == 'train' | ||
| else self.split_files['test'][index] | ||
| ) | ||
|
|
||
| image = self._load_image(file_info['image']) | ||
| mask = self._load_target(file_info['mask']).long() | ||
|
|
||
| # Reformulate as building segmentation task | ||
| mask[mask == 2] = 1 # minor-damage → building | ||
| mask[(mask == 3) | (mask == 4)] = 0 # major/destroyed → background | ||
|
|
||
| sample = {'image': image, 'mask': mask} | ||
| if self.transforms: | ||
| sample = self.transforms(sample) | ||
|
|
||
| return sample | ||
|
|
||
| def __len__(self) -> int: | ||
| """Return the total number of samples in the dataset.""" | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return ( | ||
| len(self.split_files['train']) | ||
| if self.split == 'train' | ||
| else len(self.split_files['test']) | ||
| ) | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def get_id_ood_sizes(self) -> tuple[int, int]: | ||
| """Return the number of samples in the train and test splits.""" | ||
| return (len(self.split_files['train']), len(self.split_files['test'])) | ||
|
|
||
| def _initialize_files(self, root: str) -> list[dict[str, str]]: | ||
| """Initialize the dataset by loading file paths and computing basenames with sample numbers.""" | ||
| all_files = [] | ||
| for split in self.metadata.keys(): | ||
| image_root = os.path.join(root, split, 'images') | ||
| mask_root = os.path.join(root, split, 'targets') | ||
| images = glob.glob(os.path.join(image_root, '*.png')) | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Extract basenames while preserving the disaster-name and sample number | ||
| for img in images: | ||
| basename_parts = os.path.basename(img).split('_') | ||
| event_name = basename_parts[0] # e.g., mexico-earthquake | ||
| sample_number = basename_parts[1] # e.g., 00000001 | ||
| basename = f'{event_name}_{sample_number}' | ||
|
|
||
| file_info = { | ||
| 'image': img, | ||
| 'mask': os.path.join( | ||
| mask_root, f'{basename}_pre_disaster_target.png' | ||
| ), | ||
| 'basename': basename, | ||
| } | ||
| all_files.append(file_info) | ||
| return all_files | ||
|
|
||
| def _load_split_files_by_disaster_and_type( | ||
| self, | ||
| files: list[dict[str, str]], | ||
| id_disaster: dict[str, str], | ||
| ood_disaster: dict[str, str], | ||
| ) -> dict[str, list[dict[str, str]]]: | ||
| """Return the filepaths for the train (ID) and test (OOD) sets based on disaster name and pre-post disaster type. | ||
burakekim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Args: | ||
| files: List of file paths with their corresponding information. | ||
| id_disaster: Dictionary specifying in-domain (ID) disaster and type (e.g., {'disaster_name': 'guatemala-volcano', 'pre-post': 'pre'}). | ||
| ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type (e.g., {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}). | ||
|
|
||
| Returns: | ||
| A dictionary containing 'train' (ID) and 'test' (OOD) file lists. | ||
| """ | ||
| train_files = [] | ||
| test_files = [] | ||
|
|
||
| for file_info in files: | ||
| basename = file_info['basename'] | ||
| disaster_name = basename.split('_')[0] | ||
| pre_post = 'pre' if 'pre_disaster' in file_info['image'] else 'post' | ||
|
|
||
| if disaster_name == id_disaster['disaster_name']: | ||
| if ( | ||
| id_disaster.get('pre-post') == 'both' | ||
| or id_disaster['pre-post'] == pre_post | ||
| ): | ||
| image = ( | ||
| file_info['image'].replace('post_disaster', 'pre_disaster') | ||
| if pre_post == 'pre' | ||
| else file_info['image'] | ||
| ) | ||
| mask = ( | ||
| file_info['mask'].replace('post_disaster', 'pre_disaster') | ||
| if pre_post == 'pre' | ||
| else file_info['mask'] | ||
| ) | ||
| train_files.append(dict(image=image, mask=mask)) | ||
|
|
||
| if disaster_name == ood_disaster['disaster_name']: | ||
| if ( | ||
| ood_disaster.get('pre-post') == 'both' | ||
| or ood_disaster['pre-post'] == pre_post | ||
| ): | ||
| test_files.append(file_info) | ||
|
|
||
| return {'train': train_files, 'test': test_files} | ||
Uh oh!
There was an error while loading. Please reload this page.