Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions autoemulate/core/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from autoemulate.core.metrics import (
R2,
TorchMetrics,
get_metric_config,
get_metric_configs,
get_metric,
get_metrics,
)
from autoemulate.core.model_selection import bootstrap, evaluate
from autoemulate.core.plotting import (
Expand Down Expand Up @@ -143,8 +143,8 @@ def __init__(

# Setup metrics. If evaluation_metrics is None, default to ["r2", "rmse"]
evaluation_metrics = evaluation_metrics or ["r2", "rmse"]
self.evaluation_metrics = get_metric_configs(evaluation_metrics)
self.tuning_metric = get_metric_config(tuning_metric)
self.evaluation_metrics = get_metrics(evaluation_metrics)
self.tuning_metric = get_metric(tuning_metric)

# Transforms to search over
self.x_transforms_list = [
Expand Down
54 changes: 24 additions & 30 deletions autoemulate/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ class TorchMetrics(Metric):

Parameters
----------
metric : MetricLike
metric: MetricLike
The torchmetrics metric class or partial.
name : str
name: str
Display name for the metric. If None, uses the class name of the metric.
maximize : bool
maximize: bool
Whether higher values are better.
"""

Expand Down Expand Up @@ -83,8 +83,12 @@ def __call__(
y_pred = y_pred.rsample((n_samples,)).mean(dim=0)
metric = self.metric()
metric.to(y_pred.device)
# Assume first dim is a batch dim, flatten others for metric calculation
metric.update(y_pred.flatten(start_dim=1), y_true.flatten(start_dim=1))

# Assume first dim is a batch dim if >=2D, flatten others for metric calculation
metric.update(
y_pred.flatten(start_dim=1) if y_pred.ndim > 1 else y_pred,
y_true.flatten(start_dim=1) if y_true.ndim > 1 else y_true,
)
return metric.compute()


Expand Down Expand Up @@ -244,32 +248,30 @@ def __call__(
}


def get_metric_config(
metric: str | TorchMetrics,
) -> TorchMetrics:
"""Convert various metric specifications to MetricConfig.
def get_metric(metric: str | Metric) -> Metric:
"""Convert metric specification to a `Metric`.

Parameters
----------
metric : str | type[torchmetrics.Metric] | partial[torchmetrics.Metric] | Metric
metric: str | Metric
The metric specification. Can be:
- A string shortcut like "r2", "rmse", "mse", "mae"
- A Metric instance (returned as-is)

Returns
-------
TorchMetrics
The metric configuration.
Metric
The metric.

Raises
------
ValueError
If the metric specification is invalid or name is not provided when required.

If the metric specification is not a string (and registered in
AVAILABLE_METRICS) or Metric instance.

"""
# If already a TorchMetric, return as-is
if isinstance(metric, TorchMetrics):
# If already a Metric, return as-is
if isinstance(metric, Metric):
return metric

if isinstance(metric, str):
Expand All @@ -286,25 +288,17 @@ def get_metric_config(
)


def get_metric_configs(
metrics: Sequence[str | TorchMetrics],
) -> list[TorchMetrics]:
"""Convert a list of metric specifications to MetricConfig objects.
def get_metrics(metrics: Sequence[str | Metric]) -> list[Metric]:
"""Convert a list of metric specifications to list of `Metric`s.

Parameters
----------
metrics : Sequence[str | TorchMetrics]
metrics: Sequence[str | Metric]
Sequence of metric specifications.

Returns
-------
list[TorchMetrics]
List of metric configurations.
list[Metric]
List of metrics.
"""
result_metrics = []

for m in metrics:
config = get_metric_config(m) if isinstance(m, (str | TorchMetrics)) else m
result_metrics.append(config)

return result_metrics
return [get_metric(m) for m in metrics]
10 changes: 5 additions & 5 deletions autoemulate/core/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
get_torch_device,
move_tensors_to_device,
)
from autoemulate.core.metrics import R2, Metric, TorchMetrics, get_metric_configs
from autoemulate.core.metrics import R2, Metric, get_metrics
from autoemulate.core.types import (
DeviceLike,
ModelParams,
Expand Down Expand Up @@ -63,7 +63,7 @@ def cross_validate(
y_transforms: list[Transform] | None = None,
device: DeviceLike = "cpu",
random_seed: int | None = None,
metrics: list[TorchMetrics] | None = None,
metrics: list[Metric] | None = None,
):
"""
Cross validate model performance using the given `cv` strategy.
Expand Down Expand Up @@ -100,7 +100,7 @@ def cross_validate(

# Setup metrics
if metrics is None:
metrics = get_metric_configs(["r2", "rmse"])
metrics = get_metrics(["r2", "rmse"])

cv_results = {metric.name: [] for metric in metrics}
device = get_torch_device(device)
Expand Down Expand Up @@ -158,7 +158,7 @@ def bootstrap(
n_bootstraps: int | None = 100,
n_samples: int = 1000,
device: str | torch.device = "cpu",
metrics: list[TorchMetrics] | None = None,
metrics: list[Metric] | None = None,
) -> dict[str, tuple[float, float]]:
"""
Get bootstrap estimates of metrics.
Expand Down Expand Up @@ -193,7 +193,7 @@ def bootstrap(

# Setup metrics
if metrics is None:
metrics = get_metric_configs(["r2", "rmse"])
metrics = get_metrics(["r2", "rmse"])

# If no bootstraps are specified, fall back to a single evaluation on given data
if n_bootstraps is None:
Expand Down
6 changes: 3 additions & 3 deletions autoemulate/core/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.distributions import Transform

from autoemulate.core.device import TorchDeviceMixin
from autoemulate.core.metrics import TorchMetrics, get_metric_config
from autoemulate.core.metrics import Metric, get_metric
from autoemulate.core.model_selection import cross_validate
from autoemulate.core.types import (
DeviceLike,
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
n_iter: int = 10,
device: DeviceLike | None = None,
random_seed: int | None = None,
tuning_metric: str | TorchMetrics = "r2",
tuning_metric: str | Metric = "r2",
):
TorchDeviceMixin.__init__(self, device=device)
self.n_iter = n_iter
Expand All @@ -60,7 +60,7 @@ def __init__(
self.dataset = self._convert_to_dataset(x_tensor, y_tensor)

# Setup tuning metric
self.tuning_metric = get_metric_config(tuning_metric)
self.tuning_metric = get_metric(tuning_metric)

if random_seed is not None:
set_random_seed(seed=random_seed)
Expand Down
51 changes: 23 additions & 28 deletions tests/core/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
CRPSMetric,
Metric,
TorchMetrics,
get_metric_config,
get_metric_configs,
get_metric,
get_metrics,
)
from torch.distributions import Normal

Expand Down Expand Up @@ -176,7 +176,7 @@ def test_mae_computation():

def test_get_metric_config_with_string_r2():
"""Test get_metric_config with 'r2' string."""
config = get_metric_config("r2")
config = get_metric("r2")

assert config == R2
assert config.name == "r2"
Expand All @@ -185,7 +185,7 @@ def test_get_metric_config_with_string_r2():

def test_get_metric_config_with_string_rmse():
"""Test get_metric_config with 'rmse' string."""
config = get_metric_config("rmse")
config = get_metric("rmse")

assert config == RMSE
assert config.name == "rmse"
Expand All @@ -194,7 +194,7 @@ def test_get_metric_config_with_string_rmse():

def test_get_metric_config_with_string_mse():
"""Test get_metric_config with 'mse' string."""
config = get_metric_config("mse")
config = get_metric("mse")

assert config == MSE
assert config.name == "mse"
Expand All @@ -203,7 +203,7 @@ def test_get_metric_config_with_string_mse():

def test_get_metric_config_with_string_mae():
"""Test get_metric_config with 'mae' string."""
config = get_metric_config("mae")
config = get_metric("mae")

assert config == MAE
assert config.name == "mae"
Expand All @@ -212,9 +212,9 @@ def test_get_metric_config_with_string_mae():

def test_get_metric_config_case_insensitive():
"""Test get_metric_config is case insensitive."""
config_upper = get_metric_config("R2")
config_lower = get_metric_config("r2")
config_mixed = get_metric_config("R2")
config_upper = get_metric("R2")
config_lower = get_metric("r2")
config_mixed = get_metric("R2")

assert config_upper == config_lower == config_mixed == R2

Expand All @@ -225,7 +225,7 @@ def test_get_metric_config_with_torchmetrics_instance():
metric=torchmetrics.R2Score, name="custom_r2", maximize=True
)

config = get_metric_config(custom_metric)
config = get_metric(custom_metric)

assert config == custom_metric
assert config.name == "custom_r2"
Expand All @@ -234,7 +234,7 @@ def test_get_metric_config_with_torchmetrics_instance():
def test_get_metric_config_invalid_string():
"""Test get_metric_config with invalid string raises ValueError."""
with pytest.raises(ValueError, match="Unknown metric shortcut") as excinfo:
get_metric_config("invalid_metric")
get_metric("invalid_metric")

assert "Unknown metric shortcut" in str(excinfo.value)
assert "invalid_metric" in str(excinfo.value)
Expand All @@ -244,15 +244,15 @@ def test_get_metric_config_invalid_string():
def test_get_metric_config_unsupported_type():
"""Test get_metric_config with unsupported type raises ValueError."""
with pytest.raises(ValueError, match="Unsupported metric type") as excinfo:
get_metric_config(123) # type: ignore[arg-type]
get_metric(123) # type: ignore[arg-type]

assert "Unsupported metric type" in str(excinfo.value)


def test_get_metric_config_with_none():
"""Test get_metric_config with None raises ValueError."""
with pytest.raises(ValueError, match="Unsupported metric type") as excinfo:
get_metric_config(None) # type: ignore[arg-type]
get_metric(None) # type: ignore[arg-type]

assert "Unsupported metric type" in str(excinfo.value)

Expand All @@ -263,7 +263,7 @@ def test_get_metric_config_with_none():
def test_get_metric_configs_with_strings():
"""Test get_metric_configs with list of strings."""
metrics = ["r2", "rmse", "mse"]
configs = get_metric_configs(metrics)
configs = get_metrics(metrics)

assert len(configs) == 3
assert configs[0] == R2
Expand All @@ -278,7 +278,7 @@ def test_get_metric_configs_with_mixed_types():
)

metrics = ["r2", custom_metric, "mse"]
configs = get_metric_configs(metrics)
configs = get_metrics(metrics)

assert len(configs) == 3
assert configs[0] == R2
Expand All @@ -288,15 +288,15 @@ def test_get_metric_configs_with_mixed_types():

def test_get_metric_configs_with_empty_list():
"""Test get_metric_configs with empty list."""
configs = get_metric_configs([])
configs = get_metrics([])

assert len(configs) == 0
assert configs == []


def test_get_metric_configs_with_single_metric():
"""Test get_metric_configs with single metric."""
configs = get_metric_configs(["r2"])
configs = get_metrics(["r2"])

assert len(configs) == 1
assert configs[0] == R2
Expand All @@ -305,7 +305,7 @@ def test_get_metric_configs_with_single_metric():
def test_get_metric_configs_with_all_available_metrics():
"""Test get_metric_configs with all available metrics."""
metrics = list(AVAILABLE_METRICS.keys())
configs = get_metric_configs(metrics)
configs = get_metrics(metrics)

assert len(configs) == len(AVAILABLE_METRICS)

Expand All @@ -320,7 +320,7 @@ def test_get_metric_configs_with_torchmetrics_instances():
metric=torchmetrics.MeanSquaredError, name="mse_1", maximize=False
)

configs = get_metric_configs([metric1, metric2])
configs = get_metrics([metric1, metric2])

assert len(configs) == 2
assert configs[0] == metric1
Expand All @@ -330,7 +330,7 @@ def test_get_metric_configs_with_torchmetrics_instances():
def test_get_metric_configs_case_insensitive():
"""Test get_metric_configs is case insensitive for strings."""
metrics = ["R2", "RMSE", "mse", "MaE", "Crps"]
configs = get_metric_configs(metrics)
configs = get_metrics(metrics)

assert len(configs) == 5
assert configs[0] == R2
Expand Down Expand Up @@ -396,18 +396,13 @@ def test_metric_with_multidimensional_tensors():
def test_metric_configs_workflow():
"""Test complete workflow of getting and using metric configs."""
# Get configs from strings
configs = get_metric_configs(["r2", "rmse"])
metrics = get_metrics(["r2", "rmse"])

# Use configs to compute metrics
y_pred = torch.tensor([1.0, 2.0, 3.0])
y_true = torch.tensor([1.0, 2.0, 3.0])

results = {}
for config in configs:
metric = config.metric()
metric.update(y_pred, y_true)
results[config.name] = metric.compute()

results = {metric.name: metric(y_pred, y_true) for metric in metrics}
assert "r2" in results
assert "rmse" in results
assert torch.isclose(results["r2"], torch.tensor(1.0)) # Perfect R2
Expand Down Expand Up @@ -509,7 +504,7 @@ def test_crps_aggregation_across_batch():

def test_get_metric_config_crps():
"""Test get_metric_config with 'crps' string."""
config = get_metric_config("crps")
config = get_metric("crps")

assert config == CRPS
assert isinstance(config, CRPSMetric)
Expand Down