diff --git a/autoemulate/core/compare.py b/autoemulate/core/compare.py index a6804c61a..6f28bf0c8 100644 --- a/autoemulate/core/compare.py +++ b/autoemulate/core/compare.py @@ -15,8 +15,8 @@ from autoemulate.core.metrics import ( R2, TorchMetrics, - get_metric_config, - get_metric_configs, + get_metric, + get_metrics, ) from autoemulate.core.model_selection import bootstrap, evaluate from autoemulate.core.plotting import ( @@ -143,8 +143,8 @@ def __init__( # Setup metrics. If evaluation_metrics is None, default to ["r2", "rmse"] evaluation_metrics = evaluation_metrics or ["r2", "rmse"] - self.evaluation_metrics = get_metric_configs(evaluation_metrics) - self.tuning_metric = get_metric_config(tuning_metric) + self.evaluation_metrics = get_metrics(evaluation_metrics) + self.tuning_metric = get_metric(tuning_metric) # Transforms to search over self.x_transforms_list = [ diff --git a/autoemulate/core/metrics.py b/autoemulate/core/metrics.py index e0d7c5b10..964e0ecf4 100644 --- a/autoemulate/core/metrics.py +++ b/autoemulate/core/metrics.py @@ -48,11 +48,11 @@ class TorchMetrics(Metric): Parameters ---------- - metric : MetricLike + metric: MetricLike The torchmetrics metric class or partial. - name : str + name: str Display name for the metric. If None, uses the class name of the metric. - maximize : bool + maximize: bool Whether higher values are better. """ @@ -83,8 +83,12 @@ def __call__( y_pred = y_pred.rsample((n_samples,)).mean(dim=0) metric = self.metric() metric.to(y_pred.device) - # Assume first dim is a batch dim, flatten others for metric calculation - metric.update(y_pred.flatten(start_dim=1), y_true.flatten(start_dim=1)) + + # Assume first dim is a batch dim if >=2D, flatten others for metric calculation + metric.update( + y_pred.flatten(start_dim=1) if y_pred.ndim > 1 else y_pred, + y_true.flatten(start_dim=1) if y_true.ndim > 1 else y_true, + ) return metric.compute() @@ -244,32 +248,30 @@ def __call__( } -def get_metric_config( - metric: str | TorchMetrics, -) -> TorchMetrics: - """Convert various metric specifications to MetricConfig. +def get_metric(metric: str | Metric) -> Metric: + """Convert metric specification to a `Metric`. Parameters ---------- - metric : str | type[torchmetrics.Metric] | partial[torchmetrics.Metric] | Metric + metric: str | Metric The metric specification. Can be: - A string shortcut like "r2", "rmse", "mse", "mae" - A Metric instance (returned as-is) Returns ------- - TorchMetrics - The metric configuration. + Metric + The metric. Raises ------ ValueError - If the metric specification is invalid or name is not provided when required. - + If the metric specification is not a string (and registered in + AVAILABLE_METRICS) or Metric instance. """ - # If already a TorchMetric, return as-is - if isinstance(metric, TorchMetrics): + # If already a Metric, return as-is + if isinstance(metric, Metric): return metric if isinstance(metric, str): @@ -286,25 +288,17 @@ def get_metric_config( ) -def get_metric_configs( - metrics: Sequence[str | TorchMetrics], -) -> list[TorchMetrics]: - """Convert a list of metric specifications to MetricConfig objects. +def get_metrics(metrics: Sequence[str | Metric]) -> list[Metric]: + """Convert a list of metric specifications to list of `Metric`s. Parameters ---------- - metrics : Sequence[str | TorchMetrics] + metrics: Sequence[str | Metric] Sequence of metric specifications. Returns ------- - list[TorchMetrics] - List of metric configurations. + list[Metric] + List of metrics. """ - result_metrics = [] - - for m in metrics: - config = get_metric_config(m) if isinstance(m, (str | TorchMetrics)) else m - result_metrics.append(config) - - return result_metrics + return [get_metric(m) for m in metrics] diff --git a/autoemulate/core/model_selection.py b/autoemulate/core/model_selection.py index 1c1bbd055..5b0237dc8 100644 --- a/autoemulate/core/model_selection.py +++ b/autoemulate/core/model_selection.py @@ -10,7 +10,7 @@ get_torch_device, move_tensors_to_device, ) -from autoemulate.core.metrics import R2, Metric, TorchMetrics, get_metric_configs +from autoemulate.core.metrics import R2, Metric, get_metrics from autoemulate.core.types import ( DeviceLike, ModelParams, @@ -63,7 +63,7 @@ def cross_validate( y_transforms: list[Transform] | None = None, device: DeviceLike = "cpu", random_seed: int | None = None, - metrics: list[TorchMetrics] | None = None, + metrics: list[Metric] | None = None, ): """ Cross validate model performance using the given `cv` strategy. @@ -100,7 +100,7 @@ def cross_validate( # Setup metrics if metrics is None: - metrics = get_metric_configs(["r2", "rmse"]) + metrics = get_metrics(["r2", "rmse"]) cv_results = {metric.name: [] for metric in metrics} device = get_torch_device(device) @@ -158,7 +158,7 @@ def bootstrap( n_bootstraps: int | None = 100, n_samples: int = 1000, device: str | torch.device = "cpu", - metrics: list[TorchMetrics] | None = None, + metrics: list[Metric] | None = None, ) -> dict[str, tuple[float, float]]: """ Get bootstrap estimates of metrics. @@ -193,7 +193,7 @@ def bootstrap( # Setup metrics if metrics is None: - metrics = get_metric_configs(["r2", "rmse"]) + metrics = get_metrics(["r2", "rmse"]) # If no bootstraps are specified, fall back to a single evaluation on given data if n_bootstraps is None: diff --git a/autoemulate/core/tuner.py b/autoemulate/core/tuner.py index df726ae63..b3443e8b8 100644 --- a/autoemulate/core/tuner.py +++ b/autoemulate/core/tuner.py @@ -6,7 +6,7 @@ from torch.distributions import Transform from autoemulate.core.device import TorchDeviceMixin -from autoemulate.core.metrics import TorchMetrics, get_metric_config +from autoemulate.core.metrics import Metric, get_metric from autoemulate.core.model_selection import cross_validate from autoemulate.core.types import ( DeviceLike, @@ -46,7 +46,7 @@ def __init__( n_iter: int = 10, device: DeviceLike | None = None, random_seed: int | None = None, - tuning_metric: str | TorchMetrics = "r2", + tuning_metric: str | Metric = "r2", ): TorchDeviceMixin.__init__(self, device=device) self.n_iter = n_iter @@ -60,7 +60,7 @@ def __init__( self.dataset = self._convert_to_dataset(x_tensor, y_tensor) # Setup tuning metric - self.tuning_metric = get_metric_config(tuning_metric) + self.tuning_metric = get_metric(tuning_metric) if random_seed is not None: set_random_seed(seed=random_seed) diff --git a/tests/core/test_metrics.py b/tests/core/test_metrics.py index 87fa16297..8e676d203 100644 --- a/tests/core/test_metrics.py +++ b/tests/core/test_metrics.py @@ -15,8 +15,8 @@ CRPSMetric, Metric, TorchMetrics, - get_metric_config, - get_metric_configs, + get_metric, + get_metrics, ) from torch.distributions import Normal @@ -176,7 +176,7 @@ def test_mae_computation(): def test_get_metric_config_with_string_r2(): """Test get_metric_config with 'r2' string.""" - config = get_metric_config("r2") + config = get_metric("r2") assert config == R2 assert config.name == "r2" @@ -185,7 +185,7 @@ def test_get_metric_config_with_string_r2(): def test_get_metric_config_with_string_rmse(): """Test get_metric_config with 'rmse' string.""" - config = get_metric_config("rmse") + config = get_metric("rmse") assert config == RMSE assert config.name == "rmse" @@ -194,7 +194,7 @@ def test_get_metric_config_with_string_rmse(): def test_get_metric_config_with_string_mse(): """Test get_metric_config with 'mse' string.""" - config = get_metric_config("mse") + config = get_metric("mse") assert config == MSE assert config.name == "mse" @@ -203,7 +203,7 @@ def test_get_metric_config_with_string_mse(): def test_get_metric_config_with_string_mae(): """Test get_metric_config with 'mae' string.""" - config = get_metric_config("mae") + config = get_metric("mae") assert config == MAE assert config.name == "mae" @@ -212,9 +212,9 @@ def test_get_metric_config_with_string_mae(): def test_get_metric_config_case_insensitive(): """Test get_metric_config is case insensitive.""" - config_upper = get_metric_config("R2") - config_lower = get_metric_config("r2") - config_mixed = get_metric_config("R2") + config_upper = get_metric("R2") + config_lower = get_metric("r2") + config_mixed = get_metric("R2") assert config_upper == config_lower == config_mixed == R2 @@ -225,7 +225,7 @@ def test_get_metric_config_with_torchmetrics_instance(): metric=torchmetrics.R2Score, name="custom_r2", maximize=True ) - config = get_metric_config(custom_metric) + config = get_metric(custom_metric) assert config == custom_metric assert config.name == "custom_r2" @@ -234,7 +234,7 @@ def test_get_metric_config_with_torchmetrics_instance(): def test_get_metric_config_invalid_string(): """Test get_metric_config with invalid string raises ValueError.""" with pytest.raises(ValueError, match="Unknown metric shortcut") as excinfo: - get_metric_config("invalid_metric") + get_metric("invalid_metric") assert "Unknown metric shortcut" in str(excinfo.value) assert "invalid_metric" in str(excinfo.value) @@ -244,7 +244,7 @@ def test_get_metric_config_invalid_string(): def test_get_metric_config_unsupported_type(): """Test get_metric_config with unsupported type raises ValueError.""" with pytest.raises(ValueError, match="Unsupported metric type") as excinfo: - get_metric_config(123) # type: ignore[arg-type] + get_metric(123) # type: ignore[arg-type] assert "Unsupported metric type" in str(excinfo.value) @@ -252,7 +252,7 @@ def test_get_metric_config_unsupported_type(): def test_get_metric_config_with_none(): """Test get_metric_config with None raises ValueError.""" with pytest.raises(ValueError, match="Unsupported metric type") as excinfo: - get_metric_config(None) # type: ignore[arg-type] + get_metric(None) # type: ignore[arg-type] assert "Unsupported metric type" in str(excinfo.value) @@ -263,7 +263,7 @@ def test_get_metric_config_with_none(): def test_get_metric_configs_with_strings(): """Test get_metric_configs with list of strings.""" metrics = ["r2", "rmse", "mse"] - configs = get_metric_configs(metrics) + configs = get_metrics(metrics) assert len(configs) == 3 assert configs[0] == R2 @@ -278,7 +278,7 @@ def test_get_metric_configs_with_mixed_types(): ) metrics = ["r2", custom_metric, "mse"] - configs = get_metric_configs(metrics) + configs = get_metrics(metrics) assert len(configs) == 3 assert configs[0] == R2 @@ -288,7 +288,7 @@ def test_get_metric_configs_with_mixed_types(): def test_get_metric_configs_with_empty_list(): """Test get_metric_configs with empty list.""" - configs = get_metric_configs([]) + configs = get_metrics([]) assert len(configs) == 0 assert configs == [] @@ -296,7 +296,7 @@ def test_get_metric_configs_with_empty_list(): def test_get_metric_configs_with_single_metric(): """Test get_metric_configs with single metric.""" - configs = get_metric_configs(["r2"]) + configs = get_metrics(["r2"]) assert len(configs) == 1 assert configs[0] == R2 @@ -305,7 +305,7 @@ def test_get_metric_configs_with_single_metric(): def test_get_metric_configs_with_all_available_metrics(): """Test get_metric_configs with all available metrics.""" metrics = list(AVAILABLE_METRICS.keys()) - configs = get_metric_configs(metrics) + configs = get_metrics(metrics) assert len(configs) == len(AVAILABLE_METRICS) @@ -320,7 +320,7 @@ def test_get_metric_configs_with_torchmetrics_instances(): metric=torchmetrics.MeanSquaredError, name="mse_1", maximize=False ) - configs = get_metric_configs([metric1, metric2]) + configs = get_metrics([metric1, metric2]) assert len(configs) == 2 assert configs[0] == metric1 @@ -330,7 +330,7 @@ def test_get_metric_configs_with_torchmetrics_instances(): def test_get_metric_configs_case_insensitive(): """Test get_metric_configs is case insensitive for strings.""" metrics = ["R2", "RMSE", "mse", "MaE", "Crps"] - configs = get_metric_configs(metrics) + configs = get_metrics(metrics) assert len(configs) == 5 assert configs[0] == R2 @@ -396,18 +396,13 @@ def test_metric_with_multidimensional_tensors(): def test_metric_configs_workflow(): """Test complete workflow of getting and using metric configs.""" # Get configs from strings - configs = get_metric_configs(["r2", "rmse"]) + metrics = get_metrics(["r2", "rmse"]) # Use configs to compute metrics y_pred = torch.tensor([1.0, 2.0, 3.0]) y_true = torch.tensor([1.0, 2.0, 3.0]) - results = {} - for config in configs: - metric = config.metric() - metric.update(y_pred, y_true) - results[config.name] = metric.compute() - + results = {metric.name: metric(y_pred, y_true) for metric in metrics} assert "r2" in results assert "rmse" in results assert torch.isclose(results["r2"], torch.tensor(1.0)) # Perfect R2 @@ -509,7 +504,7 @@ def test_crps_aggregation_across_batch(): def test_get_metric_config_crps(): """Test get_metric_config with 'crps' string.""" - config = get_metric_config("crps") + config = get_metric("crps") assert config == CRPS assert isinstance(config, CRPSMetric)