37
37
38
38
import torch
39
39
from botorch .utils .sampling import PolytopeSampler
40
+ from linear_operator .operators import DiagLinearOperator , LinearOperator
40
41
from torch import Tensor
41
42
42
43
_twopi = 2.0 * math .pi
@@ -58,8 +59,8 @@ def __init__(
58
59
interior_point : Optional [Tensor ] = None ,
59
60
fixed_indices : Optional [Union [List [int ], Tensor ]] = None ,
60
61
mean : Optional [Tensor ] = None ,
61
- covariance_matrix : Optional [Tensor ] = None ,
62
- covariance_root : Optional [Tensor ] = None ,
62
+ covariance_matrix : Optional [Union [ Tensor , LinearOperator ] ] = None ,
63
+ covariance_root : Optional [Union [ Tensor , LinearOperator ] ] = None ,
63
64
check_feasibility : bool = False ,
64
65
burnin : int = 0 ,
65
66
thinning : int = 0 ,
@@ -88,7 +89,10 @@ def __init__(
88
89
distribution (if omitted, use the identity).
89
90
covariance_root: A `d x d`-dim root of the covariance matrix such that
90
91
covariance_root @ covariance_root.T = covariance_matrix. NOTE: This
91
- matrix is assumed to be lower triangular.
92
+ matrix is assumed to be lower triangular. covariance_root can only be
93
+ passed in conjunction with fixed_indices if covariance_root is a
94
+ DiagLinearOperator. Otherwise the factorization would need to be re-
95
+ computed, as we need to solve in `standardize`.
92
96
check_feasibility: If True, raise an error if the sampling results in an
93
97
infeasible sample. This creates some overhead and so is switched off
94
98
by default.
@@ -123,14 +127,16 @@ def __init__(
123
127
self ._Az , self ._bz = A , b
124
128
self ._is_fixed , self ._not_fixed = None , None
125
129
if fixed_indices is not None :
126
- mean , covariance_matrix = self ._fixed_features_initialization (
127
- A = A ,
128
- b = b ,
129
- interior_point = interior_point ,
130
- fixed_indices = fixed_indices ,
131
- mean = mean ,
132
- covariance_matrix = covariance_matrix ,
133
- covariance_root = covariance_root ,
130
+ mean , covariance_matrix , covariance_root = (
131
+ self ._fixed_features_initialization (
132
+ A = A ,
133
+ b = b ,
134
+ interior_point = interior_point ,
135
+ fixed_indices = fixed_indices ,
136
+ mean = mean ,
137
+ covariance_matrix = covariance_matrix ,
138
+ covariance_root = covariance_root ,
139
+ )
134
140
)
135
141
136
142
self ._mean = mean
@@ -176,6 +182,9 @@ def _fixed_features_initialization(
176
182
"""Modifies the constraint system (A, b) due to fixed indices and assigns
177
183
the modified constraints system to `self._Az`, `self._bz`. NOTE: Needs to be
178
184
called prior to `self._standardization_initialization` in the constructor.
185
+ covariance_root and fixed_indices can both not be None only if covariance_root
186
+ is a DiagLinearOperator. Otherwise, the covariance matrix would need to be
187
+ refactorized.
179
188
180
189
Returns:
181
190
Tuple of `mean` and `covariance_matrix` tensors of the non-fixed dimensions.
@@ -185,10 +194,16 @@ def _fixed_features_initialization(
185
194
"If `fixed_indices` are provided, an interior point must also be "
186
195
"provided in order to infer feasible values of the fixed features."
187
196
)
188
- if covariance_root is not None :
189
- raise ValueError (
190
- "Provide either covariance_root or fixed_indices, not both."
191
- )
197
+
198
+ root_is_diag = isinstance (covariance_root , DiagLinearOperator )
199
+ if covariance_root is not None and not root_is_diag :
200
+ root_is_diag = (covariance_root .diag ().diag () == covariance_root ).all ()
201
+ if root_is_diag : # convert the diagonal root to a DiagLinearOperator
202
+ covariance_root = DiagLinearOperator (covariance_root .diagonal ())
203
+ else : # otherwise, fail
204
+ raise ValueError (
205
+ "Provide either covariance_root or fixed_indices, not both."
206
+ )
192
207
d = interior_point .shape [0 ]
193
208
is_fixed , not_fixed = get_index_tensors (fixed_indices = fixed_indices , d = d )
194
209
self ._is_fixed = is_fixed
@@ -205,7 +220,10 @@ def _fixed_features_initialization(
205
220
covariance_matrix = covariance_matrix [
206
221
not_fixed .unsqueeze (- 1 ), not_fixed .unsqueeze (0 )
207
222
]
208
- return mean , covariance_matrix
223
+ if root_is_diag : # in the special case of diagonal root, can subselect
224
+ covariance_root = DiagLinearOperator (covariance_root .diagonal ()[not_fixed ])
225
+
226
+ return mean , covariance_matrix , covariance_root
209
227
210
228
def _standardization_initialization (self ) -> None :
211
229
"""For non-standard mean and covariance, we're going to rewrite the problem as
@@ -482,8 +500,10 @@ def _standardize(self, x: Tensor) -> Tensor:
482
500
z = x
483
501
if self ._mean is not None :
484
502
z = z - self ._mean
485
- if self ._covariance_root is not None :
486
- z = torch .linalg .solve_triangular (self ._covariance_root , z , upper = False )
503
+ root = self ._covariance_root
504
+ if root is not None :
505
+ z = torch .linalg .solve_triangular (root , z , upper = False )
506
+
487
507
return z
488
508
489
509
def _unstandardize (self , z : Tensor ) -> Tensor :
0 commit comments