@@ -53,6 +53,7 @@ def __init__(
5353 max_asks : int | None = None ,
5454 keep_most_recent : int | None = None ,
5555 min_post_range : float | None = None ,
56+ log_post_var : bool = False ,
5657 name : str = "" ,
5758 run_indefinitely : bool = False ,
5859 transforms : ChainedInputTransform = ChainedInputTransform (** {}),
@@ -82,6 +83,7 @@ def __init__(
8283 as data collected from later trials. When None, the model is fitted on all data.
8384 min_post_range (float, optional): Experimental. The required difference between the posterior's minimum and maximum value in
8485 probablity space before the strategy will finish. Ignored if None (default).
86+ log_post_var (bool): Whether to log the posterior prediction variance (as it would be used for min_post_range). Defaults to False.
8587 name (str): The name of the strategy. Defaults to the empty string.
8688 run_indefinitely (bool): If true, the strategy will run indefinitely until finish() is explicitly called. Other stopping criteria will
8789 be ignored. Defaults to False.
@@ -161,8 +163,11 @@ def __init__(
161163 self .transforms = transforms
162164
163165 self .min_post_range = min_post_range
164- if self .min_post_range is not None :
165- assert model is not None , "min_post_range must be None if model is None!"
166+ self .log_post_var = log_post_var
167+ if self .min_post_range is not None or self .log_post_var :
168+ assert (
169+ model is not None
170+ ), "posterior range cannot be evaluated if model is None!"
166171 self .eval_grid = make_scaled_sobol (
167172 lb = self .lb , ub = self .ub , size = self ._n_eval_points
168173 )
@@ -429,16 +434,23 @@ def finished(self) -> bool:
429434 else :
430435 sufficient_outcomes = True
431436
432- if self .min_post_range is not None :
437+ if self .min_post_range is not None or self . log_post_var :
433438 assert (
434439 self .model is not None
435440 ), "model is None! Cannot predict without a model!"
436- fmean , _ = self .model .predict (self .eval_grid , probability_space = True )
437- meets_post_range = bool (
438- ((fmean .max () - fmean .min ()) >= self .min_post_range ).item ()
439- )
441+ fmean , fvar = self .model .predict (self .eval_grid , probability_space = True )
442+ post_range = fmean .max () - fmean .min ()
443+ else :
444+ post_range = None
445+
446+ if post_range is not None :
447+ logger .info (f"Mean posterior variance = { fvar .mean ().item ()} " )
448+
449+ if self .min_post_range :
450+ meets_post_range = bool ((post_range >= self .min_post_range ).item ())
440451 else :
441452 meets_post_range = True
453+
442454 finished = (
443455 self ._count >= self .min_asks
444456 and self .n >= self .min_total_tells
0 commit comments