18
18
from aepsych .factory .monotonic import monotonic_mean_covar_factory
19
19
from aepsych .kernels .rbf_partial_grad import RBFKernelPartialObsGrad
20
20
from aepsych .means .constant_partial_grad import ConstantMeanPartialObsGrad
21
- from aepsych .models .base import AEPsychMixin
21
+ from aepsych .models .base import AEPsychModelDeviceMixin
22
22
from aepsych .models .utils import select_inducing_points
23
23
from aepsych .utils import _process_bounds , promote_0d
24
24
from botorch .fit import fit_gpytorch_mll
32
32
from torch import Tensor
33
33
34
34
35
- class MonotonicRejectionGP (AEPsychMixin , ApproximateGP ):
35
+ class MonotonicRejectionGP (AEPsychModelDeviceMixin , ApproximateGP ):
36
36
"""A monotonic GP using rejection sampling.
37
37
38
38
This takes the same insight as in e.g. Riihimäki & Vehtari 2010 (that the derivative of a GP
@@ -83,15 +83,15 @@ def __init__(
83
83
objective (Optional[MCAcquisitionObjective], optional): Transformation of GP to apply before computing acquisition function. Defaults to identity transform for gaussian likelihood, probit transform for probit-bernoulli.
84
84
extra_acqf_args (Optional[Dict[str, object]], optional): Additional arguments to pass into the acquisition function. Defaults to None.
85
85
"""
86
- self . lb , self . ub , self .dim = _process_bounds (lb , ub , dim )
86
+ lb , ub , self .dim = _process_bounds (lb , ub , dim )
87
87
if likelihood is None :
88
88
likelihood = BernoulliLikelihood ()
89
89
90
90
self .inducing_size = num_induc
91
91
self .inducing_point_method = inducing_point_method
92
92
inducing_points = select_inducing_points (
93
93
inducing_size = self .inducing_size ,
94
- bounds = self . bounds ,
94
+ bounds = torch . stack (( lb , ub )) ,
95
95
method = "sobol" ,
96
96
)
97
97
@@ -134,7 +134,9 @@ def __init__(
134
134
135
135
super ().__init__ (variational_strategy )
136
136
137
- self .bounds_ = torch .stack ([self .lb , self .ub ])
137
+ self .register_buffer ("lb" , lb )
138
+ self .register_buffer ("ub" , ub )
139
+ self .register_buffer ("bounds_" , torch .stack ([self .lb , self .ub ]))
138
140
self .mean_module = mean_module
139
141
self .covar_module = covar_module
140
142
self .likelihood = likelihood
@@ -144,7 +146,7 @@ def __init__(
144
146
self .num_samples = num_samples
145
147
self .num_rejection_samples = num_rejection_samples
146
148
self .fixed_prior_mean = fixed_prior_mean
147
- self .inducing_points = inducing_points
149
+ self .register_buffer ( " inducing_points" , inducing_points )
148
150
149
151
def fit (self , train_x : Tensor , train_y : Tensor , ** kwargs ) -> None :
150
152
"""Fit the model
@@ -161,7 +163,7 @@ def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
161
163
X = self .train_inputs [0 ],
162
164
bounds = self .bounds ,
163
165
method = self .inducing_point_method ,
164
- )
166
+ ). to ( self . device )
165
167
self ._set_model (train_x , train_y )
166
168
167
169
def _set_model (
@@ -284,13 +286,14 @@ def predict_probability(
284
286
return self .predict (x , probability_space = True )
285
287
286
288
def _augment_with_deriv_index (self , x : Tensor , indx ) -> Tensor :
289
+ x = x .to (self .device )
287
290
return torch .cat (
288
- (x , indx * torch .ones (x .shape [0 ], 1 )),
291
+ (x , indx * torch .ones (x .shape [0 ], 1 ). to ( self . device ) ),
289
292
dim = 1 ,
290
293
)
291
294
292
295
def _get_deriv_constraint_points (self ) -> Tensor :
293
- deriv_cp = torch .tensor ([])
296
+ deriv_cp = torch .tensor ([]). to ( self . device )
294
297
for i in self .monotonic_idxs :
295
298
induc_i = self ._augment_with_deriv_index (self .inducing_points , i + 1 )
296
299
deriv_cp = torch .cat ((deriv_cp , induc_i ), dim = 0 )
@@ -299,8 +302,8 @@ def _get_deriv_constraint_points(self) -> Tensor:
299
302
@classmethod
300
303
def from_config (cls , config : Config ) -> MonotonicRejectionGP :
301
304
classname = cls .__name__
302
- num_induc = config .gettensor (classname , "num_induc" , fallback = 25 )
303
- num_samples = config .gettensor (classname , "num_samples" , fallback = 250 )
305
+ num_induc = config .getint (classname , "num_induc" , fallback = 25 )
306
+ num_samples = config .getint (classname , "num_samples" , fallback = 250 )
304
307
num_rejection_samples = config .getint (
305
308
classname , "num_rejection_samples" , fallback = 5000
306
309
)
0 commit comments