Skip to content

Commit 3b39d1f

Browse files
committed
Merge remote-tracking branch 'origin/main' into 44-hydra
2 parents c236171 + 145f1d6 commit 3b39d1f

10 files changed

Lines changed: 366 additions & 246 deletions

File tree

notebooks/00_exploration.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@
249249
"metadata": {},
250250
"outputs": [],
251251
"source": [
252-
"from auto_cast.metrics.spatiotemporal import MSE\n",
252+
"from auto_cast.metrics.spatiotemporal_old import MSE\n",
253253
"\n",
254254
"assert trues is not None\n",
255255
"assert preds.shape == trues.shape\n",
@@ -284,7 +284,7 @@
284284
"metadata": {},
285285
"outputs": [],
286286
"source": [
287-
"from auto_cast.metrics.spatiotemporal import RMSE\n",
287+
"from auto_cast.metrics.spatiotemporal_old import RMSE\n",
288288
"\n",
289289
"assert trues is not None\n",
290290
"rmse_error = RMSE()(preds, trues, 2)"

src/auto_cast/encoders/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ def encode_batch(
5858
encoded_inputs = self.encode(batch)
5959

6060
# Assign output fields to inputs to be encoded identically in this default impl
61-
# Use replace to avoid mutating the original batch
62-
output_batch = replace(batch, input_fields=batch.output_fields)
61+
# Create a new batch with output fields as input fields to prevent mutation
62+
output_batch = replace(batch, input_fields=batch.output_fields.clone())
63+
6364
encoded_outputs = self.encode(output_batch)
6465

6566
# Return encoded batch

src/auto_cast/metrics/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
from .spatiotemporal import MAE, MSE, NMAE, NMSE, NRMSE, RMSE, VMSE, VRMSE, LInfinity
1+
from .spatiotemporal import (
2+
MAE,
3+
MSE,
4+
NMAE,
5+
NMSE,
6+
NRMSE,
7+
RMSE,
8+
VMSE,
9+
VRMSE,
10+
LInfinity,
11+
)
212

313
__all__ = ["MAE", "MSE", "NMAE", "NMSE", "NRMSE", "RMSE", "VMSE", "VRMSE", "LInfinity"]
414

src/auto_cast/metrics/base.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

0 commit comments

Comments
 (0)