Skip to content

Commit

Permalink
Support multi-output models in MES using PosteriorTransform (#904)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #904

Reviewed By: Balandat

Differential Revision: D30022574

fbshipit-source-id: 6292eea8500c3013fd29efefd736352231316891
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Feb 28, 2022
1 parent 6965426 commit c984666
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 34 deletions.
4 changes: 3 additions & 1 deletion botorch/acquisition/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def __init__(
sampler: The sampler used for drawing fantasy samples. In the basic setting
of a standard GP (default) this is a dummy, since the variance of the
model after conditioning does not actually depend on the sampled values.
posterior_transform: A PosteriorTransform. Required for multi-output models.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
X_pending: A `n' x d`-dim Tensor of `n'` design points that have
points that have been submitted for function evaluation but
have not yet been evaluated.
Expand Down
24 changes: 18 additions & 6 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def __init__(
Args:
model: A fitted single-outcome model.
posterior_transform: A PosteriorTransform. Required for multi-output models.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
"""
super().__init__(model=model)
posterior_transform = self._deprecate_acqf_objective(
Expand Down Expand Up @@ -99,7 +101,9 @@ def __init__(
model: A fitted single-outcome model.
best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing
the best function value observed so far (assumed noiseless).
posterior_transform: A PosteriorTransform. Required for multi-output models.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
maximize: If True, consider the problem a maximization problem.
"""
super().__init__(model=model, posterior_transform=posterior_transform, **kwargs)
Expand Down Expand Up @@ -164,7 +168,9 @@ def __init__(
Args:
model: A fitted single-outcome GP model (must be in batch mode if
candidate sets X will be)
posterior_transform: A PosteriorTransform. Required for multi-output models.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
maximize: If True, consider the problem a maximization problem. Note
that if `maximize=False`, the posterior mean is negated. As a
consequence `optimize_acqf(PosteriorMean(gp, maximize=False))`
Expand Down Expand Up @@ -225,7 +231,9 @@ def __init__(
model: A fitted single-outcome model.
best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing
the best function value observed so far (assumed noiseless).
posterior_transform: A PosteriorTransform. Required for multi-output models.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
maximize: If True, consider the problem a maximization problem.
"""
super().__init__(model=model, posterior_transform=posterior_transform, **kwargs)
Expand Down Expand Up @@ -293,7 +301,9 @@ def __init__(
candidate sets X will be)
beta: Either a scalar or a one-dim tensor with `b` elements (batch mode)
representing the trade-off parameter between mean and covariance
posterior_transform: A PosteriorTransform. Required for multi-output models.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
maximize: If True, consider the problem a maximization problem.
"""
super().__init__(model=model, posterior_transform=posterior_transform, **kwargs)
Expand Down Expand Up @@ -636,7 +646,9 @@ def __init__(
Args:
model: A fitted single-outcome model.
weights: A tensor of shape `q` for scalarization.
posterior_transform: A PosteriorTransform. Required for multi-output models.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
"""
super().__init__(model=model, posterior_transform=posterior_transform, **kwargs)
self.register_buffer("weights", weights.unsqueeze(dim=0))
Expand Down
97 changes: 78 additions & 19 deletions botorch/acquisition/max_value_entropy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.cost_aware import CostAwareUtility, InverseCostWeightedUtility
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions.errors import UnsupportedError
from botorch.models.cost import AffineFidelityCostModel
from botorch.models.model import Model
from botorch.models.utils import check_no_nans
Expand Down Expand Up @@ -68,6 +70,7 @@ def __init__(
self,
model: Model,
num_mv_samples: int,
posterior_transform: Optional[PosteriorTransform] = None,
maximize: bool = True,
X_pending: Optional[Tensor] = None,
) -> None:
Expand All @@ -76,17 +79,18 @@ def __init__(
Args:
model: A fitted single-outcome model.
num_mv_samples: Number of max value samples.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
maximize: If True, consider the problem a maximization problem.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
submitted for function evaluation but have not yet been evaluated.
"""
super().__init__(model=model)

# Multi-output GP models are not currently supported
if model.num_outputs != 1:
raise NotImplementedError(
"Multi-output models are not yet supported by "
f"`{self.__class__.__name__}`."
if posterior_transform is None and model.num_outputs != 1:
raise UnsupportedError(
"Must specify a posterior transform when using a multi-output model."
)

# Batched GP models are not currently supported
Expand All @@ -96,10 +100,11 @@ def __init__(
batch_shape = torch.Size()
if len(batch_shape) > 0:
raise NotImplementedError(
"Batched GP models (e.g. fantasized models) are not yet "
"Batched GP models (e.g., fantasized models) are not yet "
f"supported by `{self.__class__.__name__}`."
)
self.num_mv_samples = num_mv_samples
self.posterior_transform = posterior_transform
self.maximize = maximize
self.weight = 1.0 if maximize else -1.0
self.set_X_pending(X_pending)
Expand All @@ -116,7 +121,11 @@ def forward(self, X: Tensor) -> Tensor:
A `batch_shape`-dim Tensor of MVE values at the given design points `X`.
"""
# Compute the posterior, posterior mean, variance and std
posterior = self.model.posterior(X.unsqueeze(-3), observation_noise=False)
posterior = self.model.posterior(
X.unsqueeze(-3),
observation_noise=False,
posterior_transform=self.posterior_transform,
)
# batch_shape x num_fantasies x (m) x 1
mean = self.weight * posterior.mean.squeeze(-1).squeeze(-1)
variance = posterior.variance.clamp_min(CLAMP_LB).view_as(mean)
Expand Down Expand Up @@ -193,6 +202,7 @@ def __init__(
model: Model,
candidate_set: Tensor,
num_mv_samples: int = 10,
posterior_transform: Optional[PosteriorTransform] = None,
use_gumbel: bool = True,
maximize: bool = True,
X_pending: Optional[Tensor] = None,
Expand All @@ -206,6 +216,9 @@ def __init__(
discretize the design space. Max values are sampled from the
(joint) model posterior over these points.
num_mv_samples: Number of max value samples.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
use_gumbel: If True, use Gumbel approximation to sample the max values.
maximize: If True, consider the problem a maximization problem.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
Expand All @@ -231,6 +244,7 @@ def __init__(
super().__init__(
model=model,
num_mv_samples=num_mv_samples,
posterior_transform=posterior_transform,
maximize=maximize,
X_pending=X_pending,
)
Expand Down Expand Up @@ -275,6 +289,7 @@ def _sample_max_values(
model=self.model,
candidate_set=candidate_set,
num_samples=self.num_mv_samples,
posterior_transform=self.posterior_transform,
maximize=self.maximize,
)

Expand Down Expand Up @@ -303,6 +318,7 @@ def __init__(
num_fantasies: int = 16,
num_mv_samples: int = 10,
num_y_samples: int = 128,
posterior_transform: Optional[PosteriorTransform] = None,
use_gumbel: bool = True,
maximize: bool = True,
X_pending: Optional[Tensor] = None,
Expand All @@ -321,6 +337,9 @@ def __init__(
complexity, wall time and memory). Ignored if `X_pending` is `None`.
num_mv_samples: Number of max value samples.
num_y_samples: Number of posterior samples at specific design point `X`.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
use_gumbel: If True, use Gumbel approximation to sample the max values.
maximize: If True, consider the problem a maximization problem.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
Expand All @@ -332,6 +351,7 @@ def __init__(
model=model,
candidate_set=candidate_set,
num_mv_samples=num_mv_samples,
posterior_transform=posterior_transform,
use_gumbel=use_gumbel,
maximize=maximize,
X_pending=X_pending,
Expand Down Expand Up @@ -397,7 +417,11 @@ def _compute_information_gain(
given design points `X` (`num_fantasies=1` for non-fantasized models).
"""
# compute the std_m, variance_m with noisy observation
posterior_m = self.model.posterior(X.unsqueeze(-3), observation_noise=True)
posterior_m = self.model.posterior(
X.unsqueeze(-3),
observation_noise=True,
posterior_transform=self.posterior_transform,
)
# batch_shape x num_fantasies x (m) x (1 + num_trace_observations)
mean_m = self.weight * posterior_m.mean.squeeze(-1)
# batch_shape x num_fantasies x (m) x (1 + num_trace_observations)
Expand Down Expand Up @@ -489,7 +513,7 @@ class qLowerBoundMaxValueEntropy(DiscreteMaxValueBase):
the mutual information between max values and a batch of candidate points `X`.
See [Moss2021gibbon]_ for a detailed discussion.
The model must be single-outcome.
The model must be single-outcome, unless using a PosteriorTransform.
q > 1 is supported through greedy batch filling.
Example:
Expand Down Expand Up @@ -527,7 +551,9 @@ def _compute_information_gain(
# doing posterior computations twice

# compute the mean_m, variance_m with noisy observation
posterior_m = self.model.posterior(X, observation_noise=True)
posterior_m = self.model.posterior(
X, observation_noise=True, posterior_transform=self.posterior_transform
)
mean_m = self.weight * posterior_m.mean.squeeze(-1)
# batch_shape x 1
variance_m = posterior_m.variance.clamp_min(CLAMP_LB).squeeze(-1)
Expand Down Expand Up @@ -584,17 +610,29 @@ def _compute_information_gain(
# it provides only a translation of the acqusition function surface
# and can thus be ignored.

if self.posterior_transform is not None:
raise UnsupportedError(
"qLowerBoundMaxValueEntropy does not support PosteriorTransforms"
"when X_pending is not None."
)

X_batches = torch.cat(
[X, self.X_pending.unsqueeze(0).repeat(X.shape[0], 1, 1)], 1
)
# batch_shape x (1 + m) x d
# NOTE: This is the blocker for supporting posterior transforms.
# We would have to process this MVN, applying whatever operations
# are typically applied for the corresponding posterior, then applying
# the posterior transform onto the resulting object.
V = self.model(X_batches)
# Evaluate terms required for A
A = V.lazy_covariance_matrix[:, 0, 1:].unsqueeze(1)
# batch_shape x 1 x m
# Evaluate terms required for B
B = self.model.posterior(
self.X_pending, observation_noise=True
self.X_pending,
observation_noise=True,
posterior_transform=self.posterior_transform,
).mvn.covariance_matrix.unsqueeze(0)
# 1 x m x m

Expand All @@ -616,8 +654,8 @@ class qMultiFidelityMaxValueEntropy(qMaxValueEntropy):
for a detailed discussion of the basic ideas on multi-fidelity MES
(note that this implementation is somewhat different).
The model must be single-outcome. The batch case `q > 1` is supported
through cyclic optimization and fantasies.
The model must be single-outcome, unless using a PosteriorTransform.
The batch case `q > 1` is supported through cyclic optimization and fantasies.
Example:
>>> model = SingleTaskGP(train_X, train_Y)
Expand All @@ -634,6 +672,7 @@ def __init__(
num_fantasies: int = 16,
num_mv_samples: int = 10,
num_y_samples: int = 128,
posterior_transform: Optional[PosteriorTransform] = None,
use_gumbel: bool = True,
maximize: bool = True,
X_pending: Optional[Tensor] = None,
Expand All @@ -657,6 +696,9 @@ def __init__(
is not `None`.
num_mv_samples: Number of max value samples.
num_y_samples: Number of posterior samples at specific design point `X`.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
use_gumbel: If True, use Gumbel approximation to sample the max values.
maximize: If True, consider the problem a maximization problem.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
Expand All @@ -678,9 +720,10 @@ def __init__(
num_fantasies=num_fantasies,
num_mv_samples=num_mv_samples,
num_y_samples=num_y_samples,
X_pending=X_pending,
posterior_transform=posterior_transform,
use_gumbel=use_gumbel,
maximize=maximize,
X_pending=X_pending,
)

if cost_aware_utility is None:
Expand Down Expand Up @@ -731,7 +774,9 @@ def forward(self, X: Tensor) -> Tensor:

# Compute the posterior, posterior mean, variance without noise
# `_m` and `_M` in the var names means the current and the max fidelity.
posterior = self.model.posterior(X_all, observation_noise=False)
posterior = self.model.posterior(
X_all, observation_noise=False, posterior_transform=self.posterior_transform
)
mean_M = self.weight * posterior.mean[..., -1, 0] # batch_shape x num_fantasies
variance_M = posterior.variance[..., -1, 0].clamp_min(CLAMP_LB)
# get the covariance between the low fidelities and max fidelity
Expand All @@ -751,7 +796,11 @@ def forward(self, X: Tensor) -> Tensor:


def _sample_max_value_Thompson(
model: Model, candidate_set: Tensor, num_samples: int, maximize: bool = True
model: Model,
candidate_set: Tensor,
num_samples: int,
posterior_transform: Optional[PosteriorTransform] = None,
maximize: bool = True,
) -> Tensor:
"""Samples the max values by discrete Thompson sampling.
Expand All @@ -762,12 +811,15 @@ def _sample_max_value_Thompson(
candidate_set: A `n x d` Tensor including `n` candidate points to
discretize the design space.
num_samples: Number of max value samples.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
maximize: If True, consider the problem a maximization problem.
Returns:
A `num_samples x num_fantasies` Tensor of posterior max value samples.
"""
posterior = model.posterior(candidate_set)
posterior = model.posterior(candidate_set, posterior_transform=posterior_transform)
weight = 1.0 if maximize else -1.0
samples = weight * posterior.rsample(torch.Size([num_samples])).squeeze(-1)
# samples is num_samples x (num_fantasies) x n
Expand All @@ -779,7 +831,11 @@ def _sample_max_value_Thompson(


def _sample_max_value_Gumbel(
model: Model, candidate_set: Tensor, num_samples: int, maximize: bool = True
model: Model,
candidate_set: Tensor,
num_samples: int,
posterior_transform: Optional[PosteriorTransform] = None,
maximize: bool = True,
) -> Tensor:
"""Samples the max values by Gumbel approximation.
Expand All @@ -790,13 +846,16 @@ def _sample_max_value_Gumbel(
candidate_set: A `n x d` Tensor including `n` candidate points to
discretize the design space.
num_samples: Number of max value samples.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
maximize: If True, consider the problem a maximization problem.
Returns:
A `num_samples x num_fantasies` Tensor of posterior max value samples.
"""
# define the approximate CDF for the max value under the independence assumption
posterior = model.posterior(candidate_set)
posterior = model.posterior(candidate_set, posterior_transform=posterior_transform)
weight = 1.0 if maximize else -1.0
mu = weight * posterior.mean
sigma = posterior.variance.clamp_min(1e-8).sqrt()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def __init__(
self._sample_max_values()
else:
self.set_X_pending(X_pending)
# This avoids attribute errors in qMaxValueEntropy code.
self.posterior_transform = None

def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
r"""Set pending points.
Expand Down
Loading

0 comments on commit c984666

Please sign in to comment.