Skip to content

Commit 16d6293

Browse files
committed
covering edge scenarios per Adam's comments
1 parent 50cd637 commit 16d6293

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,4 +147,4 @@ venv.bak/
147147
dmypy.json
148148

149149
# Pyre type checker
150-
.pyre/
150+
.pyre/

torchgeo/datasets/xview.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import glob
77
import os
8-
from collections.abc import Callable
8+
from collections.abc import Callable, Sequence
99
from typing import ClassVar, Literal
1010

1111
import matplotlib.pyplot as plt
@@ -297,9 +297,9 @@ class XView2DistShift(XView2):
297297
.. versionadded:: 0.8
298298
"""
299299

300-
binary_classes: ClassVar[tuple[str, str]] = ('background', 'building')
300+
binary_classes = ('background', 'building')
301301

302-
valid_disasters: ClassVar[list[str]] = [
302+
valid_disasters = (
303303
'hurricane-harvey',
304304
'socal-fire',
305305
'hurricane-matthew',
@@ -310,18 +310,17 @@ class XView2DistShift(XView2):
310310
'hurricane-florence',
311311
'hurricane-michael',
312312
'midwest-flooding',
313-
]
313+
)
314314

315315
def __init__(
316316
self,
317317
root: str = 'data',
318318
split: Literal['train', 'test'] = 'train',
319-
id_ood_disaster: list[dict[str, str]] = [
319+
id_ood_disaster: Sequence[dict[str, str]] = (
320320
{'disaster_name': 'hurricane-matthew', 'pre-post': 'post'},
321321
{'disaster_name': 'mexico-earthquake', 'pre-post': 'post'},
322-
],
323-
transforms: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]]
324-
| None = None,
322+
),
323+
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
325324
checksum: bool = False,
326325
) -> None:
327326
"""Initialize the XView2DistShift dataset instance.
@@ -340,14 +339,26 @@ def __init__(
340339
DatasetNotFoundError: If dataset is not found.
341340
"""
342341
assert split in ['train', 'test'], "Split must be either 'train' or 'test'."
342+
assert len(id_ood_disaster) == 2, (
343+
'id_ood_disaster must contain exactly two items'
344+
)
343345

344-
if (
345-
id_ood_disaster[0]['disaster_name'] not in self.valid_disasters
346-
or id_ood_disaster[1]['disaster_name'] not in self.valid_disasters
347-
):
348-
raise ValueError(
349-
f'Invalid disaster names. Valid options are: {", ".join(self.valid_disasters)}'
350-
)
346+
for disaster in id_ood_disaster:
347+
if 'disaster_name' not in disaster:
348+
raise ValueError(
349+
"Each disaster entry must contain a 'disaster_name' key."
350+
)
351+
if disaster['disaster_name'] not in self.valid_disasters:
352+
raise ValueError(
353+
f'Invalid disaster name: {disaster["disaster_name"]}. '
354+
f'Valid options are: {", ".join(self.valid_disasters)}'
355+
)
356+
357+
for disaster in id_ood_disaster:
358+
if 'disaster_name' not in disaster or 'pre-post' not in disaster:
359+
raise ValueError(
360+
"Each disaster entry must contain 'disaster_name' and 'pre-post' keys."
361+
)
351362

352363
self.root = root
353364
self.split = split

0 commit comments

Comments
 (0)