2020 get_random_generator )
2121
2222from .bounding import (UnitCube , Ellipsoid , MultiEllipsoid , RadFriends ,
23- SupFriends )
23+ SupFriends , Bound )
2424from .utils import (get_enlarge_bootstrap , save_sampler , restore_sampler )
2525
2626__all__ = ["Sampler" ]
2727
2828SAMPLER_LIST = ['rwalk' , 'unif' , 'rslice' , 'slice' ]
2929
3030
31+ def _get_bound (bounding , ndim ):
32+ if isinstance (bounding , str ):
33+ if bounding not in ['none' , 'single' , 'multi' , 'balls' , 'cubes' ]:
34+ raise ValueError ('Unsupported bounding type' )
35+ elif isinstance (bounding , Bound ):
36+ pass
37+ else :
38+ raise ValueError ('Unsupported bounding type' )
39+
40+ if bounding == 'none' :
41+ bound = UnitCube (ndim )
42+ elif bounding == 'single' :
43+ bound = Ellipsoid (np .zeros (ndim ) + .5 , np .identity (ndim ) * ndim / 4 )
44+ # this is ellipsoid in the center of the cube that contains
45+ # the whole cube
46+ elif bounding == 'multi' :
47+ bound = MultiEllipsoid (ctrs = [np .zeros (ndim ) + .5 ],
48+ covs = [np .identity (ndim ) * ndim / 4 ])
49+ # this is ellipsoid in the center of the cube that contains
50+ # the whole cube
51+ elif bounding == 'balls' :
52+ bound = RadFriends (ndim )
53+ elif bounding == 'cubes' :
54+ bound = SupFriends (ndim )
55+ return bound
56+
57+
3158class Sampler :
3259 """
3360 The basic sampler object that performs the actual nested sampling.
@@ -189,12 +216,13 @@ def __init__(self,
189216 self .first_bound_update_ncall = first_update .get (
190217 'min_ncall' , 2 * self .nlive )
191218 self .first_bound_update_eff = first_update .get ('min_eff' , 10. )
192-
193219 self .logl_first_update = None
220+ self .ncall_at_last_update = 0
221+
194222 self .unit_cube_sampling = True
195- self .bound_list = [UnitCube (self .ncdim )] # bounding distributions
223+ self .bound = UnitCube (self .ncdim )
224+ self .bound_list = [self .bound ] # bounding distributions
196225 self .nbound = 1 # total number of unique bounding distributions
197- self .ncall_at_last_update = 0
198226
199227 self .logvol_init = logvol_init
200228
@@ -212,27 +240,10 @@ def __init__(self,
212240
213241 self .cite = self .kwargs .get ('cite' )
214242
215- if bounding not in ['none' , 'single' , 'multi' , 'balls' , 'cubes' ]:
216- raise ValueError ('Unsupported bounding type' )
217243 self .bounding = bounding
218- if bounding == 'none' :
219- self .bound = UnitCube (self .ncdim )
220- elif bounding == 'single' :
221- self .bound = Ellipsoid (
222- np .zeros (self .ncdim ) + .5 ,
223- np .identity (self .ncdim ) * self .ncdim / 4 )
224- # this is ellipsoid in the center of the cube that contains
225- # the whole cube
226- elif bounding == 'multi' :
227- self .bound = MultiEllipsoid (
228- ctrs = [np .zeros (self .ncdim ) + .5 ],
229- covs = [np .identity (self .ncdim ) * self .ncdim / 4 ])
230- # this is ellipsoid in the center of the cube that contains
231- # the whole cube
232- elif bounding == 'balls' :
233- self .bound = RadFriends (self .ncdim )
234- elif bounding == 'cubes' :
235- self .bound = SupFriends (self .ncdim )
244+ self .bound_next = _get_bound (bounding , ndim )
245+ # the reason I do not set it as self.bound
246+ # because we start from unit cube
236247
237248 def save (self , fname ):
238249 """
@@ -467,12 +478,17 @@ def update_bound_if_needed(self, loglstar, ncall=None, force=False):
467478 else :
468479 subset = slice (None )
469480 if self .unit_cube_sampling :
481+ # done with unit cube
482+ # updating the bound and internal sampler
470483 self .unit_cube_sampling = False
471484 self .logl_first_update = loglstar
485+ self .bound = self .bound_next
472486 self .internal_sampler = self .internal_sampler_next
473- bound = self .update_bound (subset = subset )
487+ self .bound_next = None
488+ self .internal_sampler_next = None
489+ self .update_bound (subset = subset )
474490 if self .save_bounds :
475- self .bound_list .append (bound )
491+ self .bound_list .append (self . bound )
476492 self .nbound += 1
477493 self .ncall_at_last_update = ncall
478494
0 commit comments