5
5
6
6
import glob
7
7
import os
8
- from collections .abc import Callable
8
+ from collections .abc import Callable , Sequence
9
9
from typing import ClassVar , Literal
10
10
11
11
import matplotlib .pyplot as plt
@@ -297,9 +297,9 @@ class XView2DistShift(XView2):
297
297
.. versionadded:: 0.8
298
298
"""
299
299
300
- binary_classes : ClassVar [ tuple [ str , str ]] = ('background' , 'building' )
300
+ binary_classes = ('background' , 'building' )
301
301
302
- valid_disasters : ClassVar [ list [ str ]] = [
302
+ valid_disasters = (
303
303
'hurricane-harvey' ,
304
304
'socal-fire' ,
305
305
'hurricane-matthew' ,
@@ -310,18 +310,17 @@ class XView2DistShift(XView2):
310
310
'hurricane-florence' ,
311
311
'hurricane-michael' ,
312
312
'midwest-flooding' ,
313
- ]
313
+ )
314
314
315
315
def __init__ (
316
316
self ,
317
317
root : str = 'data' ,
318
318
split : Literal ['train' , 'test' ] = 'train' ,
319
- id_ood_disaster : list [dict [str , str ]] = [
319
+ id_ood_disaster : Sequence [dict [str , str ]] = (
320
320
{'disaster_name' : 'hurricane-matthew' , 'pre-post' : 'post' },
321
321
{'disaster_name' : 'mexico-earthquake' , 'pre-post' : 'post' },
322
- ],
323
- transforms : Callable [[dict [str , torch .Tensor ]], dict [str , torch .Tensor ]]
324
- | None = None ,
322
+ ),
323
+ transforms : Callable [[dict [str , Tensor ]], dict [str , Tensor ]] | None = None ,
325
324
checksum : bool = False ,
326
325
) -> None :
327
326
"""Initialize the XView2DistShift dataset instance.
@@ -340,14 +339,26 @@ def __init__(
340
339
DatasetNotFoundError: If dataset is not found.
341
340
"""
342
341
assert split in ['train' , 'test' ], "Split must be either 'train' or 'test'."
342
+ assert len (id_ood_disaster ) == 2 , (
343
+ 'id_ood_disaster must contain exactly two items'
344
+ )
343
345
344
- if (
345
- id_ood_disaster [0 ]['disaster_name' ] not in self .valid_disasters
346
- or id_ood_disaster [1 ]['disaster_name' ] not in self .valid_disasters
347
- ):
348
- raise ValueError (
349
- f'Invalid disaster names. Valid options are: { ", " .join (self .valid_disasters )} '
350
- )
346
+ for disaster in id_ood_disaster :
347
+ if 'disaster_name' not in disaster :
348
+ raise ValueError (
349
+ "Each disaster entry must contain a 'disaster_name' key."
350
+ )
351
+ if disaster ['disaster_name' ] not in self .valid_disasters :
352
+ raise ValueError (
353
+ f'Invalid disaster name: { disaster ["disaster_name" ]} . '
354
+ f'Valid options are: { ", " .join (self .valid_disasters )} '
355
+ )
356
+
357
+ for disaster in id_ood_disaster :
358
+ if 'disaster_name' not in disaster or 'pre-post' not in disaster :
359
+ raise ValueError (
360
+ "Each disaster entry must contain 'disaster_name' and 'pre-post' keys."
361
+ )
351
362
352
363
self .root = root
353
364
self .split = split
0 commit comments