-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathzebra_dataset.py
More file actions
87 lines (73 loc) · 2.71 KB
/
zebra_dataset.py
File metadata and controls
87 lines (73 loc) · 2.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from pathlib import Path
from collections.abc import Sequence
import numpy as np
from anemoi.datasets.data import open_dataset
from anemoi.datasets.data.dataset import Dataset as AnemoiDataset
from cachetools import LRUCache, cachedmethod
from torch.utils.data import Dataset
from ice_station_zebra.types import ArrayCHW, ArrayTCHW, DataSpace
class ZebraDataset(Dataset):
def __init__(
self,
name: str,
input_files: list[Path],
*,
start: str | None = None,
end: str | None = None,
) -> None:
"""A dataset for use by Zebra
Dataset shape is: time; variables; ensembles; position
We reshape each time point to: variables; pos_x; pos_y
"""
super().__init__()
self._cache: LRUCache = LRUCache(maxsize=128)
self._dataset: AnemoiDataset | None = None
self._end = end
self._input_files = input_files
self._name = name
self._start = start
@property
def dataset(self) -> AnemoiDataset:
"""Load the underlying Anemoi dataset."""
if not self._dataset:
self._dataset = open_dataset(
self._input_files, start=self._start, end=self._end
)
self._dataset._name = self._name
return self._dataset
@property
def end_date(self) -> np.datetime64:
"""Return the end date of the dataset."""
return self.dataset.end_date
@property
def name(self) -> str:
"""Return the name of the dataset."""
return self._name
@property
def space(self) -> DataSpace:
"""Return the data space for this dataset."""
return DataSpace(
channels=self.dataset.shape[1],
name=self.name,
shape=self.dataset.field_shape,
)
@property
def start_date(self) -> np.datetime64:
"""Return the start date of the dataset."""
return self.dataset.start_date
def __len__(self) -> int:
"""Return the total length of the dataset"""
return len(self.dataset)
def __getitem__(self, idx: int) -> ArrayCHW:
"""Return the data for a single timestep in [C, H, W] format"""
return self.dataset[idx].reshape(self.space.chw)
@cachedmethod(lambda self: self._cache)
def index_from_date(self, date: np.datetime64) -> int:
"""Return the index of a given date in the dataset."""
idx, _, _ = self.dataset.to_index(date, 0)
return idx
def get_tchw(self, dates: Sequence[np.datetime64]) -> ArrayTCHW:
"""Return the data for a series of timesteps in [T, C, H, W] format"""
return np.stack(
[self[self.index_from_date(target_date)] for target_date in dates], axis=0
)