Skip to content

Commit 5dfa053

Browse files
authored
Merge pull request #62 from alan-turing-institute/metrics-dev
Add metrics (#20)
2 parents 73dae21 + d20deca commit 5dfa053

6 files changed

Lines changed: 394 additions & 17 deletions

File tree

notebooks/00_exploration.ipynb

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,25 @@
7272
]
7373
},
7474
{
75-
"cell_type": "markdown",
75+
"cell_type": "code",
76+
"execution_count": null,
7677
"id": "5",
7778
"metadata": {},
79+
"outputs": [],
80+
"source": []
81+
},
82+
{
83+
"cell_type": "markdown",
84+
"id": "6",
85+
"metadata": {},
7886
"source": [
7987
"### Example batch\n"
8088
]
8189
},
8290
{
8391
"cell_type": "code",
8492
"execution_count": null,
85-
"id": "6",
93+
"id": "7",
8694
"metadata": {},
8795
"outputs": [],
8896
"source": [
@@ -94,7 +102,7 @@
94102
{
95103
"cell_type": "code",
96104
"execution_count": null,
97-
"id": "7",
105+
"id": "8",
98106
"metadata": {},
99107
"outputs": [],
100108
"source": [
@@ -126,7 +134,7 @@
126134
{
127135
"cell_type": "code",
128136
"execution_count": null,
129-
"id": "8",
137+
"id": "9",
130138
"metadata": {},
131139
"outputs": [],
132140
"source": [
@@ -135,7 +143,7 @@
135143
},
136144
{
137145
"cell_type": "markdown",
138-
"id": "9",
146+
"id": "10",
139147
"metadata": {},
140148
"source": [
141149
"### Run trainer\n"
@@ -144,21 +152,21 @@
144152
{
145153
"cell_type": "code",
146154
"execution_count": null,
147-
"id": "10",
155+
"id": "11",
148156
"metadata": {},
149157
"outputs": [],
150158
"source": [
151159
"import lightning as L\n",
152160
"\n",
153161
"device = \"mps\" # \"cpu\"\n",
154162
"# device = \"cpu\"\n",
155-
"trainer = L.Trainer(max_epochs=1, accelerator=device, log_every_n_steps=10)\n",
163+
"trainer = L.Trainer(max_epochs=10, accelerator=device, log_every_n_steps=10)\n",
156164
"trainer.fit(model, datamodule.train_dataloader(), datamodule.val_dataloader())"
157165
]
158166
},
159167
{
160168
"cell_type": "markdown",
161-
"id": "11",
169+
"id": "12",
162170
"metadata": {},
163171
"source": [
164172
"### Run the evaluation\n"
@@ -167,7 +175,7 @@
167175
{
168176
"cell_type": "code",
169177
"execution_count": null,
170-
"id": "12",
178+
"id": "13",
171179
"metadata": {},
172180
"outputs": [],
173181
"source": [
@@ -176,7 +184,7 @@
176184
},
177185
{
178186
"cell_type": "markdown",
179-
"id": "13",
187+
"id": "14",
180188
"metadata": {},
181189
"source": [
182190
"### Example rollout\n"
@@ -185,7 +193,7 @@
185193
{
186194
"cell_type": "code",
187195
"execution_count": null,
188-
"id": "14",
196+
"id": "15",
189197
"metadata": {},
190198
"outputs": [],
191199
"source": [
@@ -196,7 +204,7 @@
196204
{
197205
"cell_type": "code",
198206
"execution_count": null,
199-
"id": "15",
207+
"id": "16",
200208
"metadata": {},
201209
"outputs": [],
202210
"source": [
@@ -209,7 +217,7 @@
209217
{
210218
"cell_type": "code",
211219
"execution_count": null,
212-
"id": "16",
220+
"id": "17",
213221
"metadata": {},
214222
"outputs": [],
215223
"source": [
@@ -220,7 +228,18 @@
220228
{
221229
"cell_type": "code",
222230
"execution_count": null,
223-
"id": "17",
231+
"id": "18",
232+
"metadata": {},
233+
"outputs": [],
234+
"source": [
235+
"assert preds.shape == trues.shape\n",
236+
"mse_error = MSE()(preds, trues, trues)"
237+
]
238+
},
239+
{
240+
"cell_type": "code",
241+
"execution_count": null,
242+
"id": "19",
224243
"metadata": {},
225244
"outputs": [],
226245
"source": [
@@ -230,18 +249,49 @@
230249
{
231250
"cell_type": "code",
232251
"execution_count": null,
233-
"id": "18",
252+
"id": "20",
234253
"metadata": {},
235254
"outputs": [],
236255
"source": [
237256
"assert trues is not None\n",
238257
"print(trues.shape)\n"
239258
]
259+
},
260+
{
261+
"cell_type": "code",
262+
"execution_count": null,
263+
"id": "21",
264+
"metadata": {},
265+
"outputs": [],
266+
"source": [
267+
"from the_well.benchmark.metrics import RMSE, MAE, MSE\n",
268+
"from the_well.data.datasets import WellMetadata\n",
269+
"\n",
270+
"rmse_error = RMSE.eval(preds, trues, WellMetadata)"
271+
]
272+
},
273+
{
274+
"cell_type": "code",
275+
"execution_count": null,
276+
"id": "22",
277+
"metadata": {},
278+
"outputs": [],
279+
"source": [
280+
"WellMetadata.n_spatial_dims"
281+
]
282+
},
283+
{
284+
"cell_type": "code",
285+
"execution_count": null,
286+
"id": "23",
287+
"metadata": {},
288+
"outputs": [],
289+
"source": []
240290
}
241291
],
242292
"metadata": {
243293
"kernelspec": {
244-
"display_name": ".venv",
294+
"display_name": "auto-cast",
245295
"language": "python",
246296
"name": "python3"
247297
},
@@ -255,7 +305,7 @@
255305
"name": "python",
256306
"nbconvert_exporter": "python",
257307
"pygments_lexer": "ipython3",
258-
"version": "3.12.12"
308+
"version": "3.12.9"
259309
}
260310
},
261311
"nbformat": 4,

src/auto_cast/metrics/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .spatiotemporal import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity
2+
3+
__all__ = ["MAE", "MSE", "NMAE", "NMSE", "NRMSE", "RMSE", "VMSE", "VRMSE", "LInfinity"]
4+
5+
ALL_METRICS = (MSE, MAE, NMAE, NMSE, RMSE, NRMSE, VMSE, VRMSE, LInfinity)

src/auto_cast/metrics/base.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import numpy as np
2+
import torch
3+
from torch import nn
4+
5+
from auto_cast.types import TensorBCTSPlus
6+
7+
8+
class Metric(nn.Module):
9+
"""
10+
Base class for metrics.
11+
12+
This class standardizes the input arguments and
13+
checks the dimensions of the input tensors.
14+
15+
Args:
16+
f: function
17+
Metric function that takes in the following arguments:
18+
y_pred: torch.Tensor | np.ndarray
19+
Predicted values tensor.
20+
y_true: torch.Tensor | np.ndarray
21+
Target values tensor.
22+
**kwargs : dict
23+
Additional arguments for the metric.
24+
"""
25+
26+
def forward(self, *args, **kwargs):
27+
assert len(args) >= 2, (
28+
"At least two arguments required (y_pred, y_true, n_spatial_dims)"
29+
)
30+
y_pred, y_true, n_spatial_dims = args[:3]
31+
32+
# Convert y_pred and y_true to torch.Tensor if they are np.ndarray
33+
if isinstance(y_pred, np.ndarray):
34+
y_pred = torch.from_numpy(y_pred)
35+
if isinstance(y_true, np.ndarray):
36+
y_true = torch.from_numpy(y_true)
37+
assert isinstance(y_pred, torch.Tensor), (
38+
"y_pred must be a torch.Tensor or np.ndarray"
39+
)
40+
assert isinstance(y_true, torch.Tensor), (
41+
"y_true must be a torch.Tensor or np.ndarray"
42+
)
43+
44+
# Check dimensions
45+
assert y_pred.ndim >= n_spatial_dims + 1, (
46+
"y_pred must have at least n_spatial_dims + 1 dimensions"
47+
)
48+
assert y_true.ndim >= n_spatial_dims + 1, (
49+
"y_true must have at least n_spatial_dims + 1 dimensions"
50+
)
51+
return self.score(y_pred, y_true, n_spatial_dims, **kwargs)
52+
53+
@staticmethod
54+
def score(
55+
y_pred: TensorBCTSPlus, y_true: TensorBCTSPlus, n_spatial_dims: int, **kwargs
56+
):
57+
raise NotImplementedError

0 commit comments

Comments
 (0)