@@ -34,7 +34,11 @@ def maxima(self, bounds, num_starts=5, num_samples=1024, method="L-BFGS-B",
3434
3535 assert num_samples is not None , "`num_samples` must be specified!"
3636 assert num_samples > 0 , "`num_samples` must be positive integer!"
37- assert num_starts is None or num_samples >= num_starts , \
37+
38+ assert num_starts is not None , "`num_starts` must be specified!"
39+ assert num_starts >= 0 , "`num_starts` must be nonnegative integer!"
40+
41+ assert num_samples >= num_starts , \
3842 "number of random samples (`num_samples`) must be " \
3943 "greater than number of starting points (`num_starts`)"
4044
@@ -43,11 +47,13 @@ def maxima(self, bounds, num_starts=5, num_samples=1024, method="L-BFGS-B",
4347 # TODO(LT): Allow alternative arbitary generator function callbacks
4448 # to support e.g. Gaussian sampling, low-discrepancy sequences, etc.
4549 X_init = random_state .uniform (low = low , high = high , size = (num_samples , dim ))
46- y_init = self .predict (X_init ).squeeze (axis = - 1 )
50+ z_init = self .predict (X_init ).squeeze (axis = - 1 )
51+ # the function to minimize is negative of the classifier output
52+ f_init = - z_init
4753
4854 results = []
49- if num_starts is not None and num_starts > 0 :
50- ind = np .argpartition (y_init , kth = num_starts - 1 , axis = None )
55+ if num_starts > 0 :
56+ ind = np .argpartition (f_init , kth = num_starts - 1 , axis = None )
5157 for i in range (num_starts ):
5258 x0 = X_init [ind [i ]]
5359 result = minimize (self ._func_min , x0 = x0 , method = method ,
@@ -59,10 +65,9 @@ def maxima(self, bounds, num_starts=5, num_samples=1024, method="L-BFGS-B",
5965 f"iterations: { result .nit :02d} , "
6066 f"status: { result .status } ({ result .message } )" )
6167 else :
62- for i in range (num_samples ):
63- result = OptimizeResult (x = X_init [i ], fun = y_init [i ],
64- success = True )
65- results .append (result )
68+ i = np .argmin (f_init , axis = None )
69+ result = OptimizeResult (x = X_init [i ], fun = f_init [i ], success = True )
70+ results .append (result )
6671
6772 return results
6873
0 commit comments