-
Notifications
You must be signed in to change notification settings - Fork 443
Custom disaster-based train/test splits for xView2 dataset #2416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
da2399b
9791f12
0f57ecf
62919bf
5985f44
a23344e
8072f2b
dc97c66
459704c
8239d32
c45bb25
0f3ceb7
a74c99f
b66c792
5c2a6e6
ae33510
f587d9a
7d2dbbb
f011b44
4bc5510
4cc0ac0
6d4ac9e
6e32b6b
dd5e626
b73de3c
0150c32
b7296e7
83c82cf
a273d5b
3a812ce
34a5687
b040659
7b947c8
d442682
50cd637
16d6293
103d7cf
8bb52f6
f08cbbb
c333a8e
ec1b8ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -559,6 +559,7 @@ xView2 | |
^^^^^^ | ||
|
||
.. autoclass:: XView2 | ||
.. autoclass:: XView2DistShift | ||
|
||
ZueriCrop | ||
^^^^^^^^^ | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,7 +12,7 @@ | |||||
from _pytest.fixtures import SubRequest | ||||||
from pytest import MonkeyPatch | ||||||
|
||||||
from torchgeo.datasets import DatasetNotFoundError, XView2 | ||||||
from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift | ||||||
|
||||||
|
||||||
class TestXView2: | ||||||
|
@@ -92,3 +92,94 @@ def test_plot(self, dataset: XView2) -> None: | |||||
x['prediction'] = x['mask'][0].clone() | ||||||
dataset.plot(x) | ||||||
plt.close() | ||||||
|
||||||
|
||||||
class TestXView2DistShift: | ||||||
@pytest.fixture(params=['train', 'test']) | ||||||
def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistShift: | ||||||
monkeypatch.setattr( | ||||||
XView2DistShift, | ||||||
'metadata', | ||||||
{ | ||||||
'train': { | ||||||
'filename': 'train_images_labels_targets.tar.gz', | ||||||
'md5': '373e61d55c1b294aa76b94dbbd81332b', | ||||||
'directory': 'train', | ||||||
}, | ||||||
'test': { | ||||||
'filename': 'test_images_labels_targets.tar.gz', | ||||||
'md5': 'bc6de81c956a3bada38b5b4e246266a1', | ||||||
'directory': 'test', | ||||||
}, | ||||||
}, | ||||||
) | ||||||
root = os.path.join('tests', 'data', 'xview2') | ||||||
split = request.param | ||||||
transforms = nn.Identity() | ||||||
|
||||||
return XView2DistShift( | ||||||
root=root, | ||||||
split=split, | ||||||
id_ood_disaster=[ | ||||||
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, | ||||||
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this duplication required for some reason? |
||||||
], | ||||||
transforms=transforms, | ||||||
checksum=True, | ||||||
) | ||||||
|
||||||
def test_getitem(self, dataset: XView2DistShift) -> None: | ||||||
x = dataset[0] | ||||||
assert isinstance(x, dict) | ||||||
assert isinstance(x['image'], torch.Tensor) | ||||||
assert isinstance(x['mask'], torch.Tensor) | ||||||
assert set(torch.unique(x['mask']).tolist()).issubset({0, 1}) # binary mask | ||||||
|
||||||
def test_len(self, dataset: XView2DistShift) -> None: | ||||||
assert len(dataset) > 0 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's test a specific length to ensure it behaves as expected. Can use an if-statement if it's different for train and test. |
||||||
|
||||||
def test_invalid_disaster(self) -> None: | ||||||
with pytest.raises(ValueError, match='Invalid disaster name'): | ||||||
XView2DistShift( | ||||||
root='tests/data/xview2', | ||||||
id_ood_disaster=[ | ||||||
{'disaster_name': 'not-a-real-one', 'pre-post': 'post'}, | ||||||
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, | ||||||
], | ||||||
) | ||||||
|
||||||
def test_missing_disaster_name_key(self) -> None: | ||||||
with pytest.raises( | ||||||
ValueError, match="Each disaster entry must contain a 'disaster_name' key." | ||||||
): | ||||||
XView2DistShift( | ||||||
root='tests/data/xview2', | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Windows |
||||||
id_ood_disaster=[ | ||||||
{'pre-post': 'post'}, # missing 'disaster_name' | ||||||
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, | ||||||
], | ||||||
) | ||||||
|
||||||
def test_missing_pre_post_key(self) -> None: | ||||||
with pytest.raises( | ||||||
ValueError, | ||||||
match="Each disaster entry must contain 'disaster_name' and 'pre-post' keys.", | ||||||
): | ||||||
XView2DistShift( | ||||||
root='tests/data/xview2', | ||||||
id_ood_disaster=[ | ||||||
{'disaster_name': 'hurricane-harvey'}, # missing 'pre-post' | ||||||
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, | ||||||
], | ||||||
) | ||||||
|
||||||
def test_invalid_split(self) -> None: | ||||||
with pytest.raises(AssertionError): | ||||||
XView2DistShift( | ||||||
root='tests/data/xview2', | ||||||
split='bad', # type: ignore[arg-type] | ||||||
id_ood_disaster=[ | ||||||
{'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, | ||||||
{'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, | ||||||
], | ||||||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,8 +5,8 @@ | |
|
||
import glob | ||
import os | ||
from collections.abc import Callable | ||
from typing import ClassVar | ||
from collections.abc import Callable, Sequence | ||
from typing import ClassVar, Literal | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
@@ -138,7 +138,7 @@ def _load_files(self, root: Path, split: str) -> list[dict[str, str]]: | |
|
||
Args: | ||
root: root dir of dataset | ||
split: subset of dataset, one of [train, test] | ||
split: subset of dataset, one of ['train', 'test'] | ||
|
||
Returns: | ||
list of dicts containing paths for each pair of images and masks | ||
|
@@ -282,3 +282,213 @@ def plot( | |
plt.suptitle(suptitle) | ||
|
||
return fig | ||
|
||
|
||
class XView2DistShift(XView2): | ||
"""A subclass of the XView2 dataset designed to reformat the original train/test splits. | ||
|
||
This class allows for the selection of particular disasters to be used as the | ||
training set (in-domain) and test set (out-of-domain). The dataset can be split | ||
according to the disaster names specified by the user, enabling the model to train | ||
on one disaster type and evaluate on a different, out-of-domain disaster. The goal | ||
is to test the generalization ability of models trained on one disaster to perform | ||
on others. | ||
burakekim marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
.. versionadded:: 0.8 | ||
""" | ||
|
||
binary_classes = ('background', 'building') | ||
|
||
valid_disasters = ( | ||
'hurricane-harvey', | ||
'socal-fire', | ||
'hurricane-matthew', | ||
'mexico-earthquake', | ||
'guatemala-volcano', | ||
'santa-rosa-wildfire', | ||
'palu-tsunami', | ||
'hurricane-florence', | ||
'hurricane-michael', | ||
'midwest-flooding', | ||
) | ||
|
||
def __init__( | ||
self, | ||
root: str = 'data', | ||
split: Literal['train', 'test'] = 'train', | ||
id_ood_disaster: Sequence[dict[str, str]] = ( | ||
{'disaster_name': 'hurricane-matthew', 'pre-post': 'post'}, | ||
{'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}, | ||
), | ||
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, | ||
checksum: bool = False, | ||
) -> None: | ||
"""Initialize the XView2DistShift dataset instance. | ||
|
||
Args: | ||
root: Root directory where the dataset is located. | ||
split: One of "train" or "test". | ||
id_ood_disaster: List containing in-distribution and out-of-distribution disaster names. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs more description. Clarify that both |
||
transforms: a function/transform that takes input sample and its target as | ||
entry and returns a transformed version | ||
checksum: if True, check the MD5 of the downloaded files (may be slow) | ||
|
||
Raises: | ||
AssertionError: If *split* is invalid. | ||
ValueError: If a disaster name in *id_ood_disaster* is not one of the valid disasters. | ||
Comment on lines
+337
to
+338
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like many other undocumented cases where these are raised |
||
DatasetNotFoundError: If dataset is not found. | ||
""" | ||
assert split in ['train', 'test'], "Split must be either 'train' or 'test'." | ||
assert len(id_ood_disaster) == 2, ( | ||
'id_ood_disaster must contain exactly two items' | ||
) | ||
|
||
for disaster in id_ood_disaster: | ||
if 'disaster_name' not in disaster: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check is duplicated below... |
||
raise ValueError( | ||
"Each disaster entry must contain a 'disaster_name' key." | ||
) | ||
if disaster['disaster_name'] not in self.valid_disasters: | ||
raise ValueError( | ||
f'Invalid disaster name: {disaster["disaster_name"]}. ' | ||
f'Valid options are: {", ".join(self.valid_disasters)}' | ||
) | ||
|
||
for disaster in id_ood_disaster: | ||
if 'disaster_name' not in disaster or 'pre-post' not in disaster: | ||
raise ValueError( | ||
"Each disaster entry must contain 'disaster_name' and 'pre-post' keys." | ||
) | ||
|
||
self.root = root | ||
self.split = split | ||
self.transforms = transforms | ||
self.checksum = checksum | ||
|
||
self._verify() | ||
|
||
# Load all files and compute basenames and disasters only once | ||
self.all_files = self._initialize_files(root) | ||
|
||
# Split logic by disaster and pre-post type | ||
self.split_files: dict[str, list[dict[str, str]]] = ( | ||
self._load_split_files_by_disaster_and_type( | ||
self.all_files, id_ood_disaster[0], id_ood_disaster[1] | ||
burakekim marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I pass in a dict of length 100, are the later 98 arguments ignored? Seems like a dict is a bad choice for this. Why not use 4 different parameters so they can be more clearly documented and type checked? |
||
) | ||
) | ||
|
||
train_size, test_size = self.get_id_ood_sizes() | ||
print(f'ID sample len: {train_size}, OOD sample len: {test_size}') | ||
|
||
def __getitem__(self, index: int) -> dict[str, torch.Tensor]: | ||
"""Return an index within the dataset. | ||
|
||
Args: | ||
index: index to return | ||
|
||
Returns: | ||
data and label at that index | ||
""" | ||
file_info = self.split_files[self.split][index] | ||
|
||
image = self._load_image(file_info['image']) | ||
mask = self._load_target(file_info['mask']).long() | ||
|
||
# Reformulate as building segmentation task | ||
mask[mask == 2] = 1 # minor-damage → building | ||
mask[(mask == 3) | (mask == 4)] = 0 # major/destroyed → background | ||
|
||
sample = {'image': image, 'mask': mask} | ||
if self.transforms: | ||
sample = self.transforms(sample) | ||
|
||
return sample | ||
|
||
def __len__(self) -> int: | ||
"""Return the number of data points in the dataset. | ||
|
||
Returns: | ||
length of the dataset | ||
""" | ||
return len(self.split_files[self.split]) | ||
|
||
def get_id_ood_sizes(self) -> tuple[int, int]: | ||
"""Return the number of samples in the train and test splits.""" | ||
return (len(self.split_files['train']), len(self.split_files['test'])) | ||
|
||
def _initialize_files(self, root: str) -> list[dict[str, str]]: | ||
"""Initialize the dataset by loading file paths and computing basenames with sample numbers.""" | ||
all_files = [] | ||
for split in self.metadata.keys(): | ||
image_root = os.path.join(root, split, 'images') | ||
mask_root = os.path.join(root, split, 'targets') | ||
images = sorted(glob.glob(os.path.join(image_root, '*.png'))) | ||
|
||
# Extract basenames while preserving the disaster-name and sample number | ||
for img in images: | ||
basename_parts = os.path.basename(img).split('_') | ||
event_name = basename_parts[0] # e.g., mexico-earthquake | ||
sample_number = basename_parts[1] # e.g., 00000001 | ||
basename = f'{event_name}_{sample_number}' | ||
|
||
file_info = { | ||
'image': img, | ||
'mask': os.path.join( | ||
mask_root, f'{basename}_pre_disaster_target.png' | ||
), | ||
'basename': basename, | ||
} | ||
all_files.append(file_info) | ||
return all_files | ||
|
||
def _load_split_files_by_disaster_and_type( | ||
self, | ||
files: list[dict[str, str]], | ||
id_disaster: dict[str, str], | ||
ood_disaster: dict[str, str], | ||
) -> dict[str, list[dict[str, str]]]: | ||
"""Return train and test filepaths based on disaster name and pre/post type. | ||
|
||
Args: | ||
files: List of file paths with their corresponding information. | ||
id_disaster: Dictionary specifying in-domain (ID) disaster and type, | ||
e.g., {'disaster_name': 'guatemala-volcano', 'pre-post': 'pre'}. | ||
ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type, | ||
e.g., {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}. | ||
|
||
Returns: | ||
A dictionary containing 'train' (ID) and 'test' (OOD) file lists. | ||
""" | ||
train_files = [] | ||
test_files = [] | ||
|
||
for file_info in files: | ||
basename = file_info['basename'] | ||
disaster_name = basename.split('_')[0] | ||
pre_post = 'pre' if 'pre_disaster' in file_info['image'] else 'post' | ||
|
||
if disaster_name == id_disaster['disaster_name']: | ||
if ( | ||
id_disaster.get('pre-post') == 'both' | ||
or id_disaster['pre-post'] == pre_post | ||
): | ||
image = ( | ||
file_info['image'].replace('post_disaster', 'pre_disaster') | ||
if pre_post == 'pre' | ||
else file_info['image'] | ||
) | ||
mask = ( | ||
file_info['mask'].replace('post_disaster', 'pre_disaster') | ||
if pre_post == 'pre' | ||
else file_info['mask'] | ||
) | ||
train_files.append(dict(image=image, mask=mask)) | ||
|
||
if disaster_name == ood_disaster['disaster_name']: | ||
if ( | ||
ood_disaster.get('pre-post') == 'both' | ||
or ood_disaster['pre-post'] == pre_post | ||
): | ||
test_files.append(file_info) | ||
|
||
return {'train': train_files, 'test': test_files} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might not even need to monkeypatch this if you remove
checksum=True
, I never bother with this