Skip to content

Custom disaster-based train/test splits for xView2 dataset #2416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: main
Choose a base branch
from

Conversation

burakekim
Copy link
Contributor

@burakekim burakekim commented Nov 18, 2024

cc: @calebrob6

XView2DistShift is a subclass of XView2 designed to modify the original train/test splits. Similar to EuroSATSpatial #2074, this class enables domain adaptation and out-of-distribution (OOD) detection experiments.

From the docstring:

This class allows for the selection of particular disasters to be used as the training set (in-domain) and test set (out-of-domain). The dataset can be split according to the disaster names specified by the user, enabling the model to train on one disaster type and evaluate on a different, out-of-domain disaster. The goal is to test the generalization ability of models trained on one disaster to perform on another.

@github-actions github-actions bot added documentation Improvements or additions to documentation datasets Geospatial or benchmark datasets testing Continuous integration testing labels Nov 18, 2024
@adamjstewart
Copy link
Collaborator

We decided on EuroSAT Spatial before, why switch to XView2 Dist Shift now? Will there be any corresponding citations for these new splits?

It would be nice to move more of the shared code in the XView2 base class so that the only thing that needs to be changed in this subclass is the URLs. How different are these datasets?

@adamjstewart adamjstewart modified the milestones: 0.6.2, 0.7.0 Nov 18, 2024
@burakekim
Copy link
Contributor Author

burakekim commented Nov 18, 2024

We decided on EuroSAT Spatial before, why switch to XView2 Dist Shift now?

Spatial refers to the type of distribution shift revealed by the splits when they are rearranged. XView2, consists of multiple disasters, and the distribution shift is determined by the user's choice. The user can select any disaster as the training set and another as the test set -- which introduces varying types of distribution shifts. These shifts range from near-distribution shifts to far-distribution shifts, depending on how different the disasters in the splits are. And here, the difference is not limited to spatial factors but also includes temporal and contextual differences. That is why, Spatial would be a misleading naming for XView2

One alternative could be standardizing the naming for these subset datasets with a suffix like OOD or DistShift. What do you think?

It would be nice to move more of the shared code in the XView2 base class so that the only thing that needs to be changed in this subclass is the URLs. How different are these datasets?

They are basically the same dataset but with different splits. XView2DistShift allows users to select specific disasters for training and testing sets.

Are you suggesting we curate the filenames for all disasters as HF links and dynamically load them as training or testing sets based on input? This approach would save us from _initialize_files and _load_split_files_by_disaster_and_type, not __getitem__ and __len__

@calebrob6
Copy link
Member

Great dataset. highly recommend

@burakekim
Copy link
Contributor Author

burakekim commented Feb 1, 2025

This is how it works:


id_ood_disaster =  [{'disaster_name': 'hurricane-matthew', 'pre-post': 'post'},
  {'disaster_name': 'mexico-earthquake', 'pre-post': 'post'}]

xview2 = XView2DistShift(root=root, 
                         split="test",
                         id_ood_disaster=id_ood_disaster)

> ID sample len: 311, OOD sample len: 159

All the existing methods are revised to make XView2DistShift work as it should. I cannot seem to find a way to prune further (unless I upstream a method or two to XView2 but that is too much of refactoring).

If it looks good, I can go ahead with unit tests.

@adamjstewart, just to loop you in: As you may have noticed, we (cc: @calebrob6) are upstreaming some modifications to existing datasets to make them suitable for assessing models under controlled domain shifts. This unlocks a whole new research dimension in TG, enabling users to explore robustness, generalization ability, anomaly detection, novelty detection, OOD detection and more. If it reaches a certain level of maturity, I could even consider spinning it off as a standalone toolkit that also involves methods like our recent OOD detector!

@adamjstewart adamjstewart removed this from the 0.7.0 milestone Mar 23, 2025
@github-actions github-actions bot removed the testing Continuous integration testing label Apr 19, 2025
@adamjstewart adamjstewart added this to the 0.8.0 milestone Apr 20, 2025
@burakekim
Copy link
Contributor Author

Better now. Thanks for the review @adamjstewart!

@burakekim burakekim requested a review from adamjstewart April 22, 2025 10:38
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still a lot of concerns about how id/ood arguments are handled.

class TestXView2DistShift:
@pytest.fixture(params=['train', 'test'])
def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistShift:
monkeypatch.setattr(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might not even need to monkeypatch this if you remove checksum=True, I never bother with this

split=split,
id_ood_disaster=[
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'},
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this duplication required for some reason?

assert set(torch.unique(x['mask']).tolist()).issubset({0, 1}) # binary mask

def test_len(self, dataset: XView2DistShift) -> None:
assert len(dataset) > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's test a specific length to ensure it behaves as expected. Can use an if-statement if it's different for train and test.

ValueError, match="Each disaster entry must contain a 'disaster_name' key."
):
XView2DistShift(
root='tests/data/xview2',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
root='tests/data/xview2',
root=os.path.join('tests', 'data', 'xview2'),

Windows

Args:
root: Root directory where the dataset is located.
split: One of "train" or "test".
id_ood_disaster: List containing in-distribution and out-of-distribution disaster names.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs more description. Clarify that both disaster_name and pre-post are required. Also explain what they mean, I have no idea

)

for disaster in id_ood_disaster:
if 'disaster_name' not in disaster:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is duplicated below...

Comment on lines +337 to +338
AssertionError: If *split* is invalid.
ValueError: If a disaster name in *id_ood_disaster* is not one of the valid disasters.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like many other undocumented cases where these are raised

# Split logic by disaster and pre-post type
self.split_files: dict[str, list[dict[str, str]]] = (
self._load_split_files_by_disaster_and_type(
self.all_files, id_ood_disaster[0], id_ood_disaster[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I pass in a dict of length 100, are the later 98 arguments ignored? Seems like a dict is a bad choice for this. Why not use 4 different parameters so they can be more clearly documented and type checked?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datasets Geospatial or benchmark datasets documentation Improvements or additions to documentation testing Continuous integration testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants