-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcombined_dataset.py
More file actions
119 lines (100 loc) · 4.62 KB
/
combined_dataset.py
File metadata and controls
119 lines (100 loc) · 4.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from collections.abc import Sequence
import numpy as np
from torch.utils.data import Dataset
from icenet_mp.types import ArrayTCHW
from .single_dataset import SingleDataset
class CombinedDataset(Dataset):
def __init__(
self,
datasets: Sequence[SingleDataset],
target_group_name: str,
target_variables: Sequence[str],
*,
n_forecast_steps: int = 1,
n_history_steps: int = 1,
) -> None:
"""Initialise a combined dataset from a sequence of SingleDatasets.
One of the datasets must be the target and all must have the same frequency. The
number of forecast and history steps can be set, which will determine the shape
of the NTCHW tensors returned by __getitem__.
"""
super().__init__()
# Store the number of forecast and history steps
self.n_forecast_steps = n_forecast_steps
self.n_history_steps = n_history_steps
# Create a new dataset for the target with only the selected variables
self.target = next(
ds for ds in datasets if ds.name == target_group_name
).subset(variables=target_variables)
self.inputs = list(datasets)
# Require that all datasets have the same frequency
frequencies = sorted({ds.frequency for ds in datasets}) # type: ignore[type-var]
if len(frequencies) != 1:
msg = f"Cannot combine datasets with different frequencies: {frequencies}."
raise ValueError(msg)
self.frequency = frequencies[0]
# Lazy-load dates on first request
self._available_dates: list[np.datetime64] | None = None
@property
def dates(self) -> list[np.datetime64]:
"""Get list of dates that are available in all datasets."""
if self._available_dates is None:
# Identify dates that exist in all input datasets
input_date_set = set.intersection(*(set(ds.dates) for ds in self.inputs))
target_date_set = set(self.target.dates)
self._available_dates = sorted(
available_date
for available_date in input_date_set
# ... if all inputs have n_history_steps starting on start_date
if all(
date in input_date_set
for date in self.get_history_steps(available_date)
)
# ... and if the target has n_forecast_steps starting after the history dates
and all(
date in target_date_set
for date in self.get_forecast_steps(available_date)
)
)
if len(self._available_dates) == 0:
msg = (
"CombinedDataset has no valid dates. This can happen when there "
"are no valid windows given the configured history/forecast steps or "
"when the input datasets do not have overlapping time ranges."
)
raise ValueError(msg)
return self._available_dates
@property
def end_date(self) -> np.datetime64:
"""Return the end date of the dataset."""
return self.dates[-1]
@property
def start_date(self) -> np.datetime64:
"""Return the start date of the dataset."""
return self.dates[0]
def __len__(self) -> int:
"""Return the total length of the dataset."""
return len(self.dates)
def __getitem__(self, idx: int) -> dict[str, ArrayTCHW]:
"""Return the data for a single timestep as a dictionary.
Returns:
A dictionary with dataset names as keys and a numpy array as the value.
The shape of each array is:
- input datasets: [n_history_steps, C_input_k, H_input_k, W_input_k]
- target dataset: [n_forecast_steps, C_target, H_target, W_target]
"""
return {
ds.name: ds.get_tchw(self.get_history_steps(self.dates[idx]))
for ds in self.inputs
} | {"target": self.target.get_tchw(self.get_forecast_steps(self.dates[idx]))}
def get_forecast_steps(self, start_date: np.datetime64) -> list[np.datetime64]:
"""Return list of consecutive forecast dates for a given start date."""
return [
start_date + (idx + self.n_history_steps) * self.frequency
for idx in range(self.n_forecast_steps)
]
def get_history_steps(self, start_date: np.datetime64) -> list[np.datetime64]:
"""Return list of consecutive history dates for a given start date."""
return [
start_date + idx * self.frequency for idx in range(self.n_history_steps)
]