-
Notifications
You must be signed in to change notification settings - Fork 443
Custom disaster-based train/test splits for xView2 dataset #2416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
We decided on EuroSAT Spatial before, why switch to XView2 Dist Shift now? Will there be any corresponding citations for these new splits? It would be nice to move more of the shared code in the XView2 base class so that the only thing that needs to be changed in this subclass is the URLs. How different are these datasets? |
Spatial refers to the type of distribution shift revealed by the splits when they are rearranged. XView2, consists of multiple disasters, and the distribution shift is determined by the user's choice. The user can select any disaster as the training set and another as the test set -- which introduces varying types of distribution shifts. These shifts range from near-distribution shifts to far-distribution shifts, depending on how different the disasters in the splits are. And here, the difference is not limited to spatial factors but also includes temporal and contextual differences. That is why, Spatial would be a misleading naming for XView2 One alternative could be standardizing the naming for these subset datasets with a suffix like OOD or DistShift. What do you think?
They are basically the same dataset but with different splits. XView2DistShift allows users to select specific disasters for training and testing sets. Are you suggesting we curate the filenames for all disasters as HF links and dynamically load them as training or testing sets based on input? This approach would save us from |
Great dataset. highly recommend |
This is how it works:
All the existing methods are revised to make XView2DistShift work as it should. I cannot seem to find a way to prune further (unless I upstream a method or two to If it looks good, I can go ahead with unit tests. @adamjstewart, just to loop you in: As you may have noticed, we (cc: @calebrob6) are upstreaming some modifications to existing datasets to make them suitable for assessing models under controlled domain shifts. This unlocks a whole new research dimension in TG, enabling users to explore robustness, generalization ability, anomaly detection, novelty detection, OOD detection and more. If it reaches a certain level of maturity, I could even consider spinning it off as a standalone toolkit that also involves methods like our recent OOD detector! |
Adam for the win Co-authored-by: Adam J. Stewart <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
Co-authored-by: Adam J. Stewart <[email protected]>
Better now. Thanks for the review @adamjstewart! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still a lot of concerns about how id/ood arguments are handled.
class TestXView2DistShift: | ||
@pytest.fixture(params=['train', 'test']) | ||
def dataset(self, monkeypatch: MonkeyPatch, request: SubRequest) -> XView2DistShift: | ||
monkeypatch.setattr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might not even need to monkeypatch this if you remove checksum=True
, I never bother with this
split=split, | ||
id_ood_disaster=[ | ||
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, | ||
{'disaster_name': 'hurricane-harvey', 'pre-post': 'post'}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this duplication required for some reason?
assert set(torch.unique(x['mask']).tolist()).issubset({0, 1}) # binary mask | ||
|
||
def test_len(self, dataset: XView2DistShift) -> None: | ||
assert len(dataset) > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's test a specific length to ensure it behaves as expected. Can use an if-statement if it's different for train and test.
ValueError, match="Each disaster entry must contain a 'disaster_name' key." | ||
): | ||
XView2DistShift( | ||
root='tests/data/xview2', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
root='tests/data/xview2', | |
root=os.path.join('tests', 'data', 'xview2'), |
Windows
Args: | ||
root: Root directory where the dataset is located. | ||
split: One of "train" or "test". | ||
id_ood_disaster: List containing in-distribution and out-of-distribution disaster names. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs more description. Clarify that both disaster_name
and pre-post
are required. Also explain what they mean, I have no idea
) | ||
|
||
for disaster in id_ood_disaster: | ||
if 'disaster_name' not in disaster: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is duplicated below...
AssertionError: If *split* is invalid. | ||
ValueError: If a disaster name in *id_ood_disaster* is not one of the valid disasters. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like many other undocumented cases where these are raised
# Split logic by disaster and pre-post type | ||
self.split_files: dict[str, list[dict[str, str]]] = ( | ||
self._load_split_files_by_disaster_and_type( | ||
self.all_files, id_ood_disaster[0], id_ood_disaster[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I pass in a dict of length 100, are the later 98 arguments ignored? Seems like a dict is a bad choice for this. Why not use 4 different parameters so they can be more clearly documented and type checked?
cc: @calebrob6
XView2DistShift
is a subclass ofXView2
designed to modify the original train/test splits. Similar to EuroSATSpatial #2074, this class enables domain adaptation and out-of-distribution (OOD) detection experiments.From the docstring: