Skip to content

Commit 47be238

Browse files
authored
Merge pull request #397 from alan-turing-institute/371-compare
Add compare for experimental
2 parents 20f42e1 + c105b1d commit 47be238

12 files changed

+193
-18
lines changed

autoemulate/experimental/compare.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import logging
2+
import warnings
3+
from typing import Any
4+
5+
import numpy as np
6+
from sklearn.model_selection import BaseCrossValidator, KFold
7+
8+
from autoemulate.experimental.data.utils import InputTypeMixin
9+
from autoemulate.experimental.emulators import ALL_EMULATORS
10+
from autoemulate.experimental.emulators.base import Emulator
11+
from autoemulate.experimental.model_selection import cross_validate
12+
from autoemulate.experimental.tuner import Tuner
13+
from autoemulate.experimental.types import InputLike
14+
15+
16+
class AutoEmulate(InputTypeMixin):
17+
def __init__(
18+
self,
19+
x: InputLike,
20+
y: InputLike,
21+
models: list[type[Emulator]] | None = None,
22+
):
23+
# TODO: refactor in https://github.com/alan-turing-institute/autoemulate/issues/400
24+
x, y = self._convert_to_tensors(x, y)
25+
26+
# Set default models if None
27+
updated_models = self.get_models(models)
28+
29+
# Filter models to only be those that can handle multioutput data
30+
if y.shape[1] > 1:
31+
updated_models = self.filter_models_if_multioutput(
32+
updated_models, models is not None
33+
)
34+
35+
self.models = updated_models
36+
self.train_val, self.test = self._random_split(self._convert_to_dataset(x, y))
37+
38+
@staticmethod
39+
def all_emulators() -> list[type[Emulator]]:
40+
return ALL_EMULATORS
41+
42+
def get_models(self, models: list[type[Emulator]] | None) -> list[type[Emulator]]:
43+
if models is None:
44+
return self.all_emulators()
45+
return models
46+
47+
def filter_models_if_multioutput(
48+
self, models: list[type[Emulator]], warn: bool
49+
) -> list[type[Emulator]]:
50+
updated_models = []
51+
for model in models:
52+
if not model.is_multioutput():
53+
if warn:
54+
msg = (
55+
f"Model ({model}) is not multioutput but the data is "
56+
f"multioutput. Skipping model ({model})..."
57+
)
58+
warnings.warn(msg, stacklevel=2)
59+
else:
60+
updated_models.append(model)
61+
return updated_models
62+
63+
def log_compare(self, model_cls, best_model_config, r2_score, rmse_score):
64+
logger = logging.getLogger(__name__)
65+
msg = (
66+
f"Model: {model_cls.__name__}, "
67+
f"Best params: {best_model_config}, "
68+
f"R2 score: {r2_score:.3f}, "
69+
f"RMSE score: {rmse_score:.3f}"
70+
)
71+
logger.info(msg)
72+
73+
def compare(
74+
self, n_iter: int = 10, cv: type[BaseCrossValidator] = KFold
75+
) -> dict[str, dict[str, Any]]:
76+
tuner = Tuner(self.train_val, y=None, n_iter=n_iter)
77+
models_evaluated = {}
78+
for model_cls in self.models:
79+
scores, configs = tuner.run(model_cls)
80+
best_score_idx = scores.index(max(scores))
81+
best_model_config = configs[best_score_idx]
82+
cv_results = cross_validate(
83+
cv(), self.train_val.dataset, model_cls, **best_model_config
84+
)
85+
r2_score, rmse_score = (
86+
np.mean(cv_results["r2"]),
87+
np.mean(cv_results["rmse"]),
88+
)
89+
models_evaluated[model_cls.__name__] = {
90+
"config": best_model_config,
91+
"r2_score": r2_score,
92+
"rmse_score": rmse_score,
93+
}
94+
self.log_compare(model_cls, best_model_config, r2_score, rmse_score)
95+
return models_evaluated

autoemulate/experimental/data/utils.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import numpy as np
22
import torch
3+
import torch.utils
4+
import torch.utils.data
35
from autoemulate.experimental.types import InputLike
4-
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
6+
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset, random_split
57

68

79
class InputTypeMixin:
@@ -31,6 +33,8 @@ def _convert_to_dataset(
3133
dataset = TensorDataset(x)
3234
elif isinstance(x, Dataset) and y is None:
3335
dataset = x
36+
elif isinstance(x, DataLoader) and y is None:
37+
dataset = x.dataset
3438
else:
3539
raise ValueError(
3640
f"Unsupported type for x ({type(x)}). Must be numpy array or PyTorch "
@@ -69,6 +73,21 @@ def _convert_to_tensors(
6973
Convert InputLike x, y to Tensor or tuple of Tensors.
7074
"""
7175
dataset = self._convert_to_dataset(x, y)
76+
77+
# Handle Subset of TensorDataset
78+
if isinstance(dataset, Subset):
79+
if isinstance(dataset.dataset, TensorDataset):
80+
tensors = dataset.dataset.tensors
81+
indices = dataset.indices
82+
83+
# Use indexing to get subset tensors
84+
subset_tensors = tuple(tensor[indices] for tensor in tensors)
85+
dataset = TensorDataset(*subset_tensors)
86+
else:
87+
raise ValueError(
88+
f"Subset must wrap a TensorDataset. Found {type(dataset.dataset)}."
89+
)
90+
7291
if isinstance(dataset, TensorDataset):
7392
if len(dataset.tensors) > 2:
7493
raise ValueError(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .gaussian_process.exact import GaussianProcessExact
2+
from .lightgbm import LightGBM
3+
from .neural_processes.conditional_neural_process import CNPModule
4+
5+
ALL_EMULATORS = [GaussianProcessExact, LightGBM, CNPModule]

autoemulate/experimental/emulators/base.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55

66
from autoemulate.experimental.data.preprocessors import Preprocessor
77
from autoemulate.experimental.data.utils import InputTypeMixin
8-
from autoemulate.experimental.types import (
9-
InputLike,
10-
OutputLike,
11-
TuneConfig,
12-
)
8+
from autoemulate.experimental.types import InputLike, OutputLike, TuneConfig
139

1410

1511
class Emulator(ABC):
@@ -35,6 +31,11 @@ def fit(self, x: InputLike, y: InputLike | None): ...
3531
def predict(self, x: InputLike) -> OutputLike:
3632
pass
3733

34+
@staticmethod
35+
@abstractmethod
36+
def is_multioutput() -> bool:
37+
"""Flag to indicate if the model is multioutput or not."""
38+
3839
@staticmethod
3940
@abstractmethod
4041
def get_tune_config() -> TuneConfig:

autoemulate/experimental/emulators/gaussian_process/exact.py

+4
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
123123
self.batch_size = batch_size
124124
self.activation = activation
125125

126+
@staticmethod
127+
def is_multioutput():
128+
return True
129+
126130
def preprocess(self, x: InputLike) -> InputLike:
127131
"""Preprocess the input data using the preprocessor."""
128132
if self.preprocessor is not None:

autoemulate/experimental/emulators/lightgbm.py

+4
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
6464
self.importance_type = importance_type
6565
self.verbose = verbose
6666

67+
@staticmethod
68+
def is_multioutput() -> bool:
69+
return False
70+
6771
def fit(self, x: InputLike, y: InputLike | None):
6872
"""
6973
Fits the emulator to the data.

autoemulate/experimental/emulators/neural_processes/conditional_neural_process.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ def get_tune_config():
478478
"hidden_layers_dec": [1, 2, 4],
479479
"activation": [nn.ReLU],
480480
"min_context_points": [4, 5, 6],
481-
"offset_context_points": [4, 6],
481+
"offset_context_points": [4, 5],
482+
# max_context_points must be less than n_episodes
482483
"n_episodes": [12, 13, 14],
483484
}

autoemulate/experimental/model_selection.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any
2+
13
import numpy as np
24
import torchmetrics
35
from sklearn.model_selection import BaseCrossValidator
@@ -7,6 +9,7 @@
79
from autoemulate.experimental.types import (
810
DistributionLike,
911
InputLike,
12+
ModelConfig,
1013
OutputLike,
1114
TensorLike,
1215
)
@@ -60,8 +63,8 @@ def evaluate(
6063
def cross_validate(
6164
cv: BaseCrossValidator,
6265
dataset: Dataset,
63-
model: Emulator,
64-
batch_size: int = 16,
66+
model: type[Emulator],
67+
**kwargs: Any,
6568
):
6669
"""
6770
Cross validate model performance using the given `cv` strategy.
@@ -81,7 +84,9 @@ def cross_validate(
8184
dict[str, list[float]]
8285
Contains r2 and rmse scores computed for each cross validation fold.
8386
"""
87+
best_model_config: ModelConfig = kwargs
8488
cv_results = {"r2": [], "rmse": []}
89+
batch_size = best_model_config.get("batch_size", 16)
8590
for train_idx, val_idx in cv.split(dataset): # type: ignore TODO: identify type handling here
8691
# create train/val data subsets
8792
# convert idx to list to satisfy type checker
@@ -91,13 +96,15 @@ def cross_validate(
9196
val_loader = DataLoader(val_subset, batch_size=batch_size)
9297

9398
# fit model
94-
model.fit(train_loader, y=None)
99+
x, y = next(iter(train_loader))
100+
m = model(x, y, **best_model_config)
101+
m.fit(x, y)
95102

96103
# evaluate on batches
97104
r2_metric = torchmetrics.R2Score()
98105
mse_metric = torchmetrics.MeanSquaredError()
99106
for x_batch, y_batch in val_loader:
100-
y_batch_pred = model.predict(x_batch)
107+
y_batch_pred = m.predict(x_batch)
101108
_update(y_batch, y_batch_pred, r2_metric)
102109
_update(y_batch, y_batch_pred, mse_metric)
103110

tests/experimental/test_experimental_base.py

+4
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def get_tune_config():
119119
"batch_size": [16],
120120
}
121121

122+
@staticmethod
123+
def is_multioutput():
124+
return False
125+
122126
def setup_method(self):
123127
"""
124128
Define the PyTorchBackend instance.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from autoemulate.experimental.compare import AutoEmulate
2+
from autoemulate.experimental.emulators import ALL_EMULATORS
3+
4+
5+
def test_compare(sample_data_y2d):
6+
x, y = sample_data_y2d
7+
ae = AutoEmulate(x, y)
8+
results = ae.compare(10)
9+
print(results)
10+
11+
12+
def test_compare_user_models(sample_data_y2d, recwarn):
13+
x, y = sample_data_y2d
14+
ae = AutoEmulate(x, y, models=ALL_EMULATORS)
15+
results = ae.compare(1)
16+
print(results)
17+
assert len(recwarn) == 1
18+
assert str(recwarn.pop().message) == (
19+
"Model (<class 'autoemulate.experimental.emulators.lightgbm.Li"
20+
"ghtGBM'>) is not multioutput but the data is multioutput. Skipping model "
21+
"(<class 'autoemulate.experimental.emulators.lightgbm.LightGBM'>)..."
22+
)
23+
24+
25+
def test_compare_y1d(sample_data_y1d):
26+
x, y = sample_data_y1d
27+
# TODO: add handling when 1D
28+
y = y.reshape(-1, 1)
29+
ae = AutoEmulate(x, y)
30+
results = ae.compare(10)
31+
print(results)

tests/experimental/test_experimental_conditional_neural_process.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_cnp_module_predict_fails_with_calling_fit_first(sample_data_y1d):
6666

6767
def test_tune_gp(sample_data_y1d):
6868
x, y = sample_data_y1d
69-
tuner = Tuner(x, y, n_iter=5)
69+
tuner = Tuner(x, y, n_iter=20)
7070
scores, configs = tuner.run(CNPModule)
71-
assert len(scores) == 5
72-
assert len(configs) == 5
71+
assert len(scores) == 20
72+
assert len(configs) == 20

tests/experimental/test_experimental_model_selection.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_cross_validate():
1111
Test cross_validate can be called with any sklearn.model_selection class.
1212
"""
1313

14-
class DummyEmulator(Emulator):
14+
class DummyEmulator(Emulator, torch.nn.Module):
1515
def __init__(self, x=None, y=None, **kwargs):
1616
pass
1717

@@ -25,22 +25,26 @@ def predict(self, x):
2525
def get_tune_config():
2626
return {}
2727

28+
@staticmethod
29+
def is_multioutput():
30+
return False
31+
2832
x = torch.tensor(np.arange(32)).float()
2933
y = 2 * x
3034
dataset = TensorDataset(x, y)
3135

32-
emulator = DummyEmulator()
36+
emulator_cls = DummyEmulator
3337

3438
# KFold
35-
results = cross_validate(KFold(n_splits=2), dataset, emulator)
39+
results = cross_validate(KFold(n_splits=2), dataset, emulator_cls)
3640
assert "r2" in results
3741
assert "rmse" in results
3842
assert len(results["r2"]) == 2
3943
assert len(results["rmse"]) == 2
4044

4145
# LeavePOut: LOO raised an error with torchmetrics R2Score since it requires at
4246
# least 2 samples
43-
results = cross_validate(LeavePOut(p=2), dataset, emulator)
47+
results = cross_validate(LeavePOut(p=2), dataset, emulator_cls)
4448
expected_n = (x.shape[0] * (x.shape[0] - 1)) / 2
4549
assert len(results["r2"]) == expected_n
4650
assert len(results["rmse"]) == expected_n

0 commit comments

Comments
 (0)