5
5
import torch
6
6
from gpytorch import ExactMarginalLogLikelihood
7
7
from gpytorch .distributions import MultitaskMultivariateNormal , MultivariateNormal
8
- from gpytorch .kernels import (
9
- ScaleKernel ,
10
- )
8
+ from gpytorch .kernels import ScaleKernel
11
9
from gpytorch .likelihoods import MultitaskGaussianLikelihood
12
10
from torch import nn
13
11
26
24
zero_mean ,
27
25
)
28
26
from autoemulate .experimental .data .preprocessors import Preprocessor , Standardizer
29
- from autoemulate .experimental .emulators .base import (
30
- Emulator ,
31
- InputTypeMixin ,
32
- )
27
+ from autoemulate .experimental .emulators .base import Emulator , InputTypeMixin
33
28
from autoemulate .experimental .emulators .gaussian_process import (
34
29
CovarModuleFn ,
35
30
MeanModuleFn ,
36
31
)
37
- from autoemulate .experimental .types import InputLike , OutputLike
32
+ from autoemulate .experimental .types import OutputLike , TensorLike
38
33
from autoemulate .utils import set_random_seed
39
34
40
35
@@ -53,8 +48,8 @@ class GaussianProcessExact(
53
48
54
49
def __init__ ( # noqa: PLR0913 allow too many arguments since all currently required
55
50
self ,
56
- x : InputLike ,
57
- y : InputLike ,
51
+ x : TensorLike ,
52
+ y : TensorLike ,
58
53
likelihood_cls : type [MultitaskGaussianLikelihood ] = MultitaskGaussianLikelihood ,
59
54
mean_module_fn : MeanModuleFn = constant_mean ,
60
55
covar_module_fn : CovarModuleFn = rbf ,
@@ -68,6 +63,7 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
68
63
if random_state is not None :
69
64
set_random_seed (random_state )
70
65
66
+ # TODO (#422): update the call here to check or call e.g. `_ensure_2d`
71
67
x , y = self ._convert_to_tensors (x , y )
72
68
73
69
# Initialize the mean and covariance modules
@@ -85,8 +81,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
85
81
)
86
82
)
87
83
88
- assert isinstance (y , torch .Tensor )
89
- assert isinstance (x , torch .Tensor )
90
84
self .n_features_in_ = x .shape [1 ]
91
85
self .n_outputs_ = y .shape [1 ] if y .ndim > 1 else 1
92
86
@@ -108,7 +102,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
108
102
109
103
# Init must be called with preprocessed data
110
104
x_preprocessed = self .preprocess (x )
111
- assert isinstance (x_preprocessed , torch .Tensor )
112
105
gpytorch .models .ExactGP .__init__ (
113
106
self ,
114
107
train_inputs = x_preprocessed ,
@@ -127,24 +120,21 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
127
120
def is_multioutput ():
128
121
return True
129
122
130
- def preprocess (self , x : InputLike ) -> InputLike :
123
+ def preprocess (self , x : TensorLike ) -> TensorLike :
131
124
"""Preprocess the input data using the preprocessor."""
132
125
if self .preprocessor is not None :
133
126
x = self .preprocessor .preprocess (x )
134
127
return x
135
128
136
- def forward (self , x : InputLike ):
137
- assert isinstance (x , torch .Tensor )
129
+ def forward (self , x : TensorLike ):
138
130
mean = self .mean_module (x )
139
-
140
131
assert isinstance (mean , torch .Tensor )
141
132
covar = self .covar_module (x )
142
-
143
133
return MultitaskMultivariateNormal .from_batch_mvn (
144
134
MultivariateNormal (mean , covar )
145
135
)
146
136
147
- def log_epoch (self , epoch : int , loss : torch . Tensor ):
137
+ def log_epoch (self , epoch : int , loss : TensorLike ):
148
138
logger = logging .getLogger (__name__ )
149
139
assert self .likelihood .noise is not None
150
140
msg = (
@@ -153,15 +143,16 @@ def log_epoch(self, epoch: int, loss: torch.Tensor):
153
143
)
154
144
logger .info (msg )
155
145
156
- def _fit (self , x : InputLike , y : InputLike | None ):
146
+ def _fit (self , x : TensorLike , y : TensorLike ):
157
147
self .train ()
158
148
self .likelihood .train ()
159
- # Ensure tensors and correct shapes
160
- x , y = self ._convert_to_tensors (self ._convert_to_dataset (x , y ))
149
+
150
+ # TODO: move conversion out of _fit() and instead rely on for impl check
151
+ x , y = self ._convert_to_tensors (x , y )
152
+
161
153
optimizer = torch .optim .Adam (self .parameters (), lr = self .lr )
162
154
mll = ExactMarginalLogLikelihood (self .likelihood , self )
163
155
x = self .preprocess (x )
164
- assert isinstance (x , torch .Tensor )
165
156
166
157
# Set the training data in case changed since init
167
158
self .set_train_data (x , y , strict = False )
@@ -176,14 +167,14 @@ def _fit(self, x: InputLike, y: InputLike | None):
176
167
self .log_epoch (epoch , loss )
177
168
optimizer .step ()
178
169
179
- def _predict (self , x : InputLike ) -> OutputLike :
170
+ def _predict (self , x : TensorLike ) -> OutputLike :
180
171
self .eval ()
181
- x = self .preprocess (x )
182
- x_tensor = self ._convert_to_tensors (x )
172
+ # TODO: remove upon implmenting validation
183
173
if not isinstance (x , torch .Tensor ):
184
174
msg = f"x ({ x } ) must be a torch.Tensor"
185
175
raise ValueError (msg )
186
- return self (x_tensor )
176
+ x = self .preprocess (x )
177
+ return self (x )
187
178
188
179
@staticmethod
189
180
def get_tune_config ():
0 commit comments