Skip to content

Commit 7e9a288

Browse files
committed
Fix pyright lints, update test, fix for 1D
1 parent 638b446 commit 7e9a288

File tree

4 files changed

+13
-14
lines changed

4 files changed

+13
-14
lines changed

autoemulate/core/metrics.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,12 @@ def __call__(
8383
y_pred = y_pred.rsample((n_samples,)).mean(dim=0)
8484
metric = self.metric()
8585
metric.to(y_pred.device)
86-
# Assume first dim is a batch dim, flatten others for metric calculation
87-
metric.update(y_pred.flatten(start_dim=1), y_true.flatten(start_dim=1))
86+
87+
# Assume first dim is a batch dim if >=2D, flatten others for metric calculation
88+
metric.update(
89+
y_pred.flatten(start_dim=1) if y_pred.ndim > 1 else y_pred,
90+
y_true.flatten(start_dim=1) if y_true.ndim > 1 else y_true,
91+
)
8892
return metric.compute()
8993

9094

autoemulate/core/model_selection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
get_torch_device,
1111
move_tensors_to_device,
1212
)
13-
from autoemulate.core.metrics import R2, Metric, TorchMetrics, get_metric_configs
13+
from autoemulate.core.metrics import R2, Metric, get_metric_configs
1414
from autoemulate.core.types import (
1515
DeviceLike,
1616
ModelParams,
@@ -63,7 +63,7 @@ def cross_validate(
6363
y_transforms: list[Transform] | None = None,
6464
device: DeviceLike = "cpu",
6565
random_seed: int | None = None,
66-
metrics: list[TorchMetrics] | None = None,
66+
metrics: list[Metric] | None = None,
6767
):
6868
"""
6969
Cross validate model performance using the given `cv` strategy.
@@ -158,7 +158,7 @@ def bootstrap(
158158
n_bootstraps: int | None = 100,
159159
n_samples: int = 1000,
160160
device: str | torch.device = "cpu",
161-
metrics: list[TorchMetrics] | None = None,
161+
metrics: list[Metric] | None = None,
162162
) -> dict[str, tuple[float, float]]:
163163
"""
164164
Get bootstrap estimates of metrics.

autoemulate/core/tuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.distributions import Transform
77

88
from autoemulate.core.device import TorchDeviceMixin
9-
from autoemulate.core.metrics import TorchMetrics, get_metric_config
9+
from autoemulate.core.metrics import Metric, get_metric_config
1010
from autoemulate.core.model_selection import cross_validate
1111
from autoemulate.core.types import (
1212
DeviceLike,
@@ -46,7 +46,7 @@ def __init__(
4646
n_iter: int = 10,
4747
device: DeviceLike | None = None,
4848
random_seed: int | None = None,
49-
tuning_metric: str | TorchMetrics = "r2",
49+
tuning_metric: str | Metric = "r2",
5050
):
5151
TorchDeviceMixin.__init__(self, device=device)
5252
self.n_iter = n_iter

tests/core/test_metrics.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,18 +396,13 @@ def test_metric_with_multidimensional_tensors():
396396
def test_metric_configs_workflow():
397397
"""Test complete workflow of getting and using metric configs."""
398398
# Get configs from strings
399-
configs = get_metric_configs(["r2", "rmse"])
399+
metrics = get_metric_configs(["r2", "rmse"])
400400

401401
# Use configs to compute metrics
402402
y_pred = torch.tensor([1.0, 2.0, 3.0])
403403
y_true = torch.tensor([1.0, 2.0, 3.0])
404404

405-
results = {}
406-
for config in configs:
407-
metric = config.metric()
408-
metric.update(y_pred, y_true)
409-
results[config.name] = metric.compute()
410-
405+
results = {metric.name: metric(y_pred, y_true) for metric in metrics}
411406
assert "r2" in results
412407
assert "rmse" in results
413408
assert torch.isclose(results["r2"], torch.tensor(1.0)) # Perfect R2

0 commit comments

Comments
 (0)