Skip to content

Commit a6dc087

Browse files
committed
update bounding logic
1 parent 9190552 commit a6dc087

1 file changed

Lines changed: 42 additions & 26 deletions

File tree

py/dynesty/sampler.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,41 @@
2020
get_random_generator)
2121

2222
from .bounding import (UnitCube, Ellipsoid, MultiEllipsoid, RadFriends,
23-
SupFriends)
23+
SupFriends, Bound)
2424
from .utils import (get_enlarge_bootstrap, save_sampler, restore_sampler)
2525

2626
__all__ = ["Sampler"]
2727

2828
SAMPLER_LIST = ['rwalk', 'unif', 'rslice', 'slice']
2929

3030

31+
def _get_bound(bounding, ndim):
32+
if isinstance(bounding, str):
33+
if bounding not in ['none', 'single', 'multi', 'balls', 'cubes']:
34+
raise ValueError('Unsupported bounding type')
35+
elif isinstance(bounding, Bound):
36+
pass
37+
else:
38+
raise ValueError('Unsupported bounding type')
39+
40+
if bounding == 'none':
41+
bound = UnitCube(ndim)
42+
elif bounding == 'single':
43+
bound = Ellipsoid(np.zeros(ndim) + .5, np.identity(ndim) * ndim / 4)
44+
# this is ellipsoid in the center of the cube that contains
45+
# the whole cube
46+
elif bounding == 'multi':
47+
bound = MultiEllipsoid(ctrs=[np.zeros(ndim) + .5],
48+
covs=[np.identity(ndim) * ndim / 4])
49+
# this is ellipsoid in the center of the cube that contains
50+
# the whole cube
51+
elif bounding == 'balls':
52+
bound = RadFriends(ndim)
53+
elif bounding == 'cubes':
54+
bound = SupFriends(ndim)
55+
return bound
56+
57+
3158
class Sampler:
3259
"""
3360
The basic sampler object that performs the actual nested sampling.
@@ -189,12 +216,13 @@ def __init__(self,
189216
self.first_bound_update_ncall = first_update.get(
190217
'min_ncall', 2 * self.nlive)
191218
self.first_bound_update_eff = first_update.get('min_eff', 10.)
192-
193219
self.logl_first_update = None
220+
self.ncall_at_last_update = 0
221+
194222
self.unit_cube_sampling = True
195-
self.bound_list = [UnitCube(self.ncdim)] # bounding distributions
223+
self.bound = UnitCube(self.ncdim)
224+
self.bound_list = [self.bound] # bounding distributions
196225
self.nbound = 1 # total number of unique bounding distributions
197-
self.ncall_at_last_update = 0
198226

199227
self.logvol_init = logvol_init
200228

@@ -212,27 +240,10 @@ def __init__(self,
212240

213241
self.cite = self.kwargs.get('cite')
214242

215-
if bounding not in ['none', 'single', 'multi', 'balls', 'cubes']:
216-
raise ValueError('Unsupported bounding type')
217243
self.bounding = bounding
218-
if bounding == 'none':
219-
self.bound = UnitCube(self.ncdim)
220-
elif bounding == 'single':
221-
self.bound = Ellipsoid(
222-
np.zeros(self.ncdim) + .5,
223-
np.identity(self.ncdim) * self.ncdim / 4)
224-
# this is ellipsoid in the center of the cube that contains
225-
# the whole cube
226-
elif bounding == 'multi':
227-
self.bound = MultiEllipsoid(
228-
ctrs=[np.zeros(self.ncdim) + .5],
229-
covs=[np.identity(self.ncdim) * self.ncdim / 4])
230-
# this is ellipsoid in the center of the cube that contains
231-
# the whole cube
232-
elif bounding == 'balls':
233-
self.bound = RadFriends(self.ncdim)
234-
elif bounding == 'cubes':
235-
self.bound = SupFriends(self.ncdim)
244+
self.bound_next = _get_bound(bounding, ndim)
245+
# the reason I do not set it as self.bound
246+
# because we start from unit cube
236247

237248
def save(self, fname):
238249
"""
@@ -467,12 +478,17 @@ def update_bound_if_needed(self, loglstar, ncall=None, force=False):
467478
else:
468479
subset = slice(None)
469480
if self.unit_cube_sampling:
481+
# done with unit cube
482+
# updating the bound and internal sampler
470483
self.unit_cube_sampling = False
471484
self.logl_first_update = loglstar
485+
self.bound = self.bound_next
472486
self.internal_sampler = self.internal_sampler_next
473-
bound = self.update_bound(subset=subset)
487+
self.bound_next = None
488+
self.internal_sampler_next = None
489+
self.update_bound(subset=subset)
474490
if self.save_bounds:
475-
self.bound_list.append(bound)
491+
self.bound_list.append(self.bound)
476492
self.nbound += 1
477493
self.ncall_at_last_update = ncall
478494

0 commit comments

Comments
 (0)