@@ -999,10 +999,12 @@ def sparse_to_dense_constraints(
999
999
def optimize_posterior_samples (
1000
1000
paths : GenericDeterministicModel ,
1001
1001
bounds : Tensor ,
1002
- raw_samples : int = 1024 ,
1003
- num_restarts : int = 20 ,
1002
+ raw_samples : int = 2048 ,
1003
+ num_restarts : int = 4 ,
1004
1004
sample_transform : Callable [[Tensor ], Tensor ] | None = None ,
1005
1005
return_transformed : bool = False ,
1006
+ suggested_points : Tensor | None = None ,
1007
+ options : dict | None = None ,
1006
1008
) -> tuple [Tensor , Tensor ]:
1007
1009
r"""Cheaply maximizes posterior samples by random querying followed by
1008
1010
gradient-based optimization using SciPy's L-BFGS-B routine.
@@ -1011,19 +1013,27 @@ def optimize_posterior_samples(
1011
1013
paths: Random Fourier Feature-based sample paths from the GP
1012
1014
bounds: The bounds on the search space.
1013
1015
raw_samples: The number of samples with which to query the samples initially.
1016
+ Raw samples are cheap to evaluate, so this should ideally be set much higher
1017
+ than num_restarts.
1014
1018
num_restarts: The number of points selected for gradient-based optimization.
1019
+ Should be set low relative to the number of raw
1015
1020
sample_transform: A callable transform of the sample outputs (e.g.
1016
1021
MCAcquisitionObjective or ScalarizedPosteriorTransform.evaluate) used to
1017
1022
negate the objective or otherwise transform the output.
1018
1023
return_transformed: A boolean indicating whether to return the transformed
1019
1024
or non-transformed samples.
1025
+ suggested_points: Tensor of suggested input locations that are high-valued.
1026
+ These are more densely evaluated during the sampling phase of optimization.
1027
+ options: Options for generation of initial candidates, passed to
1028
+ gen_batch_initial_conditions.
1020
1029
1021
1030
Returns:
1022
1031
A two-element tuple containing:
1023
1032
- X_opt: A `num_optima x [batch_size] x d`-dim tensor of optimal inputs x*.
1024
1033
- f_opt: A `num_optima x [batch_size] x m`-dim, optionally
1025
1034
`num_optima x [batch_size] x 1`-dim, tensor of optimal outputs f*.
1026
1035
"""
1036
+ options = {} if options is None else options
1027
1037
1028
1038
def path_func (x ) -> Tensor :
1029
1039
res = paths (x )
@@ -1032,21 +1042,35 @@ def path_func(x) -> Tensor:
1032
1042
1033
1043
return res .squeeze (- 1 )
1034
1044
1035
- candidate_set = unnormalize (
1036
- SobolEngine (dimension = bounds .shape [1 ], scramble = True ).draw (n = raw_samples ),
1037
- bounds = bounds ,
1038
- )
1039
1045
# queries all samples on all candidates - output shape
1040
1046
# raw_samples * num_optima * num_models
1047
+ frac_random = 1 if suggested_points is None else options .get ("frac_random" , 0.9 )
1048
+ candidate_set = draw_sobol_samples (
1049
+ bounds = bounds , n = round (raw_samples * frac_random ), q = 1
1050
+ ).squeeze (- 2 )
1051
+ if frac_random < 1 :
1052
+ perturbed_suggestions = sample_truncated_normal_perturbations (
1053
+ X = suggested_points ,
1054
+ n_discrete_points = round (raw_samples * (1 - frac_random )),
1055
+ sigma = options .get ("sample_around_best_sigma" , 1e-2 ),
1056
+ bounds = bounds ,
1057
+ )
1058
+ candidate_set = torch .cat ((candidate_set , perturbed_suggestions ))
1059
+
1041
1060
candidate_queries = path_func (candidate_set )
1042
- argtop_k = torch .topk (candidate_queries , num_restarts , dim = - 1 ).indices
1043
- X_top_k = candidate_set [argtop_k , :]
1061
+ idx = boltzmann_sample (
1062
+ function_values = candidate_queries .unsqueeze (- 1 ),
1063
+ num_samples = num_restarts ,
1064
+ eta = options .get ("eta" , 5.0 ),
1065
+ replacement = False ,
1066
+ )
1067
+ ics = candidate_set [idx , :]
1044
1068
1045
1069
# to avoid circular import, the import occurs here
1046
1070
from botorch .generation .gen import gen_candidates_scipy
1047
1071
1048
1072
X_top_k , f_top_k = gen_candidates_scipy (
1049
- X_top_k ,
1073
+ ics ,
1050
1074
path_func ,
1051
1075
lower_bounds = bounds [0 ],
1052
1076
upper_bounds = bounds [1 ],
@@ -1101,8 +1125,9 @@ def boltzmann_sample(
1101
1125
eta *= temp_decrease
1102
1126
weights = torch .exp (eta * norm_weights )
1103
1127
1128
+ # squeeze in case of m = 1 (mono-output provided as batch_size x N x 1)
1104
1129
return batched_multinomial (
1105
- weights = weights , num_samples = num_samples , replacement = replacement
1130
+ weights = weights . squeeze ( - 1 ) , num_samples = num_samples , replacement = replacement
1106
1131
)
1107
1132
1108
1133
0 commit comments