@@ -973,7 +973,11 @@ <h1>Source code for botorch.optim.optimize</h1><div class="highlight"><pre>
973
973
< span class ="n "> nonlinear_inequality_constraints</ span > < span class ="p "> :</ span > < span class ="nb "> list</ span > < span class ="p "> [</ span > < span class ="nb "> tuple</ span > < span class ="p "> [</ span > < span class ="n "> Callable</ span > < span class ="p "> ,</ span > < span class ="nb "> bool</ span > < span class ="p "> ]]</ span > < span class ="o "> |</ span > < span class ="kc "> None</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span >
974
974
< span class ="n "> post_processing_func</ span > < span class ="p "> :</ span > < span class ="n "> Callable</ span > < span class ="p "> [[</ span > < span class ="n "> Tensor</ span > < span class ="p "> ],</ span > < span class ="n "> Tensor</ span > < span class ="p "> ]</ span > < span class ="o "> |</ span > < span class ="kc "> None</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span >
975
975
< span class ="n "> batch_initial_conditions</ span > < span class ="p "> :</ span > < span class ="n "> Tensor</ span > < span class ="o "> |</ span > < span class ="kc "> None</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span >
976
+ < span class ="n "> return_best_only</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span >
977
+ < span class ="n "> gen_candidates</ span > < span class ="p "> :</ span > < span class ="n "> TGenCandidates</ span > < span class ="o "> |</ span > < span class ="kc "> None</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span >
976
978
< span class ="n "> ic_generator</ span > < span class ="p "> :</ span > < span class ="n "> TGenInitialConditions</ span > < span class ="o "> |</ span > < span class ="kc "> None</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span >
979
+ < span class ="n "> timeout_sec</ span > < span class ="p "> :</ span > < span class ="nb "> float</ span > < span class ="o "> |</ span > < span class ="kc "> None</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span >
980
+ < span class ="n "> retry_on_optimization_warning</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span >
977
981
< span class ="n "> ic_gen_kwargs</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="o "> |</ span > < span class ="kc "> None</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span >
978
982
< span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="nb "> tuple</ span > < span class ="p "> [</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> Tensor</ span > < span class ="p "> ]:</ span >
979
983
< span class ="w "> </ span > < span class ="sa "> r</ span > < span class ="sd "> """Optimize over a list of fixed_features and returns the best solution.</ span >
@@ -1022,20 +1026,38 @@ <h1>Source code for botorch.optim.optimize</h1><div class="highlight"><pre>
1022
1026
< span class ="sd "> transformations).</ span >
1023
1027
< span class ="sd "> batch_initial_conditions: A tensor to specify the initial conditions. Set</ span >
1024
1028
< span class ="sd "> this if you do not want to use default initialization strategy.</ span >
1029
+ < span class ="sd "> return_best_only: If False, outputs the solutions corresponding to all</ span >
1030
+ < span class ="sd "> random restart initializations of the optimization. Setting this keyword</ span >
1031
+ < span class ="sd "> to False is only allowed for `q=1`. Defaults to True.</ span >
1032
+ < span class ="sd "> gen_candidates: A callable for generating candidates (and their associated</ span >
1033
+ < span class ="sd "> acquisition values) given a tensor of initial conditions and an</ span >
1034
+ < span class ="sd "> acquisition function. Other common inputs include lower and upper bounds</ span >
1035
+ < span class ="sd "> and a dictionary of options, but refer to the documentation of specific</ span >
1036
+ < span class ="sd "> generation functions (e.g gen_candidates_scipy and gen_candidates_torch)</ span >
1037
+ < span class ="sd "> for method-specific inputs. Default: `gen_candidates_scipy`</ span >
1025
1038
< span class ="sd "> ic_generator: Function for generating initial conditions. Not needed when</ span >
1026
1039
< span class ="sd "> `batch_initial_conditions` are provided. Defaults to</ span >
1027
1040
< span class ="sd "> `gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition</ span >
1028
1041
< span class ="sd "> functions and `gen_batch_initial_conditions` otherwise. Must be specified</ span >
1029
1042
< span class ="sd "> for nonlinear inequality constraints.</ span >
1043
+ < span class ="sd "> timeout_sec: Max amount of time optimization can run for.</ span >
1044
+ < span class ="sd "> retry_on_optimization_warning: Whether to retry candidate generation with a new</ span >
1045
+ < span class ="sd "> set of initial conditions when it fails with an `OptimizationWarning`.</ span >
1030
1046
< span class ="sd "> ic_gen_kwargs: Additional keyword arguments passed to function specified by</ span >
1031
1047
< span class ="sd "> `ic_generator`</ span >
1032
1048
1033
1049
< span class ="sd "> Returns:</ span >
1034
1050
< span class ="sd "> A two-element tuple containing</ span >
1035
1051
1036
- < span class ="sd "> - a `q x d`-dim tensor of generated candidates.</ span >
1037
- < span class ="sd "> - an associated acquisition value.</ span >
1052
+ < span class ="sd "> - A tensor of generated candidates. The shape is</ span >
1053
+ < span class ="sd "> -- `q x d` if `return_best_only` is True (default)</ span >
1054
+ < span class ="sd "> -- `num_restarts x q x d` if `return_best_only` is False</ span >
1055
+ < span class ="sd "> - a tensor of associated acquisition values of dim `num_restarts`</ span >
1056
+ < span class ="sd "> if `return_best_only=False` else a scalar acquisition value.</ span >
1038
1057
< span class ="sd "> """</ span >
1058
+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> return_best_only</ span > < span class ="ow "> and</ span > < span class ="n "> q</ span > < span class ="o "> ></ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
1059
+ < span class ="k "> raise</ span > < span class ="ne "> NotImplementedError</ span > < span class ="p "> (</ span > < span class ="s2 "> "`return_best_only=False` is only supported for q=1."</ span > < span class ="p "> )</ span >
1060
+
1039
1061
< span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> fixed_features_list</ span > < span class ="p "> :</ span >
1040
1062
< span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="s2 "> "fixed_features_list must be non-empty."</ span > < span class ="p "> )</ span >
1041
1063
@@ -1050,11 +1072,12 @@ <h1>Source code for botorch.optim.optimize</h1><div class="highlight"><pre>
1050
1072
< span class ="n "> ic_gen_kwargs</ span > < span class ="o "> =</ span > < span class ="n "> ic_gen_kwargs</ span > < span class ="ow "> or</ span > < span class ="p "> {}</ span >
1051
1073
1052
1074
< span class ="k "> if</ span > < span class ="n "> q</ span > < span class ="o "> ==</ span > < span class ="mi "> 1</ span > < span class ="p "> :</ span >
1075
+ < span class ="n "> timeout_sec</ span > < span class ="o "> =</ span > < span class ="n "> timeout_sec</ span > < span class ="o "> /</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> fixed_features_list</ span > < span class ="p "> )</ span > < span class ="k "> if</ span > < span class ="n "> timeout_sec</ span > < span class ="k "> else</ span > < span class ="kc "> None</ span >
1053
1076
< span class ="n "> ff_candidate_list</ span > < span class ="p "> ,</ span > < span class ="n "> ff_acq_value_list</ span > < span class ="o "> =</ span > < span class ="p "> [],</ span > < span class ="p "> []</ span >
1054
1077
< span class ="n "> num_candidate_generation_failures</ span > < span class ="o "> =</ span > < span class ="mi "> 0</ span >
1055
1078
< span class ="k "> for</ span > < span class ="n "> fixed_features</ span > < span class ="ow "> in</ span > < span class ="n "> fixed_features_list</ span > < span class ="p "> :</ span >
1056
1079
< span class ="k "> try</ span > < span class ="p "> :</ span >
1057
- < span class ="n "> candidate </ span > < span class ="p "> ,</ span > < span class ="n "> acq_value </ span > < span class ="o "> =</ span > < span class ="n "> optimize_acqf</ span > < span class ="p "> (</ span >
1080
+ < span class ="n "> candidates </ span > < span class ="p "> ,</ span > < span class ="n "> acq_values </ span > < span class ="o "> =</ span > < span class ="n "> optimize_acqf</ span > < span class ="p "> (</ span >
1058
1081
< span class ="n "> acq_function</ span > < span class ="o "> =</ span > < span class ="n "> acq_function</ span > < span class ="p "> ,</ span >
1059
1082
< span class ="n "> bounds</ span > < span class ="o "> =</ span > < span class ="n "> bounds</ span > < span class ="p "> ,</ span >
1060
1083
< span class ="n "> q</ span > < span class ="o "> =</ span > < span class ="n "> q</ span > < span class ="p "> ,</ span >
@@ -1068,15 +1091,19 @@ <h1>Source code for botorch.optim.optimize</h1><div class="highlight"><pre>
1068
1091
< span class ="n "> post_processing_func</ span > < span class ="o "> =</ span > < span class ="n "> post_processing_func</ span > < span class ="p "> ,</ span >
1069
1092
< span class ="n "> batch_initial_conditions</ span > < span class ="o "> =</ span > < span class ="n "> batch_initial_conditions</ span > < span class ="p "> ,</ span >
1070
1093
< span class ="n "> ic_generator</ span > < span class ="o "> =</ span > < span class ="n "> ic_generator</ span > < span class ="p "> ,</ span >
1071
- < span class ="n "> return_best_only</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span >
1094
+ < span class ="n "> return_best_only</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="c1 "> # here we always return all candidates</ span >
1095
+ < span class ="c1 "> # and filter later</ span >
1096
+ < span class ="n "> gen_candidates</ span > < span class ="o "> =</ span > < span class ="n "> gen_candidates</ span > < span class ="p "> ,</ span >
1097
+ < span class ="n "> timeout_sec</ span > < span class ="o "> =</ span > < span class ="n "> timeout_sec</ span > < span class ="p "> ,</ span >
1098
+ < span class ="n "> retry_on_optimization_warning</ span > < span class ="o "> =</ span > < span class ="n "> retry_on_optimization_warning</ span > < span class ="p "> ,</ span >
1072
1099
< span class ="o "> **</ span > < span class ="n "> ic_gen_kwargs</ span > < span class ="p "> ,</ span >
1073
1100
< span class ="p "> )</ span >
1074
1101
< span class ="k "> except</ span > < span class ="n "> CandidateGenerationError</ span > < span class ="p "> :</ span >
1075
1102
< span class ="c1 "> # if candidate generation fails, we skip this candidate</ span >
1076
1103
< span class ="n "> num_candidate_generation_failures</ span > < span class ="o "> +=</ span > < span class ="mi "> 1</ span >
1077
1104
< span class ="k "> continue</ span >
1078
- < span class ="n "> ff_candidate_list</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> candidate </ span > < span class ="p "> )</ span >
1079
- < span class ="n "> ff_acq_value_list</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> acq_value </ span > < span class ="p "> )</ span >
1105
+ < span class ="n "> ff_candidate_list</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> candidates </ span > < span class ="p "> )</ span >
1106
+ < span class ="n "> ff_acq_value_list</ span > < span class ="o "> .</ span > < span class ="n "> append</ span > < span class ="p "> (</ span > < span class ="n "> acq_values </ span > < span class ="p "> )</ span >
1080
1107
1081
1108
< span class ="k "> if</ span > < span class ="nb "> len</ span > < span class ="p "> (</ span > < span class ="n "> ff_candidate_list</ span > < span class ="p "> )</ span > < span class ="o "> ==</ span > < span class ="mi "> 0</ span > < span class ="p "> :</ span >
1082
1109
< span class ="k "> raise</ span > < span class ="n "> CandidateGenerationError</ span > < span class ="p "> (</ span >
@@ -1091,16 +1118,25 @@ <h1>Source code for botorch.optim.optimize</h1><div class="highlight"><pre>
1091
1118
< span class ="n "> OptimizationWarning</ span > < span class ="p "> ,</ span >
1092
1119
< span class ="n "> stacklevel</ span > < span class ="o "> =</ span > < span class ="mi "> 3</ span > < span class ="p "> ,</ span >
1093
1120
< span class ="p "> )</ span >
1121
+
1094
1122
< span class ="n "> ff_acq_values</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> stack</ span > < span class ="p "> (</ span > < span class ="n "> ff_acq_value_list</ span > < span class ="p "> )</ span >
1095
- < span class ="n "> best</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> argmax</ span > < span class ="p "> (</ span > < span class ="n "> ff_acq_values</ span > < span class ="p "> )</ span >
1096
- < span class ="k "> return</ span > < span class ="n "> ff_candidate_list</ span > < span class ="p "> [</ span > < span class ="n "> best</ span > < span class ="p "> ],</ span > < span class ="n "> ff_acq_values</ span > < span class ="p "> [</ span > < span class ="n "> best</ span > < span class ="p "> ]</ span >
1123
+ < span class ="n "> max_res</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> max</ span > < span class ="p "> (</ span > < span class ="n "> ff_acq_values</ span > < span class ="p "> ,</ span > < span class ="n "> dim</ span > < span class ="o "> =-</ span > < span class ="mi "> 1</ span > < span class ="p "> )</ span >
1124
+ < span class ="n "> best_batch_idx</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> argmax</ span > < span class ="p "> (</ span > < span class ="n "> max_res</ span > < span class ="o "> .</ span > < span class ="n "> values</ span > < span class ="p "> )</ span >
1125
+ < span class ="n "> best_batch_candidates</ span > < span class ="o "> =</ span > < span class ="n "> ff_candidate_list</ span > < span class ="p "> [</ span > < span class ="n "> best_batch_idx</ span > < span class ="p "> ]</ span >
1126
+ < span class ="n "> best_acq_values</ span > < span class ="o "> =</ span > < span class ="n "> ff_acq_value_list</ span > < span class ="p "> [</ span > < span class ="n "> best_batch_idx</ span > < span class ="p "> ]</ span >
1127
+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> return_best_only</ span > < span class ="p "> :</ span >
1128
+ < span class ="k "> return</ span > < span class ="n "> best_batch_candidates</ span > < span class ="p "> ,</ span > < span class ="n "> best_acq_values</ span >
1129
+
1130
+ < span class ="n "> best_idx</ span > < span class ="o "> =</ span > < span class ="n "> max_res</ span > < span class ="o "> .</ span > < span class ="n "> indices</ span > < span class ="p "> [</ span > < span class ="n "> best_batch_idx</ span > < span class ="p "> ]</ span >
1131
+ < span class ="k "> return</ span > < span class ="n "> best_batch_candidates</ span > < span class ="p "> [</ span > < span class ="n "> best_idx</ span > < span class ="p "> ],</ span > < span class ="n "> best_acq_values</ span > < span class ="p "> [</ span > < span class ="n "> best_idx</ span > < span class ="p "> ]</ span >
1097
1132
1098
1133
< span class ="c1 "> # For batch optimization with q > 1 we do not want to enumerate all n_combos^n</ span >
1099
1134
< span class ="c1 "> # possible combinations of discrete choices. Instead, we use sequential greedy</ span >
1100
1135
< span class ="c1 "> # optimization.</ span >
1101
1136
< span class ="n "> base_X_pending</ span > < span class ="o "> =</ span > < span class ="n "> acq_function</ span > < span class ="o "> .</ span > < span class ="n "> X_pending</ span >
1102
1137
< span class ="n "> candidates</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> tensor</ span > < span class ="p "> ([],</ span > < span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="n "> bounds</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="n "> bounds</ span > < span class ="o "> .</ span > < span class ="n "> dtype</ span > < span class ="p "> )</ span >
1103
1138
1139
+ < span class ="n "> timeout_sec</ span > < span class ="o "> =</ span > < span class ="n "> timeout_sec</ span > < span class ="o "> /</ span > < span class ="n "> q</ span > < span class ="k "> if</ span > < span class ="n "> timeout_sec</ span > < span class ="k "> else</ span > < span class ="kc "> None</ span >
1104
1140
< span class ="k "> for</ span > < span class ="n "> _</ span > < span class ="ow "> in</ span > < span class ="nb "> range</ span > < span class ="p "> (</ span > < span class ="n "> q</ span > < span class ="p "> ):</ span >
1105
1141
< span class ="n "> candidate</ span > < span class ="p "> ,</ span > < span class ="n "> acq_value</ span > < span class ="o "> =</ span > < span class ="n "> optimize_acqf_mixed</ span > < span class ="p "> (</ span >
1106
1142
< span class ="n "> acq_function</ span > < span class ="o "> =</ span > < span class ="n "> acq_function</ span > < span class ="p "> ,</ span >
@@ -1115,8 +1151,12 @@ <h1>Source code for botorch.optim.optimize</h1><div class="highlight"><pre>
1115
1151
< span class ="n "> nonlinear_inequality_constraints</ span > < span class ="o "> =</ span > < span class ="n "> nonlinear_inequality_constraints</ span > < span class ="p "> ,</ span >
1116
1152
< span class ="n "> post_processing_func</ span > < span class ="o "> =</ span > < span class ="n "> post_processing_func</ span > < span class ="p "> ,</ span >
1117
1153
< span class ="n "> batch_initial_conditions</ span > < span class ="o "> =</ span > < span class ="n "> batch_initial_conditions</ span > < span class ="p "> ,</ span >
1154
+ < span class ="n "> gen_candidates</ span > < span class ="o "> =</ span > < span class ="n "> gen_candidates</ span > < span class ="p "> ,</ span >
1118
1155
< span class ="n "> ic_generator</ span > < span class ="o "> =</ span > < span class ="n "> ic_generator</ span > < span class ="p "> ,</ span >
1119
1156
< span class ="n "> ic_gen_kwargs</ span > < span class ="o "> =</ span > < span class ="n "> ic_gen_kwargs</ span > < span class ="p "> ,</ span >
1157
+ < span class ="n "> timeout_sec</ span > < span class ="o "> =</ span > < span class ="n "> timeout_sec</ span > < span class ="p "> ,</ span >
1158
+ < span class ="n "> retry_on_optimization_warning</ span > < span class ="o "> =</ span > < span class ="n "> retry_on_optimization_warning</ span > < span class ="p "> ,</ span >
1159
+ < span class ="n "> return_best_only</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span >
1120
1160
< span class ="p "> )</ span >
1121
1161
< span class ="n "> candidates</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> cat</ span > < span class ="p "> ([</ span > < span class ="n "> candidates</ span > < span class ="p "> ,</ span > < span class ="n "> candidate</ span > < span class ="p "> ],</ span > < span class ="n "> dim</ span > < span class ="o "> =-</ span > < span class ="mi "> 2</ span > < span class ="p "> )</ span >
1122
1162
< span class ="n "> acq_function</ span > < span class ="o "> .</ span > < span class ="n "> set_X_pending</ span > < span class ="p "> (</ span >
0 commit comments