Skip to content

Commit 25d69e7

Browse files
authored
Merge pull request #456 from alan-turing-institute/400-revise-input-types
Revise input types (#400)
2 parents a0a0781 + 52a8439 commit 25d69e7

File tree

8 files changed

+83
-85
lines changed

8 files changed

+83
-85
lines changed

autoemulate/experimental/data/preprocessors.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from abc import ABC, abstractmethod
22

33
import torch
4-
from autoemulate.experimental.types import InputLike
4+
from autoemulate.experimental.types import TensorLike
55

66

77
class Preprocessor(ABC):
88
@abstractmethod
99
def __init__(*args, **kwargs): ...
1010

1111
@abstractmethod
12-
def preprocess(self, x: InputLike) -> InputLike: ...
12+
def preprocess(self, x: TensorLike) -> TensorLike: ...
1313

1414

1515
class Standardizer(Preprocessor):

autoemulate/experimental/emulators/base.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from autoemulate.experimental.data.preprocessors import Preprocessor
77
from autoemulate.experimental.data.utils import InputTypeMixin
88
from autoemulate.experimental.data.validation import ValidationMixin
9-
from autoemulate.experimental.types import InputLike, OutputLike, TuneConfig
9+
from autoemulate.experimental.types import OutputLike, TensorLike, TuneConfig
1010

1111

1212
class Emulator(ABC, ValidationMixin):
@@ -17,26 +17,26 @@ class Emulator(ABC, ValidationMixin):
1717
"""
1818

1919
@abstractmethod
20-
def _fit(self, x: InputLike, y: InputLike | None): ...
20+
def _fit(self, x: TensorLike, y: TensorLike): ...
2121

22-
def fit(self, x: InputLike, y: InputLike | None):
22+
def fit(self, x: TensorLike, y: TensorLike):
2323
self._check(x, y)
2424
self._fit(x, y)
2525

2626
@abstractmethod
2727
def __init__(
28-
self, x: InputLike | None = None, y: InputLike | None = None, **kwargs
28+
self, x: TensorLike | None = None, y: TensorLike | None = None, **kwargs
2929
): ...
3030

3131
@classmethod
3232
def model_name(cls) -> str:
3333
return cls.__name__
3434

3535
@abstractmethod
36-
def _predict(self, x: InputLike) -> OutputLike:
36+
def _predict(self, x: TensorLike) -> OutputLike:
3737
pass
3838

39-
def predict(self, x: InputLike) -> OutputLike:
39+
def predict(self, x: TensorLike) -> OutputLike:
4040
self._check(x, None)
4141
output = self._predict(x)
4242
self._check_output(output)
@@ -93,7 +93,7 @@ class PyTorchBackend(nn.Module, Emulator, InputTypeMixin, Preprocessor):
9393
loss_fn: nn.Module = nn.MSELoss()
9494
optimizer: optim.Optimizer
9595

96-
def preprocess(self, x):
96+
def preprocess(self, x: TensorLike) -> TensorLike:
9797
if self.preprocessor is None:
9898
return x
9999
return self.preprocessor.preprocess(x)
@@ -107,15 +107,15 @@ def loss_func(self, y_pred, y_true):
107107

108108
def _fit(
109109
self,
110-
x: InputLike,
111-
y: InputLike | None,
110+
x: TensorLike,
111+
y: TensorLike,
112112
):
113113
"""
114114
Train a PyTorchBackend model.
115115
116116
Parameters
117117
----------
118-
X: InputLike
118+
X: TensorLike
119119
Input features as numpy array, PyTorch tensor, or DataLoader.
120120
y: OutputLike or None
121121
Target values (not needed if x is a DataLoader).
@@ -160,7 +160,7 @@ def _fit(
160160
if self.verbose and (epoch + 1) % (self.epochs // 10 or 1) == 0:
161161
print(f"Epoch [{epoch + 1}/{self.epochs}], Loss: {avg_epoch_loss:.4f}")
162162

163-
def _predict(self, x: InputLike) -> OutputLike:
163+
def _predict(self, x: TensorLike) -> OutputLike:
164164
self.eval()
165165
x = self.preprocess(x)
166166
return self(x)

autoemulate/experimental/emulators/gaussian_process/exact.py

+18-27
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import torch
66
from gpytorch import ExactMarginalLogLikelihood
77
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
8-
from gpytorch.kernels import (
9-
ScaleKernel,
10-
)
8+
from gpytorch.kernels import ScaleKernel
119
from gpytorch.likelihoods import MultitaskGaussianLikelihood
1210
from torch import nn
1311

@@ -26,15 +24,12 @@
2624
zero_mean,
2725
)
2826
from autoemulate.experimental.data.preprocessors import Preprocessor, Standardizer
29-
from autoemulate.experimental.emulators.base import (
30-
Emulator,
31-
InputTypeMixin,
32-
)
27+
from autoemulate.experimental.emulators.base import Emulator, InputTypeMixin
3328
from autoemulate.experimental.emulators.gaussian_process import (
3429
CovarModuleFn,
3530
MeanModuleFn,
3631
)
37-
from autoemulate.experimental.types import InputLike, OutputLike
32+
from autoemulate.experimental.types import OutputLike, TensorLike
3833
from autoemulate.utils import set_random_seed
3934

4035

@@ -53,8 +48,8 @@ class GaussianProcessExact(
5348

5449
def __init__( # noqa: PLR0913 allow too many arguments since all currently required
5550
self,
56-
x: InputLike,
57-
y: InputLike,
51+
x: TensorLike,
52+
y: TensorLike,
5853
likelihood_cls: type[MultitaskGaussianLikelihood] = MultitaskGaussianLikelihood,
5954
mean_module_fn: MeanModuleFn = constant_mean,
6055
covar_module_fn: CovarModuleFn = rbf,
@@ -68,6 +63,7 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
6863
if random_state is not None:
6964
set_random_seed(random_state)
7065

66+
# TODO (#422): update the call here to check or call e.g. `_ensure_2d`
7167
x, y = self._convert_to_tensors(x, y)
7268

7369
# Initialize the mean and covariance modules
@@ -85,8 +81,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
8581
)
8682
)
8783

88-
assert isinstance(y, torch.Tensor)
89-
assert isinstance(x, torch.Tensor)
9084
self.n_features_in_ = x.shape[1]
9185
self.n_outputs_ = y.shape[1] if y.ndim > 1 else 1
9286

@@ -108,7 +102,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
108102

109103
# Init must be called with preprocessed data
110104
x_preprocessed = self.preprocess(x)
111-
assert isinstance(x_preprocessed, torch.Tensor)
112105
gpytorch.models.ExactGP.__init__(
113106
self,
114107
train_inputs=x_preprocessed,
@@ -127,24 +120,21 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
127120
def is_multioutput():
128121
return True
129122

130-
def preprocess(self, x: InputLike) -> InputLike:
123+
def preprocess(self, x: TensorLike) -> TensorLike:
131124
"""Preprocess the input data using the preprocessor."""
132125
if self.preprocessor is not None:
133126
x = self.preprocessor.preprocess(x)
134127
return x
135128

136-
def forward(self, x: InputLike):
137-
assert isinstance(x, torch.Tensor)
129+
def forward(self, x: TensorLike):
138130
mean = self.mean_module(x)
139-
140131
assert isinstance(mean, torch.Tensor)
141132
covar = self.covar_module(x)
142-
143133
return MultitaskMultivariateNormal.from_batch_mvn(
144134
MultivariateNormal(mean, covar)
145135
)
146136

147-
def log_epoch(self, epoch: int, loss: torch.Tensor):
137+
def log_epoch(self, epoch: int, loss: TensorLike):
148138
logger = logging.getLogger(__name__)
149139
assert self.likelihood.noise is not None
150140
msg = (
@@ -153,15 +143,16 @@ def log_epoch(self, epoch: int, loss: torch.Tensor):
153143
)
154144
logger.info(msg)
155145

156-
def _fit(self, x: InputLike, y: InputLike | None):
146+
def _fit(self, x: TensorLike, y: TensorLike):
157147
self.train()
158148
self.likelihood.train()
159-
# Ensure tensors and correct shapes
160-
x, y = self._convert_to_tensors(self._convert_to_dataset(x, y))
149+
150+
# TODO: move conversion out of _fit() and instead rely on for impl check
151+
x, y = self._convert_to_tensors(x, y)
152+
161153
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
162154
mll = ExactMarginalLogLikelihood(self.likelihood, self)
163155
x = self.preprocess(x)
164-
assert isinstance(x, torch.Tensor)
165156

166157
# Set the training data in case changed since init
167158
self.set_train_data(x, y, strict=False)
@@ -176,14 +167,14 @@ def _fit(self, x: InputLike, y: InputLike | None):
176167
self.log_epoch(epoch, loss)
177168
optimizer.step()
178169

179-
def _predict(self, x: InputLike) -> OutputLike:
170+
def _predict(self, x: TensorLike) -> OutputLike:
180171
self.eval()
181-
x = self.preprocess(x)
182-
x_tensor = self._convert_to_tensors(x)
172+
# TODO: remove upon implmenting validation
183173
if not isinstance(x, torch.Tensor):
184174
msg = f"x ({x}) must be a torch.Tensor"
185175
raise ValueError(msg)
186-
return self(x_tensor)
176+
x = self.preprocess(x)
177+
return self(x)
187178

188179
@staticmethod
189180
def get_tune_config():

autoemulate/experimental/emulators/lightgbm.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
from sklearn.utils.validation import check_array, check_is_fitted, check_X_y
44
from torch import Tensor
55

6-
from autoemulate.experimental.emulators.base import (
7-
Emulator,
8-
InputTypeMixin,
9-
)
10-
from autoemulate.experimental.types import InputLike, OutputLike
6+
from autoemulate.experimental.emulators.base import Emulator, InputTypeMixin
7+
from autoemulate.experimental.types import OutputLike, TensorLike
118

129

1310
class LightGBM(Emulator, InputTypeMixin):
@@ -20,8 +17,8 @@ class LightGBM(Emulator, InputTypeMixin):
2017

2118
def __init__( # noqa: PLR0913 allow too many arguments since all currently required
2219
self,
23-
x: InputLike | None = None,
24-
y: InputLike | None = None,
20+
x: TensorLike | None = None,
21+
y: TensorLike | None = None,
2522
boosting_type: str = "gbdt",
2623
num_leaves: int = 31,
2724
max_depth: int = -1,
@@ -68,28 +65,30 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
6865
def is_multioutput() -> bool:
6966
return False
7067

71-
def _fit(self, x: InputLike, y: InputLike | None):
68+
def _fit(self, x: TensorLike, y: TensorLike):
7269
"""
7370
Fits the emulator to the data.
7471
The model expects the input data to be:
7572
x (features): 2D array
7673
y (target): 1D array
7774
"""
7875

79-
x, y = self._convert_to_numpy(x, y)
76+
x_np, y_np = self._convert_to_numpy(x, y)
8077

81-
if y is None:
78+
# TODO (#422): move to validation
79+
if y_np is None:
8280
msg = "y must be provided."
8381
raise ValueError(msg)
84-
if y.ndim > 2:
85-
msg = f"y must be 1D or 2D array. Found {y.ndim}D array."
82+
if y_np.ndim > 2:
83+
msg = f"y must be 1D or 2D array. Found {y_np.ndim}D array."
8684
raise ValueError(msg)
87-
if y.ndim == 2: # _convert_to_numpy may return 2D y
88-
y = y.ravel() # Ensure y is 1-dimensional
85+
if y_np.ndim == 2: # _convert_to_numpy may return 2D y
86+
y_np = y_np.ravel() # Ensure y is 1-dimensional
8987

90-
self.n_features_in_ = x.shape[1]
88+
self.n_features_in_ = x_np.shape[1]
9189

92-
x, y = check_X_y(x, y, y_numeric=True)
90+
# TODO (#422): move to validation
91+
x_np, y_np = check_X_y(x_np, y_np, y_numeric=True)
9392

9493
self.model_ = LGBMRegressor(
9594
boosting_type=self.boosting_type,
@@ -113,12 +112,13 @@ def _fit(self, x: InputLike, y: InputLike | None):
113112
verbose=self.verbose,
114113
)
115114

116-
self.model_.fit(x, y)
115+
self.model_.fit(x_np, y_np)
117116
self.is_fitted_ = True
118117

119-
def _predict(self, x: InputLike) -> OutputLike:
118+
def _predict(self, x: TensorLike) -> OutputLike:
120119
"""Predicts the output of the emulator for a given input."""
121120
x = check_array(x)
121+
# TODO (#422): move to predict() and consider if required
122122
check_is_fitted(self, "is_fitted_")
123123
y_pred = self.model_.predict(x)
124124
# Ensure the output is a 2D tensor array with shape (n_samples, 1)

autoemulate/experimental/emulators/neural_processes/conditional_neural_process.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.utils
44
import torch.utils.data
55
from autoemulate.experimental.emulators.base import PyTorchBackend
6-
from autoemulate.experimental.types import DistributionLike, InputLike, TensorLike
6+
from autoemulate.experimental.types import DistributionLike, TensorLike
77
from torch import nn
88
from torch.utils.data import Dataset
99

@@ -246,8 +246,8 @@ class CNPModule(PyTorchBackend):
246246

247247
def __init__( # noqa: PLR0913
248248
self,
249-
x: InputLike,
250-
y: InputLike,
249+
x: TensorLike,
250+
y: TensorLike,
251251
hidden_dim: int = 32,
252252
latent_dim: int = 16,
253253
hidden_layers_enc: int = 2,
@@ -287,9 +287,10 @@ def __init__( # noqa: PLR0913
287287
Batch size for training.
288288
"""
289289
super().__init__()
290-
x_, y_ = self._convert_to_tensors(x, y)
291-
self.input_dim = x_.shape[1]
292-
self.output_dim = y_.shape[1]
290+
# TODO (#422): update the call here to check or call e.g. `_ensure_2d`
291+
x, y = self._convert_to_tensors(x, y)
292+
self.input_dim = x.shape[1]
293+
self.output_dim = y.shape[1]
293294
self.encoder = Encoder(
294295
self.input_dim,
295296
self.output_dim,
@@ -347,11 +348,7 @@ def forward(
347348
reinterpreted_batch_ndims=1,
348349
)
349350

350-
def _fit(
351-
self,
352-
x: InputLike,
353-
y: InputLike | None,
354-
):
351+
def _fit(self, x: TensorLike, y: TensorLike):
355352
"""
356353
Fit the model to the data.
357354
Note the batching of data is done internally in the method.
@@ -364,8 +361,8 @@ def _fit(
364361
"""
365362
self.train()
366363

367-
# TODO: revisit as part of https://github.com/alan-turing-institute/autoemulate/issues/400
368364
# Save off all X_train and y_train
365+
# TODO (#422): update the call here to check or call e.g. `_ensure_2d`
369366
self.x_train, self.y_train = self._convert_to_tensors(x, y)
370367

371368
# Convert dataset to CNP Dataset
@@ -415,7 +412,7 @@ def _fit(
415412
if self.verbose and (epoch + 1) % (self.epochs // 10 or 1) == 0:
416413
print(f"Epoch [{epoch + 1}/{self.epochs}], Loss: {avg_epoch_loss:.4f}")
417414

418-
def _predict(self, x: InputLike) -> DistributionLike:
415+
def _predict(self, x: TensorLike) -> DistributionLike:
419416
"""
420417
Predict uses the training data as the context data and the input x as the target
421418
data. The data is preprocessed within the method.
@@ -432,14 +429,15 @@ def _predict(self, x: InputLike) -> DistributionLike:
432429
Note the distribution is a single tensor of shape (n_points, output_dim).
433430
434431
"""
432+
# TODO: add to validation _check
435433
if self.x_train is None or self.y_train is None:
436434
msg = "Model has not been trained. Please call fit() before predict()."
437435
raise ValueError(msg)
438436

439437
self.eval()
440438
x = self.preprocess(x)
441439

442-
# Convert x to a dataset
440+
# TODO: add to validation _check
443441
x_target = self._convert_to_tensors(x)
444442

445443
# Sort splitting into context and target

0 commit comments

Comments
 (0)