Skip to content

Commit 37b31de

Browse files
authored
Unify sampler initialisation
Also there was bug fix in the boundary isinstance check, and change how blob is propagated to samplers
1 parent f9d797b commit 37b31de

3 files changed

Lines changed: 261 additions & 247 deletions

File tree

py/dynesty/dynamicsampler.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -683,22 +683,23 @@ def __init__(self,
683683
ndim,
684684
sampling,
685685
bounding,
686-
ncdim=None,
687686
nlive0=None,
688-
kwargs=None,
689-
queue_size=None,
687+
ncdim=None,
688+
rstate=None,
690689
pool=None,
691690
use_pool=None,
692-
rstate=None,
691+
queue_size=None,
693692
bound_update_interval_ratio=None,
694-
first_bound_update=None):
693+
first_bound_update=None,
694+
kwargs=None,
695+
blob=None):
695696

696697
# distributions
697698
self.loglikelihood = loglikelihood
698699
self.prior_transform = prior_transform
699700
self.ndim = ndim
700701
self.ncdim = ncdim
701-
self.blob = kwargs.get('blob') or False
702+
self.blob = blob or False
702703
# bounding/sampling
703704
self.bounding = bounding
704705
self.sampling = sampling
@@ -813,19 +814,18 @@ def restore(fname, pool=None):
813814
return restore_sampler(fname, pool=pool)
814815

815816
def __get_update_interval(self, update_interval, nlive):
816-
if not isinstance(update_interval, int):
817-
if isinstance(update_interval, float):
818-
cur_update_interval_ratio = update_interval
819-
elif update_interval is None:
820-
cur_update_interval_ratio = self.bound_update_interval_ratio
821-
else:
822-
raise RuntimeError(
823-
str.format('Weird update_interval value {}',
824-
update_interval))
825-
update_interval = int(
826-
max(
827-
min(np.round(cur_update_interval_ratio * nlive),
828-
sys.maxsize), 1))
817+
if update_interval is None:
818+
cur_update_interval_ratio = self.bound_update_interval_ratio
819+
elif isinstance(update_interval, int):
820+
cur_update_interval_ratio = update_interval / nlive
821+
elif isinstance(update_interval, float):
822+
cur_update_interval_ratio = update_interval
823+
else:
824+
raise RuntimeError(
825+
str.format('Weird update_interval value {}', update_interval))
826+
update_interval = int(
827+
max(min(np.round(cur_update_interval_ratio * nlive), sys.maxsize),
828+
1))
829829
return update_interval
830830

831831
def reset(self):

0 commit comments

Comments
 (0)