-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtypes.py
More file actions
95 lines (75 loc) · 2.38 KB
/
types.py
File metadata and controls
95 lines (75 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from collections.abc import Iterator, Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Self, TypedDict
from jaxtyping import Float
from numpy import float32
from numpy.typing import NDArray
from omegaconf import DictConfig
from torch import Tensor
ArrayCHW = Float[NDArray[float32], "channels height width"]
ArrayTCHW = Float[NDArray[float32], "time channels height width"]
TensorNCHW = Float[Tensor, "batch channels height width"]
TensorNTCHW = Float[Tensor, "batch time channels height width"]
@dataclass
class AnemoiCreateArgs:
config: DictConfig
path: str
command: str = "unused"
overwrite: bool = False
processes: int = 0
threads: int = 0
@dataclass
class AnemoiInspectArgs:
detailed: bool
path: str
progress: bool
size: bool
statistics: bool
class DataloaderArgs(TypedDict):
batch_size: int
sampler: None
batch_sampler: None
drop_last: bool
worker_init_fn: None
class DataSpace:
channels: int
name: str
shape: tuple[int, int]
def __init__(self, channels: int, name: str, shape: Sequence[int]) -> None:
self.channels = int(channels)
self.name = name
self.shape = (int(shape[0]), int(shape[1]))
@property
def chw(self) -> tuple[int, int, int]:
"""Return a tuple of [channels, height, width]."""
return (self.channels, *self.shape)
@classmethod
def from_dict(cls, config: DictConfig | dict[str, Any]) -> Self:
return cls(
channels=config["channels"], name=config["name"], shape=config["shape"]
)
def to_dict(self) -> DictConfig:
"""Return the DataSpace as a DictConfig."""
return DictConfig(
{"channels": self.channels, "name": self.name, "shape": self.shape}
)
@dataclass
class ModelTestOutput(Mapping[str, Tensor]):
"""Output of a model test step."""
prediction: TensorNTCHW
target: TensorNTCHW
loss: Tensor
def __getitem__(self, key: str) -> Tensor:
if key == "prediction":
return self.prediction
if key == "target":
return self.target
if key == "loss":
return self.loss
raise KeyError(f"Key {key} not found in ModelTestOutput")
def __iter__(self) -> Iterator[str]:
yield "prediction"
yield "target"
yield "loss"
def __len__(self) -> int:
return 3