Skip to content

Commit 3406cbe

Browse files
JasonKChowfacebook-github-bot
authored andcommitted
Monotonic rejection model and generator (facebookresearch#458)
Summary: monotonic rejection model GPU support, since they're tied to the generator, we also ensure the generators are gpu ready as well. Differential Revision: D65638150
1 parent d096c6a commit 3406cbe

11 files changed

+283
-47
lines changed

aepsych/generators/monotonic_rejection_generator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def gen(
101101
)
102102

103103
# Augment bounds with deriv indicator
104-
bounds = torch.cat((model.bounds_, torch.zeros(2, 1)), dim=1)
104+
bounds = torch.cat((model.bounds_, torch.zeros(2, 1).to(model.device)), dim=1)
105105
# Fix deriv indicator to 0 during optimization
106106
fixed_features = {(bounds.shape[1] - 1): 0.0}
107107
# Fix explore features to random values

aepsych/means/constant_partial_grad.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
2626
idx = input[..., -1].to(dtype=torch.long) > 0
2727
mean_fit = super(ConstantMeanPartialObsGrad, self).forward(input[..., ~idx, :])
2828
sz = mean_fit.shape[:-1] + torch.Size([input.shape[-2]])
29-
mean = torch.zeros(sz)
29+
mean = torch.zeros(sz).to(input)
3030
mean[~idx] = mean_fit
3131
return mean

aepsych/models/base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,12 @@ def set_train_data(self, inputs=None, targets=None, strict=False):
415415
def device(self) -> torch.device:
416416
# We assume all models have some parameters and all models will only use one device
417417
# notice that this has no setting, don't let users set device, use .to().
418-
return next(self.parameters()).device
418+
try:
419+
return next(self.parameters()).device
420+
except (
421+
AttributeError
422+
): # Fallback for cases where we need device before we have params
423+
return torch.device("cpu")
419424

420425
@property
421426
def train_inputs(self) -> Optional[Tuple[torch.Tensor]]:

aepsych/models/derivative_gp.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch
1414
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
1515
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
16+
from aepsych.models.base import AEPsychModelDeviceMixin
1617
from botorch.models.gpytorch import GPyTorchModel
1718
from gpytorch.distributions import MultivariateNormal
1819
from gpytorch.kernels import Kernel
@@ -22,7 +23,7 @@
2223
from gpytorch.variational import CholeskyVariationalDistribution, VariationalStrategy
2324

2425

25-
class MixedDerivativeVariationalGP(gpytorch.models.ApproximateGP, GPyTorchModel):
26+
class MixedDerivativeVariationalGP(gpytorch.models.ApproximateGP, AEPsychModelDeviceMixin, GPyTorchModel):
2627
"""A variational GP with mixed derivative observations.
2728
2829
For more on GPs with derivative observations, see e.g. Riihimaki & Vehtari 2010.
@@ -99,6 +100,7 @@ def __init__(
99100
self._num_outputs = 1
100101
self.train_inputs = (train_x,)
101102
self.train_targets = train_y
103+
self.to(self.device) # Needed to prep for below
102104
self(train_x) # Necessary for CholeskyVariationalDistribution
103105

104106
def forward(self, x: torch.Tensor) -> MultivariateNormal:

aepsych/models/monotonic_projection_gp.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,17 @@ def posterior(
136136
# using numpy because torch doesn't support vectorized linspace,
137137
# pytorch/issues/61292
138138
grid: Union[np.ndarray, torch.Tensor] = np.linspace(
139-
self.lb[dim],
140-
X[:, dim].numpy(),
139+
self.lb[dim].cpu().numpy(),
140+
X[:, dim].cpu().numpy(),
141141
s + 1,
142142
) # (s+1 x n)
143143
grid = torch.tensor(grid[:-1, :], dtype=X.dtype) # Drop x; (s x n)
144144
X_aug[(1 + i * s) : (1 + (i + 1) * s), :, dim] = grid
145145
# X_aug[0, :, :] is X, and then subsequent indices are points in the grids
146146
# Predict marginal distributions on X_aug
147+
148+
X = X.to(self.device)
149+
X_aug = X_aug.to(self.device)
147150
with torch.no_grad():
148151
post_aug = super().posterior(X=X_aug)
149152
mu_aug = post_aug.mean.squeeze() # (m*s+1 x n)
@@ -158,12 +161,13 @@ def posterior(
158161
# Adjust the whole covariance matrix to accomadate the projected marginals
159162
with torch.no_grad():
160163
post = super().posterior(X=X)
161-
R = cov2corr(post.distribution.covariance_matrix.squeeze().numpy())
162-
S_proj = torch.tensor(corr2cov(R, sigma_proj.numpy()), dtype=X.dtype)
164+
R = cov2corr(post.distribution.covariance_matrix.squeeze().cpu().numpy())
165+
S_proj = torch.tensor(corr2cov(R, sigma_proj.cpu().numpy()), dtype=X.dtype)
163166
mvn_proj = gpytorch.distributions.MultivariateNormal(
164-
mu_proj.unsqueeze(0),
165-
S_proj.unsqueeze(0),
167+
mu_proj.unsqueeze(0).to(self.device),
168+
S_proj.unsqueeze(0).to(self.device),
166169
)
170+
167171
return GPyTorchPosterior(mvn_proj)
168172

169173
def sample(self, x: torch.Tensor, num_samples: int) -> torch.Tensor:

aepsych/models/monotonic_rejection_gp.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from aepsych.factory.monotonic import monotonic_mean_covar_factory
1919
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
2020
from aepsych.means.constant_partial_grad import ConstantMeanPartialObsGrad
21-
from aepsych.models.base import AEPsychMixin
21+
from aepsych.models.base import AEPsychModelDeviceMixin
2222
from aepsych.models.utils import select_inducing_points
2323
from aepsych.utils import _process_bounds, promote_0d
2424
from botorch.fit import fit_gpytorch_mll
@@ -32,7 +32,7 @@
3232
from torch import Tensor
3333

3434

35-
class MonotonicRejectionGP(AEPsychMixin, ApproximateGP):
35+
class MonotonicRejectionGP(AEPsychModelDeviceMixin, ApproximateGP):
3636
"""A monotonic GP using rejection sampling.
3737
3838
This takes the same insight as in e.g. Riihimäki & Vehtari 2010 (that the derivative of a GP
@@ -83,15 +83,15 @@ def __init__(
8383
objective (Optional[MCAcquisitionObjective], optional): Transformation of GP to apply before computing acquisition function. Defaults to identity transform for gaussian likelihood, probit transform for probit-bernoulli.
8484
extra_acqf_args (Optional[Dict[str, object]], optional): Additional arguments to pass into the acquisition function. Defaults to None.
8585
"""
86-
self.lb, self.ub, self.dim = _process_bounds(lb, ub, dim)
86+
lb, ub, self.dim = _process_bounds(lb, ub, dim)
8787
if likelihood is None:
8888
likelihood = BernoulliLikelihood()
8989

9090
self.inducing_size = num_induc
9191
self.inducing_point_method = inducing_point_method
9292
inducing_points = select_inducing_points(
9393
inducing_size=self.inducing_size,
94-
bounds=self.bounds,
94+
bounds=torch.stack((lb, ub)),
9595
method="sobol",
9696
)
9797

@@ -134,7 +134,9 @@ def __init__(
134134

135135
super().__init__(variational_strategy)
136136

137-
self.bounds_ = torch.stack([self.lb, self.ub])
137+
self.register_buffer("lb", lb)
138+
self.register_buffer("ub", ub)
139+
self.register_buffer("bounds_", torch.stack([self.lb, self.ub]))
138140
self.mean_module = mean_module
139141
self.covar_module = covar_module
140142
self.likelihood = likelihood
@@ -144,7 +146,7 @@ def __init__(
144146
self.num_samples = num_samples
145147
self.num_rejection_samples = num_rejection_samples
146148
self.fixed_prior_mean = fixed_prior_mean
147-
self.inducing_points = inducing_points
149+
self.register_buffer("inducing_points", inducing_points)
148150

149151
def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
150152
"""Fit the model
@@ -161,7 +163,7 @@ def fit(self, train_x: Tensor, train_y: Tensor, **kwargs) -> None:
161163
X=self.train_inputs[0],
162164
bounds=self.bounds,
163165
method=self.inducing_point_method,
164-
)
166+
).to(self.device)
165167
self._set_model(train_x, train_y)
166168

167169
def _set_model(
@@ -284,13 +286,14 @@ def predict_probability(
284286
return self.predict(x, probability_space=True)
285287

286288
def _augment_with_deriv_index(self, x: Tensor, indx) -> Tensor:
289+
x = x.to(self.device)
287290
return torch.cat(
288-
(x, indx * torch.ones(x.shape[0], 1)),
291+
(x, indx * torch.ones(x.shape[0], 1).to(self.device)),
289292
dim=1,
290293
)
291294

292295
def _get_deriv_constraint_points(self) -> Tensor:
293-
deriv_cp = torch.tensor([])
296+
deriv_cp = torch.tensor([]).to(self.device)
294297
for i in self.monotonic_idxs:
295298
induc_i = self._augment_with_deriv_index(self.inducing_points, i + 1)
296299
deriv_cp = torch.cat((deriv_cp, induc_i), dim=0)
@@ -299,8 +302,8 @@ def _get_deriv_constraint_points(self) -> Tensor:
299302
@classmethod
300303
def from_config(cls, config: Config) -> MonotonicRejectionGP:
301304
classname = cls.__name__
302-
num_induc = config.gettensor(classname, "num_induc", fallback=25)
303-
num_samples = config.gettensor(classname, "num_samples", fallback=250)
305+
num_induc = config.getint(classname, "num_induc", fallback=25)
306+
num_samples = config.getint(classname, "num_samples", fallback=250)
304307
num_rejection_samples = config.getint(
305308
classname, "num_rejection_samples", fallback=5000
306309
)

aepsych/strategy.py

+3-14
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,6 @@ class Strategy(object):
7171

7272
_n_eval_points: int = 1000
7373

74-
no_gpu_acqfs = (
75-
MonotonicMCAcquisition,
76-
MonotonicBernoulliMCMutualInformation,
77-
MonotonicMCPosteriorVariance,
78-
MonotonicMCLSE,
79-
)
80-
8174
def __init__(
8275
self,
8376
generator: Union[AEPsychGenerator, ParameterTransformedGenerator],
@@ -182,13 +175,7 @@ def __init__(
182175
)
183176
self.generator_device = torch.device("cpu")
184177
else:
185-
if hasattr(generator, "acqf") and generator.acqf in self.no_gpu_acqfs:
186-
warnings.warn(
187-
f"GPU requested for acquistion function {type(generator.acqf).__name__}, but this acquisiton function does not support GPU! Using CPU instead.",
188-
UserWarning,
189-
)
190-
self.generator_device = torch.device("cpu")
191-
elif not torch.cuda.is_available():
178+
if not torch.cuda.is_available():
192179
warnings.warn(
193180
f"GPU requested for generator {type(generator).__name__}, but no GPU found! Using CPU instead.",
194181
UserWarning,
@@ -283,9 +270,11 @@ def normalize_inputs(
283270
x = x[None, :]
284271

285272
if self.x is not None:
273+
x = x.to(self.x)
286274
x = torch.cat((self.x, x), dim=0)
287275

288276
if self.y is not None:
277+
y = y.to(self.y)
289278
y = torch.cat((self.y, y), dim=0)
290279

291280
# Ensure the correct dtype

tests_gpu/acquisition/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta, Inc. and its affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import torch
9+
from aepsych.acquisition.monotonic_rejection import MonotonicMCLSE
10+
from aepsych.acquisition.objective import ProbitObjective
11+
from aepsych.models.derivative_gp import MixedDerivativeVariationalGP
12+
from botorch.acquisition.objective import IdentityMCObjective
13+
from botorch.utils.testing import BotorchTestCase
14+
15+
16+
class TestMonotonicAcq(BotorchTestCase):
17+
def test_monotonic_acq_gpu(self):
18+
# Init
19+
train_X_aug = torch.tensor(
20+
[[0.0, 0.0, 0.0], [1.0, 1.0, 0.0], [2.0, 2.0, 0.0]]
21+
).cuda()
22+
deriv_constraint_points = torch.tensor(
23+
[[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [2.0, 2.0, 1.0]]
24+
).cuda()
25+
train_Y = torch.tensor([[1.0], [2.0], [3.0]]).cuda()
26+
27+
m = MixedDerivativeVariationalGP(
28+
train_x=train_X_aug, train_y=train_Y, inducing_points=train_X_aug
29+
).cuda()
30+
acq = MonotonicMCLSE(
31+
model=m,
32+
deriv_constraint_points=deriv_constraint_points,
33+
num_samples=5,
34+
num_rejection_samples=8,
35+
target=1.9,
36+
)
37+
self.assertTrue(isinstance(acq.objective, IdentityMCObjective))
38+
acq = MonotonicMCLSE(
39+
model=m,
40+
deriv_constraint_points=deriv_constraint_points,
41+
num_samples=5,
42+
num_rejection_samples=8,
43+
target=1.9,
44+
objective=ProbitObjective(),
45+
).cuda()
46+
# forward
47+
acq(train_X_aug)
48+
Xfull = torch.cat((train_X_aug, acq.deriv_constraint_points), dim=0)
49+
posterior = m.posterior(Xfull)
50+
samples = acq.sampler(posterior)
51+
self.assertEqual(samples.shape, torch.Size([5, 6, 1]))

0 commit comments

Comments
 (0)