Skip to content

Commit 7c28a20

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
ESS: Allowing diagonal covariance root with fixed indices (#2283)
Summary: Pull Request resolved: #2283 This commit adds support for a diagonal covariance root in conjunction with fixed indices for ESS. This is not generally supported, as the root would have to be re-factorized. The diagonal case allows for an efficient implementation without re-factorization. Reviewed By: Balandat Differential Revision: D55808235 fbshipit-source-id: d9403ceede26d24340964e9b5a06586c64144506
1 parent 0d2f926 commit 7c28a20

File tree

2 files changed

+55
-19
lines changed

2 files changed

+55
-19
lines changed

botorch/utils/probability/lin_ess.py

+38-18
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import torch
3939
from botorch.utils.sampling import PolytopeSampler
40+
from linear_operator.operators import DiagLinearOperator, LinearOperator
4041
from torch import Tensor
4142

4243
_twopi = 2.0 * math.pi
@@ -58,8 +59,8 @@ def __init__(
5859
interior_point: Optional[Tensor] = None,
5960
fixed_indices: Optional[Union[List[int], Tensor]] = None,
6061
mean: Optional[Tensor] = None,
61-
covariance_matrix: Optional[Tensor] = None,
62-
covariance_root: Optional[Tensor] = None,
62+
covariance_matrix: Optional[Union[Tensor, LinearOperator]] = None,
63+
covariance_root: Optional[Union[Tensor, LinearOperator]] = None,
6364
check_feasibility: bool = False,
6465
burnin: int = 0,
6566
thinning: int = 0,
@@ -88,7 +89,10 @@ def __init__(
8889
distribution (if omitted, use the identity).
8990
covariance_root: A `d x d`-dim root of the covariance matrix such that
9091
covariance_root @ covariance_root.T = covariance_matrix. NOTE: This
91-
matrix is assumed to be lower triangular.
92+
matrix is assumed to be lower triangular. covariance_root can only be
93+
passed in conjunction with fixed_indices if covariance_root is a
94+
DiagLinearOperator. Otherwise the factorization would need to be re-
95+
computed, as we need to solve in `standardize`.
9296
check_feasibility: If True, raise an error if the sampling results in an
9397
infeasible sample. This creates some overhead and so is switched off
9498
by default.
@@ -123,14 +127,16 @@ def __init__(
123127
self._Az, self._bz = A, b
124128
self._is_fixed, self._not_fixed = None, None
125129
if fixed_indices is not None:
126-
mean, covariance_matrix = self._fixed_features_initialization(
127-
A=A,
128-
b=b,
129-
interior_point=interior_point,
130-
fixed_indices=fixed_indices,
131-
mean=mean,
132-
covariance_matrix=covariance_matrix,
133-
covariance_root=covariance_root,
130+
mean, covariance_matrix, covariance_root = (
131+
self._fixed_features_initialization(
132+
A=A,
133+
b=b,
134+
interior_point=interior_point,
135+
fixed_indices=fixed_indices,
136+
mean=mean,
137+
covariance_matrix=covariance_matrix,
138+
covariance_root=covariance_root,
139+
)
134140
)
135141

136142
self._mean = mean
@@ -176,6 +182,9 @@ def _fixed_features_initialization(
176182
"""Modifies the constraint system (A, b) due to fixed indices and assigns
177183
the modified constraints system to `self._Az`, `self._bz`. NOTE: Needs to be
178184
called prior to `self._standardization_initialization` in the constructor.
185+
covariance_root and fixed_indices can both not be None only if covariance_root
186+
is a DiagLinearOperator. Otherwise, the covariance matrix would need to be
187+
refactorized.
179188
180189
Returns:
181190
Tuple of `mean` and `covariance_matrix` tensors of the non-fixed dimensions.
@@ -185,10 +194,16 @@ def _fixed_features_initialization(
185194
"If `fixed_indices` are provided, an interior point must also be "
186195
"provided in order to infer feasible values of the fixed features."
187196
)
188-
if covariance_root is not None:
189-
raise ValueError(
190-
"Provide either covariance_root or fixed_indices, not both."
191-
)
197+
198+
root_is_diag = isinstance(covariance_root, DiagLinearOperator)
199+
if covariance_root is not None and not root_is_diag:
200+
root_is_diag = (covariance_root.diag().diag() == covariance_root).all()
201+
if root_is_diag: # convert the diagonal root to a DiagLinearOperator
202+
covariance_root = DiagLinearOperator(covariance_root.diagonal())
203+
else: # otherwise, fail
204+
raise ValueError(
205+
"Provide either covariance_root or fixed_indices, not both."
206+
)
192207
d = interior_point.shape[0]
193208
is_fixed, not_fixed = get_index_tensors(fixed_indices=fixed_indices, d=d)
194209
self._is_fixed = is_fixed
@@ -205,7 +220,10 @@ def _fixed_features_initialization(
205220
covariance_matrix = covariance_matrix[
206221
not_fixed.unsqueeze(-1), not_fixed.unsqueeze(0)
207222
]
208-
return mean, covariance_matrix
223+
if root_is_diag: # in the special case of diagonal root, can subselect
224+
covariance_root = DiagLinearOperator(covariance_root.diagonal()[not_fixed])
225+
226+
return mean, covariance_matrix, covariance_root
209227

210228
def _standardization_initialization(self) -> None:
211229
"""For non-standard mean and covariance, we're going to rewrite the problem as
@@ -482,8 +500,10 @@ def _standardize(self, x: Tensor) -> Tensor:
482500
z = x
483501
if self._mean is not None:
484502
z = z - self._mean
485-
if self._covariance_root is not None:
486-
z = torch.linalg.solve_triangular(self._covariance_root, z, upper=False)
503+
root = self._covariance_root
504+
if root is not None:
505+
z = torch.linalg.solve_triangular(root, z, upper=False)
506+
487507
return z
488508

489509
def _unstandardize(self, z: Tensor) -> Tensor:

test/utils/probability/test_lin_ess.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from botorch.utils.constraints import get_monotonicity_constraints
1818
from botorch.utils.probability.lin_ess import LinearEllipticalSliceSampler
1919
from botorch.utils.testing import BotorchTestCase
20+
from linear_operator.operators import DiagLinearOperator
2021
from torch import Tensor
2122

2223

@@ -428,9 +429,24 @@ def test_multivariate(self):
428429
inequality_constraints=(A, b),
429430
interior_point=interior_point,
430431
fixed_indices=[0],
431-
covariance_root=torch.eye(d, **tkwargs),
432+
covariance_root=torch.full((d, d), 100, **tkwargs),
432433
)
433434

435+
# providing a diagonal covariance_root should work with fixed indices
436+
diag_root = torch.full((d,), 100, **tkwargs)
437+
for covariance_root in [DiagLinearOperator(diag_root), diag_root.diag()]:
438+
torch.manual_seed(1234)
439+
sampler = LinearEllipticalSliceSampler(
440+
inequality_constraints=(A, b),
441+
interior_point=interior_point,
442+
fixed_indices=[0],
443+
covariance_root=covariance_root,
444+
)
445+
num_samples = 16
446+
X_fixed = sampler.draw(n=num_samples)
447+
self.assertTrue((X_fixed[:, 0] == interior_point[0]).all())
448+
self.assertGreater(X_fixed.std().item(), 10.0) # false if sigma = 1
449+
434450
# high dimensional test case
435451
# Encodes order constraints on all d variables: Ax < b <-> x[i] < x[i + 1]
436452
d = 128

0 commit comments

Comments
 (0)