Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 20 additions & 26 deletions autoemulate/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,12 @@ def __call__(
y_pred = y_pred.rsample((n_samples,)).mean(dim=0)
metric = self.metric()
metric.to(y_pred.device)
# Assume first dim is a batch dim, flatten others for metric calculation
metric.update(y_pred.flatten(start_dim=1), y_true.flatten(start_dim=1))

# Assume first dim is a batch dim if >=2D, flatten others for metric calculation
metric.update(
y_pred.flatten(start_dim=1) if y_pred.ndim > 1 else y_pred,
y_true.flatten(start_dim=1) if y_true.ndim > 1 else y_true,
)
return metric.compute()


Expand Down Expand Up @@ -244,32 +248,30 @@ def __call__(
}


def get_metric_config(
metric: str | TorchMetrics,
) -> TorchMetrics:
def get_metric_config(metric: str | Metric) -> Metric:
"""Convert various metric specifications to MetricConfig.

Parameters
----------
metric : str | type[torchmetrics.Metric] | partial[torchmetrics.Metric] | Metric
metric : str | Metric
The metric specification. Can be:
- A string shortcut like "r2", "rmse", "mse", "mae"
- A Metric instance (returned as-is)

Returns
-------
TorchMetrics
The metric configuration.
Metric
The metric.

Raises
------
ValueError
If the metric specification is invalid or name is not provided when required.

If the metric specification is not a string (and registered in
AVAILABLE_METRICS) or Metric instance.

"""
# If already a TorchMetric, return as-is
if isinstance(metric, TorchMetrics):
# If already a Metric, return as-is
if isinstance(metric, Metric):
return metric

if isinstance(metric, str):
Expand All @@ -286,25 +288,17 @@ def get_metric_config(
)


def get_metric_configs(
metrics: Sequence[str | TorchMetrics],
) -> list[TorchMetrics]:
"""Convert a list of metric specifications to MetricConfig objects.
def get_metric_configs(metrics: Sequence[str | Metric]) -> list[Metric]:
"""Convert a list of metric specifications to Metric objects.

Parameters
----------
metrics : Sequence[str | TorchMetrics]
metrics : Sequence[str | Metrics]
Sequence of metric specifications.

Returns
-------
list[TorchMetrics]
List of metric configurations.
list[Metric]
List of metrics.
"""
result_metrics = []

for m in metrics:
config = get_metric_config(m) if isinstance(m, (str | TorchMetrics)) else m
result_metrics.append(config)

return result_metrics
return [get_metric_config(m) for m in metrics]
6 changes: 3 additions & 3 deletions autoemulate/core/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
get_torch_device,
move_tensors_to_device,
)
from autoemulate.core.metrics import R2, Metric, TorchMetrics, get_metric_configs
from autoemulate.core.metrics import R2, Metric, get_metric_configs
from autoemulate.core.types import (
DeviceLike,
ModelParams,
Expand Down Expand Up @@ -63,7 +63,7 @@ def cross_validate(
y_transforms: list[Transform] | None = None,
device: DeviceLike = "cpu",
random_seed: int | None = None,
metrics: list[TorchMetrics] | None = None,
metrics: list[Metric] | None = None,
):
"""
Cross validate model performance using the given `cv` strategy.
Expand Down Expand Up @@ -158,7 +158,7 @@ def bootstrap(
n_bootstraps: int | None = 100,
n_samples: int = 1000,
device: str | torch.device = "cpu",
metrics: list[TorchMetrics] | None = None,
metrics: list[Metric] | None = None,
) -> dict[str, tuple[float, float]]:
"""
Get bootstrap estimates of metrics.
Expand Down
4 changes: 2 additions & 2 deletions autoemulate/core/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.distributions import Transform

from autoemulate.core.device import TorchDeviceMixin
from autoemulate.core.metrics import TorchMetrics, get_metric_config
from autoemulate.core.metrics import Metric, get_metric_config
from autoemulate.core.model_selection import cross_validate
from autoemulate.core.types import (
DeviceLike,
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
n_iter: int = 10,
device: DeviceLike | None = None,
random_seed: int | None = None,
tuning_metric: str | TorchMetrics = "r2",
tuning_metric: str | Metric = "r2",
):
TorchDeviceMixin.__init__(self, device=device)
self.n_iter = n_iter
Expand Down
9 changes: 2 additions & 7 deletions tests/core/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,18 +396,13 @@ def test_metric_with_multidimensional_tensors():
def test_metric_configs_workflow():
"""Test complete workflow of getting and using metric configs."""
# Get configs from strings
configs = get_metric_configs(["r2", "rmse"])
metrics = get_metric_configs(["r2", "rmse"])

# Use configs to compute metrics
y_pred = torch.tensor([1.0, 2.0, 3.0])
y_true = torch.tensor([1.0, 2.0, 3.0])

results = {}
for config in configs:
metric = config.metric()
metric.update(y_pred, y_true)
results[config.name] = metric.compute()

results = {metric.name: metric(y_pred, y_true) for metric in metrics}
assert "r2" in results
assert "rmse" in results
assert torch.isclose(results["r2"], torch.tensor(1.0)) # Perfect R2
Expand Down