diff --git a/py/dynesty/dynamicsampler.py b/py/dynesty/dynamicsampler.py index 90242bce8..802801f01 100644 --- a/py/dynesty/dynamicsampler.py +++ b/py/dynesty/dynamicsampler.py @@ -606,7 +606,8 @@ def results(self): d = {} for k in [ 'nc', 'v', 'id', 'batch', 'it', 'u', 'n', 'logwt', 'logl', - 'logvol', 'logz', 'logzvar', 'h', 'batch_nlive', 'batch_bounds' + 'logvol', 'logz', 'logzvar', 'h', 'batch_nlive', 'batch_bounds', + 'scale', 'walks', ]: d[k] = np.array(self.saved_run.D[k]) @@ -633,6 +634,7 @@ def results(self): results.append( ('samples_bound', np.array(self.saved_run.D['boundidx']))) results.append(('scale', np.array(self.saved_run.D['scale']))) + results.append(('walks', np.array(self.saved_run.D['walks']))) return Results(results) @@ -865,7 +867,9 @@ def sample_initial(self, n=self.nlive_init, boundidx=results.boundidx, bounditer=results.bounditer, - scale=self.sampler.scale) + scale=self.sampler.scale, + walks=self.sampler.walks, + ) self.base_run.append(add_info) self.saved_run.append(add_info) @@ -907,7 +911,9 @@ def sample_initial(self, n=self.nlive_init - it, boundidx=results.boundidx, bounditer=results.bounditer, - scale=self.sampler.scale) + scale=self.sampler.scale, + walks=self.sampler.walks, + ) self.base_run.append(add_info) self.saved_run.append(add_info) @@ -1048,6 +1054,7 @@ def sample_batch(self, saved_logl = np.array(self.saved_run.D['logl']) saved_logvol = np.array(self.saved_run.D['logvol']) saved_scale = np.array(self.saved_run.D['scale']) + saved_walks = np.array(self.saved_run.D['walks']) nblive = self.nlive_init update_interval = self.__get_update_interval(update_interval, @@ -1141,6 +1148,7 @@ def sample_batch(self, self.new_logl_min = logl_min live_scale = saved_scale[subset0[0]] + live_walks = saved_walks[subset0[0]] # set the scale based on the lowest point # we are weighting each point by X_i to ensure @@ -1188,6 +1196,7 @@ def sample_batch(self, batch_sampler.live_v = live_v batch_sampler.live_logl = live_logl batch_sampler.scale = live_scale + batch_sampler.walks = live_walks # Trigger an update of the internal bounding distribution based # on the "new" set of live points. @@ -1292,7 +1301,9 @@ def sample_batch(self, n=nlive_new, boundidx=results.boundidx, bounditer=results.bounditer, - scale=batch_sampler.scale) + scale=batch_sampler.scale, + walks=batch_sampler.walks, + ) self.new_run.append(D) # Increment relevant counters. @@ -1328,7 +1339,9 @@ def sample_batch(self, n=nlive_new - it, boundidx=results.boundidx, bounditer=results.bounditer, - scale=batch_sampler.scale) + scale=batch_sampler.scale, + walks=batch_sampler.walks, + ) self.new_run.append(D) # Increment relevant counters. @@ -1358,7 +1371,7 @@ def combine_runs(self): for k in [ 'id', 'u', 'v', 'logl', 'nc', 'boundidx', 'it', 'bounditer', - 'n', 'scale' + 'n', 'scale', 'walks', ]: saved_d[k] = np.array(self.saved_run.D[k]) new_d[k] = np.array(self.new_run.D[k]) @@ -1411,7 +1424,7 @@ def combine_runs(self): for k in [ 'id', 'u', 'v', 'logl', 'nc', 'boundidx', 'it', - 'bounditer', 'scale' + 'bounditer', 'scale', 'walks', ]: add_info[k] = add_source[k][add_idx] self.saved_run.append(add_info) diff --git a/py/dynesty/dynesty.py b/py/dynesty/dynesty.py index 7762a5b27..f41671675 100644 --- a/py/dynesty/dynesty.py +++ b/py/dynesty/dynesty.py @@ -252,7 +252,13 @@ def NestedSampler(loglikelihood, update_func=None, ncdim=None, save_history=False, - history_filename=None): + history_filename=None, + adapt_scale=True, + adapt_walks=False, + adapt_time=None, + max_walks=1000, + target_accept=None, + ): """ Initializes and returns a sampler object for Static Nested Sampling. @@ -506,6 +512,10 @@ def prior_transform(u): if update_func is not None and not callable(update_func): raise ValueError("Unknown update function: '{0}'".format(update_func)) kwargs['update_func'] = update_func + kwargs['adapt_scale'] = adapt_scale + kwargs['adapt_walks'] = adapt_walks + kwargs['adapt_time'] = adapt_time + kwargs['max_walks'] = max_walks # Citation generator. kwargs['cite'] = _get_citations('static', bound, sample) @@ -561,6 +571,14 @@ def prior_transform(u): kwargs['fmove'] = fmove if max_move is not None: kwargs['max_move'] = max_move + if adapt_walks is not None: + kwargs['adapt_time'] = adapt_time + if max_walks is not None: + kwargs['max_walks'] = max_walks + if target_accept is not None: + kwargs['target_accept'] = target_accept + kwargs['adapt_scale'] = adapt_scale + kwargs['adapt_walks'] = adapt_walks update_interval_ratio = _get_update_interval_ratio(update_interval, sample, bound, ndim, nlive, @@ -662,7 +680,13 @@ def DynamicNestedSampler(loglikelihood, update_func=None, ncdim=None, save_history=False, - history_filename=None): + history_filename=None, + adapt_scale=True, + adapt_walks=False, + adapt_time=None, + max_walks=None, + target_accept=None, + ): """ Initializes and returns a sampler object for Dynamic Nested Sampling. @@ -959,6 +983,14 @@ def prior_transform(u): kwargs['fmove'] = fmove if max_move is not None: kwargs['max_move'] = max_move + if adapt_walks is not None: + kwargs['adapt_time'] = adapt_time + if max_walks is not None: + kwargs['max_walks'] = max_walks + if target_accept is not None: + kwargs['target_accept'] = target_accept + kwargs['adapt_scale'] = adapt_scale + kwargs['adapt_walks'] = adapt_walks # Set up parallel (or serial) evaluation. queue_size = _parse_pool_queue(pool, queue_size)[1] diff --git a/py/dynesty/nestedsamplers.py b/py/dynesty/nestedsamplers.py index 11bda465b..0b282fdd1 100644 --- a/py/dynesty/nestedsamplers.py +++ b/py/dynesty/nestedsamplers.py @@ -128,9 +128,18 @@ def __init__(self, self.compute_jac = self.kwargs.get('compute_jac', False) # Initialize random walk parameters. + self.adapt_walks = self.kwargs.get("adapt_walks", True) self.walks = max(2, self.kwargs.get('walks', 25)) + self.max_walks = self.kwargs.get("max_walks", 1000) + self.adapt_scale = self.kwargs.get("adapt_scale", True) self.facc = self.kwargs.get('facc', 0.5) self.facc = min(1., max(1. / self.walks, self.facc)) + self.adapt_time = self.kwargs.get("adapt_time", None) + if self.adapt_time is None: + self.adapt_time = self.nlive / 5 + self.target_accept = self.kwargs.get("target_accept", None) + if self.target_accept is None: + self.target_accept = self.walks * self.facc # Initialize slice parameters. self.slices = self.kwargs.get('slices', 5) @@ -148,6 +157,12 @@ def update_unif(self, blob): pass def update_rwalk(self, blob): + if self.adapt_scale: + self.update_rwalk_scale(blob) + if self.adapt_walks: + self.update_rwalk_walks(blob) + + def update_rwalk_scale(self, blob): """Update the random walk proposal scale based on the current number of accepted/rejected steps. For rwalk the scale is important because it @@ -173,6 +188,20 @@ def update_rwalk(self, blob): # here because our coefficients a_k do not obey \sum a_k^2 = \infty self.scale *= math.exp((facc - self.facc) / self.ncdim / self.facc) + def update_rwalk_walks(self, blob): + """Update the number of MCMC steps taken with the rwalk method. + This tries to keep the number of accepted steps at each iteration + approximately constant. + """ + accept = blob["accept"] + if accept == 0: + factor = 1.25 + else: + factor = (self.target_accept / accept) ** (1 / self.adapt_time) + estimated_steps = self.walks * factor + self.walks = max(min([self.max_walks, estimated_steps]), self.target_accept) + self.kwargs["walks"] = int(self.walks) + def update_slice(self, blob): """Update the slice proposal scale based on the relative size of the slices compared to our initial guess. diff --git a/py/dynesty/results.py b/py/dynesty/results.py index ca3686ab3..a0187618f 100644 --- a/py/dynesty/results.py +++ b/py/dynesty/results.py @@ -275,7 +275,8 @@ def print_fn_fallback(results, ('batch_nlive', 'array[int]', "The number of live points added in a given batch ???" "How is it different from samples_n", 'nbatch???'), - ('scale', 'array[float]', "Scalar scale applied for proposals", 'niter') + ('scale', 'array[float]', "Scalar scale applied for proposals", 'niter'), + ('walks', 'array[float]', "MCMC chain length for rwalk", 'niter'), ] diff --git a/py/dynesty/sampler.py b/py/dynesty/sampler.py index dfc056f84..9aaa91242 100644 --- a/py/dynesty/sampler.py +++ b/py/dynesty/sampler.py @@ -98,6 +98,7 @@ def __init__(self, loglikelihood, prior_transform, npdim, live_points, # set to none just for qa self.scale = None + self.walks = None self.method = None self.kwargs = {} @@ -466,7 +467,9 @@ def add_live_points(self): boundidx=boundidx, it=point_it, bounditer=bounditer, - scale=self.scale)) + scale=self.scale, + walks=self.walks, + )) self.eff = 100. * (self.it + i) / self.ncall # efficiency # Return our new "dead" point and ancillary quantities. @@ -768,7 +771,9 @@ def sample(self, nc=nc, it=worst_it, bounditer=bounditer, - scale=self.scale)) + scale=self.scale, + walks=self.walks, + )) # Update the live point (previously our "worst" point). self.live_u[worst] = u diff --git a/py/dynesty/utils.py b/py/dynesty/utils.py index f27be7a0e..e74985a26 100644 --- a/py/dynesty/utils.py +++ b/py/dynesty/utils.py @@ -188,7 +188,8 @@ def __init__(self, dynamic=False): 'it', # iteration the live (now dead) point was proposed 'n', # number of live points interior to dead point 'bounditer', # active bound at a specific iteration - 'scale' # scale factor at each iteration + 'scale', # scale factor at each iteration + 'walks', # number of steps taken at each iteration ] if dynamic: keys.extend([ diff --git a/tests/test_adapt.py b/tests/test_adapt.py new file mode 100644 index 000000000..cc797e7f6 --- /dev/null +++ b/tests/test_adapt.py @@ -0,0 +1,47 @@ +import numpy as np +import dynesty +import pytest +import itertools +from utils import get_rstate, get_printing +""" +Run a series of basic tests of the 2d eggbox +""" + +nlive = 1000 +printing = get_printing() + +# EGGBOX + + +# see 1306.2144 +def loglike_egg(x): + logl = ((2 + np.cos(x[0] / 2) * np.cos(x[1] / 2))**5) + return logl + + +def prior_transform_egg(x): + return x * 10 * np.pi + + +@pytest.mark.parametrize( + "scale,walks", + itertools.product([True, False], [True, False]) +) +def test_adapt(scale, walks): + # stress test various boundaries + ndim = 2 + rstate = get_rstate() + sampler = dynesty.NestedSampler(loglike_egg, + prior_transform_egg, + ndim, + nlive=nlive, + bound="single", + sample="rwalk", + rstate=rstate, + adapt_scale=scale, + adapt_walks=walks, + ) + sampler.run_nested(dlogz=0.01, print_progress=printing) + logz_truth = 235.856 + assert (abs(logz_truth - sampler.results.logz[-1]) < + 5. * sampler.results.logzerr[-1])