Skip to content

Commit a1fc0c3

Browse files
committed
cleanup reset functions in both samplers
also i don't explicitely reset in sample_initial in dynamic_sampler that recipe for bugs
1 parent af40047 commit a1fc0c3

2 files changed

Lines changed: 53 additions & 52 deletions

File tree

py/dynesty/dynamicsampler.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ def __init__(self,
716716
bound_bootstrap)
717717

718718
# TODO FIX
719-
# self.cite = self.kwargs.get('cite')
719+
self.cite = cite
720720

721721
# random state
722722
self.rstate = rstate
@@ -835,17 +835,25 @@ def reset(self):
835835
"""Re-initialize the sampler."""
836836

837837
# sampling
838-
self.it = 1
839-
self.batch = 0
840-
self.ncall = 0
841-
self.bound_list = []
842-
self.eff = 1.
843-
self.base = False
844-
845-
self.saved_run = RunRecord(dynamic=True)
846-
self.base_run = RunRecord(dynamic=True)
847-
self.new_run = None
848-
self.new_logl_min, self.new_logl_max = -np.inf, np.inf
838+
DynamicSampler.__init__(
839+
self,
840+
self.loglikelihood,
841+
self.prior_transform,
842+
self.ndim,
843+
self.sampling,
844+
self.bounding,
845+
nlive0=self.nlive0,
846+
ncdim=self.ncdim,
847+
rstate=self.rstate,
848+
pool=self.pool,
849+
use_pool=self.use_pool,
850+
queue_size=self.queue_size,
851+
bound_update_interval_ratio=self.bound_update_interval_ratio,
852+
first_bound_update=self.first_bound_update,
853+
bound_bootstrap=self.bound_bootstrap,
854+
bound_enlarge=self.bound_enlarge,
855+
blob=self.blob,
856+
cite=self.cite)
849857

850858
@property
851859
def results(self):
@@ -1042,9 +1050,6 @@ def sample_initial(self,
10421050
warnings.warn("Beware: `nlive_init <= 2 * ndim`!")
10431051

10441052
if not resume:
1045-
# Reset saved results to avoid any possible conflicts.
1046-
self.reset()
1047-
10481053
(self.live_u, self.live_v, self.live_logl,
10491054
blobs), logvol_init, init_ncalls = _initialize_live_points(
10501055
live_points,

py/dynesty/sampler.py

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,8 @@ class Sampler:
270270
271271
live_points : list of 3 or 4 `~numpy.ndarray`
272272
Each with shape (nlive, ndim) for the first three arrays.
273-
If `blob=True`, a fourth array of blobs (arbitrary shape) may be included.
273+
If `blob=True`, a fourth array of blobs (arbitrary shape) may be
274+
included.
274275
275276
sampling : {`'unif'`, `'rwalk'`, `'slice'`, `'rslice'`}
276277
Sampling Method used to sample uniformly within the likelihood
@@ -390,6 +391,7 @@ def __init__(self,
390391

391392
# bounding updates
392393
self.bound_update_interval = bound_update_interval
394+
self.first_bound_update = first_bound_update
393395
self.first_bound_update_ncall = first_bound_update.get(
394396
'min_ncall', 2 * self.nlive)
395397
self.first_bound_update_eff = first_bound_update.get('min_eff', 10.)
@@ -512,41 +514,36 @@ def __getstate__(self):
512514
def reset(self):
513515
"""Re-initialize the sampler."""
514516

515-
(self.live_u, self.live_v, self.live_logl,
516-
self.live_blobs), logvol_init, init_ncalls = _initialize_live_points(
517-
None,
518-
self.prior_transform,
519-
self.loglikelihood,
520-
self.mapper,
521-
nlive=self.nlive,
522-
ndim=self.ndim,
523-
rstate=self.rstate,
524-
blob=self.blob,
525-
use_pool_ptform=self.use_pool_ptform)
526-
self.logvol_init = logvol_init
527-
self.live_bound = np.zeros(self.nlive, dtype=int)
528-
self.live_it = np.zeros(self.nlive, dtype=int)
529-
530-
# parallelism
531-
self.queue = []
532-
self.nqueue = 0
533-
534-
# sampling
535-
self.it = 1
536-
self.ncall = init_ncalls
537-
self.bound = UnitCube(self.ncdim)
538-
self.bound_list = [self.bound]
539-
self.internal_sampler = UnitCubeSampler(ndim=self.ndim)
540-
self.nbound = 1
541-
self.unit_cube_sampling = True
542-
self.added_live = False
543-
544-
self.plateau_mode = False
545-
self.plateau_counter = None
546-
self.plateau_logdvol = None
547-
548-
# results
549-
self.saved_run = RunRecord()
517+
# (self.live_u, self.live_v, self.live_logl, self.live_blobs)
518+
live_points, logvol_init, init_ncalls = _initialize_live_points(
519+
None,
520+
self.prior_transform,
521+
self.loglikelihood,
522+
self.mapper,
523+
nlive=self.nlive,
524+
ndim=self.ndim,
525+
rstate=self.rstate,
526+
blob=self.blob,
527+
use_pool_ptform=self.use_pool_ptform)
528+
529+
self.__init__(self.loglikelihood,
530+
self.prior_transform,
531+
self.ndim,
532+
live_points,
533+
self.sampling,
534+
self.bounding,
535+
ncdim=self.ncdim,
536+
rstate=self.rstate,
537+
pool=self.pool,
538+
use_pool=self.use_pool,
539+
queue_size=self.queue_size,
540+
bound_update_interval=self.bound_update_interval,
541+
first_bound_update=self.first_bound_update,
542+
bound_bootstrap=self.bound_bootstrap,
543+
bound_enlarge=self.bound_enlarge,
544+
blob=self.blob,
545+
cite=self.cite,
546+
logvol_init=logvol_init)
550547

551548
@property
552549
def results(self):
@@ -1259,7 +1256,6 @@ def run_nested(self,
12591256
saved in the end of the run irrespective of checkpoint_every.
12601257
"""
12611258

1262-
12631259
# Define our stopping criteria.
12641260
if dlogz is None:
12651261
if add_live:

0 commit comments

Comments
 (0)