Skip to content

Custom disaster-based train/test splits for xView2 dataset #2416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
da2399b
minor typo in custom_raster_dataset.ipynb
burakekim Apr 9, 2024
9791f12
Merge branch 'main' of https://github.com/burakekim/torchgeo
burakekim Jul 17, 2024
0f57ecf
xview2 dist shift initial commit
burakekim Nov 18, 2024
62919bf
xview2distshift dataset
burakekim Nov 18, 2024
5985f44
test xview2
burakekim Nov 18, 2024
a23344e
formatting
burakekim Nov 18, 2024
8072f2b
Merge branch 'microsoft:main' into main
burakekim Jan 3, 2025
dc97c66
Merge branch 'main' into oodxbd
burakekim Feb 1, 2025
459704c
Merge branch 'main' into oodxbd
burakekim Feb 1, 2025
8239d32
" to '
burakekim Feb 1, 2025
c45bb25
Merge branch 'oodxbd' of https://github.com/burakekim/torchgeo into o…
burakekim Feb 1, 2025
0f3ceb7
no cuda yes docstring
burakekim Feb 1, 2025
a74c99f
id ood length method and polishing
burakekim Feb 1, 2025
b66c792
idk what this is
burakekim Apr 18, 2025
5c2a6e6
ruff fixes
burakekim Apr 18, 2025
ae33510
make mypy happy
burakekim Apr 18, 2025
f587d9a
Merge branch 'microsoft:main' into oodxbd
burakekim Apr 19, 2025
7d2dbbb
make ruff happy
burakekim Apr 19, 2025
f011b44
add tests
burakekim Apr 19, 2025
4bc5510
more test
burakekim Apr 19, 2025
4cc0ac0
Update torchgeo/datasets/xview.py
burakekim Apr 20, 2025
6d4ac9e
Update torchgeo/datasets/xview.py
burakekim Apr 20, 2025
6e32b6b
Update torchgeo/datasets/xview.py
burakekim Apr 20, 2025
dd5e626
post review changes
burakekim Apr 20, 2025
b73de3c
smol fixes
burakekim Apr 20, 2025
0150c32
ruff again
burakekim Apr 20, 2025
b7296e7
mypy fixes
burakekim Apr 20, 2025
83c82cf
mypy fix
burakekim Apr 20, 2025
a273d5b
ruff reformat
burakekim Apr 20, 2025
3a812ce
mypy ignore failing
burakekim Apr 20, 2025
34a5687
Merge branch 'main' into oodxbd
burakekim Apr 23, 2025
b040659
caleb review
burakekim Apr 23, 2025
7b947c8
Merge branch 'oodxbd' of https://github.com/burakekim/torchgeo into o…
burakekim Apr 23, 2025
d442682
Merge branch 'main' into oodxbd
burakekim Apr 23, 2025
50cd637
Merge branch 'main' into oodxbd
burakekim Apr 24, 2025
16d6293
covering edge scenarios per Adam's comments
burakekim May 6, 2025
103d7cf
Merge branch 'main' into oodxbd
burakekim May 6, 2025
8bb52f6
fix valueerror match + prettier
burakekim May 6, 2025
f08cbbb
Merge branch 'oodxbd' of https://github.com/burakekim/torchgeo into o…
burakekim May 6, 2025
c333a8e
improved test cov
burakekim May 6, 2025
ec1b8ae
Merge branch 'main' into oodxbd
burakekim May 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ xView2
^^^^^^

.. autoclass:: XView2
.. autoclass:: XView2DistShift

ZueriCrop
^^^^^^^^^
Expand Down
93 changes: 92 additions & 1 deletion tests/datasets/test_xview.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from _pytest.fixtures import SubRequest
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError, XView2
from torchgeo.datasets import DatasetNotFoundError, XView2, XView2DistShift


class TestXView2:
Expand Down Expand Up @@ -92,3 +92,94 @@ def test_plot(self, dataset: XView2) -> None:
x['prediction'] = x['mask'][0].clone()
dataset.plot(x)
plt.close()


class TestXView2DistShift:
@pytest.fixture(params=['train', 'test'])
def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistShift:
monkeypatch.setattr(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might not even need to monkeypatch this if you remove checksum=True, I never bother with this

XView2DistShift,
'metadata',
{
'train': {
'filename': 'train_images_labels_targets.tar.gz',
'md5': '373e61d55c1b294aa76b94dbbd81332b',
'directory': 'train',
},
'test': {
'filename': 'test_images_labels_targets.tar.gz',
'md5': 'bc6de81c956a3bada38b5b4e246266a1',
'directory': 'test',
},
},
)
root = os.path.join('tests', 'data', 'xview2')
split = request.param
transforms = nn.Identity()

return XView2DistShift(
root=root,
split=split,
id_ood_disaster=[
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'},
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this duplication required for some reason?

],
transforms=transforms,
checksum=True,
)

def test_getitem(self, dataset: XView2DistShift) -> None:
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['mask'], torch.Tensor)
assert set(torch.unique(x['mask']).tolist()).issubset({0, 1}) # binary mask

def test_len(self, dataset: XView2DistShift) -> None:
assert len(dataset) > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's test a specific length to ensure it behaves as expected. Can use an if-statement if it's different for train and test.


def test_invalid_disaster(self) -> None:
with pytest.raises(ValueError, match='Invalid disaster name'):
XView2DistShift(
root='tests/data/xview2',
id_ood_disaster=[
{'disaster_name': 'not-a-real-one', 'pre-post': 'post'},
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'},
],
)

def test_missing_disaster_name_key(self) -> None:
with pytest.raises(
ValueError, match="Each disaster entry must contain a 'disaster_name' key."
):
XView2DistShift(
root='tests/data/xview2',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
root='tests/data/xview2',
root=os.path.join('tests', 'data', 'xview2'),

Windows

id_ood_disaster=[
{'pre-post': 'post'}, # missing 'disaster_name'
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'},
],
)

def test_missing_pre_post_key(self) -> None:
with pytest.raises(
ValueError,
match="Each disaster entry must contain 'disaster_name' and 'pre-post' keys.",
):
XView2DistShift(
root='tests/data/xview2',
id_ood_disaster=[
{'disaster_name': 'hurricane-harvey'}, # missing 'pre-post'
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'},
],
)

def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
XView2DistShift(
root='tests/data/xview2',
split='bad', # type: ignore[arg-type]
id_ood_disaster=[
{'disaster_name': 'hurricane-matthew', 'pre-post': 'post'},
{'disaster_name': 'mexico-earthquake', 'pre-post': 'post'},
],
)
3 changes: 2 additions & 1 deletion torchgeo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
from .vaihingen import Vaihingen2D
from .vhr10 import VHR10
from .western_usa_live_fuel_moisture import WesternUSALiveFuelMoisture
from .xview import XView2
from .xview import XView2, XView2DistShift
from .zuericrop import ZueriCrop

__all__ = (
Expand Down Expand Up @@ -347,6 +347,7 @@
'VectorDataset',
'WesternUSALiveFuelMoisture',
'XView2',
'XView2DistShift',
'ZueriCrop',
'concat_samples',
'merge_samples',
Expand Down
216 changes: 213 additions & 3 deletions torchgeo/datasets/xview.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import glob
import os
from collections.abc import Callable
from typing import ClassVar
from collections.abc import Callable, Sequence
from typing import ClassVar, Literal

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -138,7 +138,7 @@ def _load_files(self, root: Path, split: str) -> list[dict[str, str]]:

Args:
root: root dir of dataset
split: subset of dataset, one of [train, test]
split: subset of dataset, one of ['train', 'test']

Returns:
list of dicts containing paths for each pair of images and masks
Expand Down Expand Up @@ -282,3 +282,213 @@ def plot(
plt.suptitle(suptitle)

return fig


class XView2DistShift(XView2):
"""A subclass of the XView2 dataset designed to reformat the original train/test splits.

This class allows for the selection of particular disasters to be used as the
training set (in-domain) and test set (out-of-domain). The dataset can be split
according to the disaster names specified by the user, enabling the model to train
on one disaster type and evaluate on a different, out-of-domain disaster. The goal
is to test the generalization ability of models trained on one disaster to perform
on others.

.. versionadded:: 0.8
"""

binary_classes = ('background', 'building')

valid_disasters = (
'hurricane-harvey',
'socal-fire',
'hurricane-matthew',
'mexico-earthquake',
'guatemala-volcano',
'santa-rosa-wildfire',
'palu-tsunami',
'hurricane-florence',
'hurricane-michael',
'midwest-flooding',
)

def __init__(
self,
root: str = 'data',
split: Literal['train', 'test'] = 'train',
id_ood_disaster: Sequence[dict[str, str]] = (
{'disaster_name': 'hurricane-matthew', 'pre-post': 'post'},
{'disaster_name': 'mexico-earthquake', 'pre-post': 'post'},
),
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
checksum: bool = False,
) -> None:
"""Initialize the XView2DistShift dataset instance.

Args:
root: Root directory where the dataset is located.
split: One of "train" or "test".
id_ood_disaster: List containing in-distribution and out-of-distribution disaster names.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs more description. Clarify that both disaster_name and pre-post are required. Also explain what they mean, I have no idea

transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
checksum: if True, check the MD5 of the downloaded files (may be slow)

Raises:
AssertionError: If *split* is invalid.
ValueError: If a disaster name in *id_ood_disaster* is not one of the valid disasters.
Comment on lines +337 to +338
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like many other undocumented cases where these are raised

DatasetNotFoundError: If dataset is not found.
"""
assert split in ['train', 'test'], "Split must be either 'train' or 'test'."
assert len(id_ood_disaster) == 2, (
'id_ood_disaster must contain exactly two items'
)

for disaster in id_ood_disaster:
if 'disaster_name' not in disaster:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is duplicated below...

raise ValueError(
"Each disaster entry must contain a 'disaster_name' key."
)
if disaster['disaster_name'] not in self.valid_disasters:
raise ValueError(
f'Invalid disaster name: {disaster["disaster_name"]}. '
f'Valid options are: {", ".join(self.valid_disasters)}'
)

for disaster in id_ood_disaster:
if 'disaster_name' not in disaster or 'pre-post' not in disaster:
raise ValueError(
"Each disaster entry must contain 'disaster_name' and 'pre-post' keys."
)

self.root = root
self.split = split
self.transforms = transforms
self.checksum = checksum

self._verify()

# Load all files and compute basenames and disasters only once
self.all_files = self._initialize_files(root)

# Split logic by disaster and pre-post type
self.split_files: dict[str, list[dict[str, str]]] = (
self._load_split_files_by_disaster_and_type(
self.all_files, id_ood_disaster[0], id_ood_disaster[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I pass in a dict of length 100, are the later 98 arguments ignored? Seems like a dict is a bad choice for this. Why not use 4 different parameters so they can be more clearly documented and type checked?

)
)

train_size, test_size = self.get_id_ood_sizes()
print(f'ID sample len: {train_size}, OOD sample len: {test_size}')

def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
"""Return an index within the dataset.

Args:
index: index to return

Returns:
data and label at that index
"""
file_info = self.split_files[self.split][index]

image = self._load_image(file_info['image'])
mask = self._load_target(file_info['mask']).long()

# Reformulate as building segmentation task
mask[mask == 2] = 1 # minor-damage → building
mask[(mask == 3) | (mask == 4)] = 0 # major/destroyed → background

sample = {'image': image, 'mask': mask}
if self.transforms:
sample = self.transforms(sample)

return sample

def __len__(self) -> int:
"""Return the number of data points in the dataset.

Returns:
length of the dataset
"""
return len(self.split_files[self.split])

def get_id_ood_sizes(self) -> tuple[int, int]:
"""Return the number of samples in the train and test splits."""
return (len(self.split_files['train']), len(self.split_files['test']))

def _initialize_files(self, root: str) -> list[dict[str, str]]:
"""Initialize the dataset by loading file paths and computing basenames with sample numbers."""
all_files = []
for split in self.metadata.keys():
image_root = os.path.join(root, split, 'images')
mask_root = os.path.join(root, split, 'targets')
images = sorted(glob.glob(os.path.join(image_root, '*.png')))

# Extract basenames while preserving the disaster-name and sample number
for img in images:
basename_parts = os.path.basename(img).split('_')
event_name = basename_parts[0] # e.g., mexico-earthquake
sample_number = basename_parts[1] # e.g., 00000001
basename = f'{event_name}_{sample_number}'

file_info = {
'image': img,
'mask': os.path.join(
mask_root, f'{basename}_pre_disaster_target.png'
),
'basename': basename,
}
all_files.append(file_info)
return all_files

def _load_split_files_by_disaster_and_type(
self,
files: list[dict[str, str]],
id_disaster: dict[str, str],
ood_disaster: dict[str, str],
) -> dict[str, list[dict[str, str]]]:
"""Return train and test filepaths based on disaster name and pre/post type.

Args:
files: List of file paths with their corresponding information.
id_disaster: Dictionary specifying in-domain (ID) disaster and type,
e.g., {'disaster_name': 'guatemala-volcano', 'pre-post': 'pre'}.
ood_disaster: Dictionary specifying out-of-domain (OOD) disaster and type,
e.g., {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}.

Returns:
A dictionary containing 'train' (ID) and 'test' (OOD) file lists.
"""
train_files = []
test_files = []

for file_info in files:
basename = file_info['basename']
disaster_name = basename.split('_')[0]
pre_post = 'pre' if 'pre_disaster' in file_info['image'] else 'post'

if disaster_name == id_disaster['disaster_name']:
if (
id_disaster.get('pre-post') == 'both'
or id_disaster['pre-post'] == pre_post
):
image = (
file_info['image'].replace('post_disaster', 'pre_disaster')
if pre_post == 'pre'
else file_info['image']
)
mask = (
file_info['mask'].replace('post_disaster', 'pre_disaster')
if pre_post == 'pre'
else file_info['mask']
)
train_files.append(dict(image=image, mask=mask))

if disaster_name == ood_disaster['disaster_name']:
if (
ood_disaster.get('pre-post') == 'both'
or ood_disaster['pre-post'] == pre_post
):
test_files.append(file_info)

return {'train': train_files, 'test': test_files}