-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcombined_dataset.py
More file actions
47 lines (38 loc) · 1.76 KB
/
combined_dataset.py
File metadata and controls
47 lines (38 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from collections.abc import Sequence
from datetime import datetime
import numpy as np
from numpy.typing import NDArray
from torch.utils.data import Dataset
from .zebra_dataset import ZebraDataset
class CombinedDataset(Dataset):
def __init__(self, datasets: Sequence[ZebraDataset], *, target: str) -> None:
"""Constructor"""
super().__init__()
self.target = next(ds for ds in datasets if ds.name == target)
self.datasets = [ds for ds in datasets if ds != self.target]
def __len__(self) -> int:
"""Return the total length of the dataset"""
return min([len(ds) for ds in self.datasets] + [len(self.target)])
def __getitem__(self, idx: int) -> tuple[NDArray[np.float32]]:
"""Return a single timestep"""
return tuple([ds[idx] for ds in self.datasets] + [self.target[idx]]) # type: ignore[return-value]
def date_from_index(self, idx: int) -> datetime:
"""Return the date of the timestep"""
np_datetime = self.target.dataset.dates[idx]
return datetime.strptime(str(np_datetime), r"%Y-%m-%dT%H:%M:%S")
@property
def end_date(self) -> np.datetime64:
"""Return the end date of the dataset."""
end_date = set(dataset.end_date for dataset in self.datasets)
if len(end_date) != 1:
msg = f"Datasets have {len(end_date)} different end dates"
raise ValueError(msg)
return end_date.pop()
@property
def start_date(self) -> np.datetime64:
"""Return the start date of the dataset."""
start_date = set(dataset.start_date for dataset in self.datasets)
if len(start_date) != 1:
msg = f"Datasets have {len(start_date)} different start dates"
raise ValueError(msg)
return start_date.pop()