Skip to content

Commit f770e5b

Browse files
committed
get rid of hslice sampler as it was failing the tests and it just cannot be trusted
updated the interface, so that the bound must now provide the interface get_random_axes() that simplifies the logic in the sampler
1 parent 1723bcc commit f770e5b

7 files changed

Lines changed: 90 additions & 623 deletions

File tree

py/dynesty/bounding.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,50 @@
4141
]
4242

4343

44+
class Bound:
45+
"""
46+
Parameters
47+
----------
48+
ndim : int
49+
The number of dimensions of the unit cube.
50+
51+
"""
52+
53+
def __init__(self):
54+
pass
55+
56+
def contains(self, x):
57+
"""Checks if unit cube contains the point `x`."""
58+
pass
59+
60+
def sample(self, rstate=None):
61+
"""
62+
Draw a sample uniformly distributed within the unit cube.
63+
64+
Returns
65+
-------
66+
x : `~numpy.ndarray` with shape (ndim,)
67+
A coordinate within the unit cube.
68+
69+
"""
70+
pass
71+
72+
def samples(self, nsamples, rstate=None):
73+
"""
74+
Draw `nsamples` samples randomly distributed within the unit cube.
75+
76+
Returns
77+
-------
78+
x : `~numpy.ndarray` with shape (nsamples, ndim)
79+
A collection of coordinates within the unit cube.
80+
81+
"""
82+
pass
83+
84+
def get_random_axes(self, rstate):
85+
pass
86+
87+
4488
class UnitCube:
4589
"""
4690
An N-dimensional unit cube.
@@ -92,6 +136,9 @@ def update(self, points, rstate=None, bootstrap=0, pool=None):
92136
"""Filler function."""
93137
pass
94138

139+
def get_random_axes(self, rstate):
140+
return np.eye(self.n)
141+
95142

96143
class Ellipsoid:
97144
"""
@@ -325,6 +372,9 @@ def update(self,
325372
if mc_integrate:
326373
self.funit = self.unitcube_overlap(rstate=rstate)
327374

375+
def get_random_axes(self, rstate):
376+
return self.axes
377+
328378

329379
class MultiEllipsoid:
330380
"""
@@ -634,6 +684,13 @@ def update(self,
634684
self.logvol_tot, self.funit = self.monte_carlo_logvol(
635685
rstate=rstate, return_overlap=True)
636686

687+
def get_random_axes(self, rstate):
688+
probs = np.exp(self.logvols - self.logvol_tot)
689+
ell_idx = rand_choice(probs, rstate)
690+
# Choose axes.
691+
ax = self.ells[ell_idx].axes
692+
return ax
693+
637694

638695
class RadFriends:
639696
"""
@@ -899,6 +956,9 @@ def _get_covariance_from_clusters(self, points):
899956
i = j
900957
return self._get_covariance_from_all_points(overlapped_points)
901958

959+
def get_random_axes(self, rstate):
960+
return self.axes
961+
902962

903963
class SupFriends:
904964
"""
@@ -1165,6 +1225,9 @@ def _get_covariance_from_clusters(self, points):
11651225
i = j
11661226
return self._get_covariance_from_all_points(overlapped_points)
11671227

1228+
def get_random_axes(self, rstate):
1229+
return self.axes
1230+
11681231

11691232
##################
11701233
# HELPER FUNCTIONS

py/dynesty/dynamicsampler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def compute_weights(results):
5656
logwt = results.logwt
5757
samples_n = results.samples_n
5858

59-
if logz.ptp() == 0:
59+
if np.ptp(logz) == 0:
6060
# this pathological case can happen if all logl are very small
6161
# and all logz are very small and the same
6262
# then the calculation below failse
@@ -185,8 +185,6 @@ def _get_update_interval_ratio(update_interval, sample, bound, ndim, nlive,
185185
update_interval_frac = 0.9 * ndim * slices
186186
elif sample == 'rslice':
187187
update_interval_frac = 2.0 * slices
188-
elif sample == 'hslice':
189-
update_interval_frac = 25.0 * slices
190188
else:
191189
update_interval_frac = np.inf
192190
warnings.warn(

py/dynesty/dynesty.py

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,7 @@ def _get_citations(nested_type, bound, sampler):
9292
("Handley, Hobson & Lasenby (2015a)",
9393
"ui.adsabs.harvard.edu/abs/2015MNRAS.450L..61H"),
9494
("Handley, Hobson & Lasenby (2015b)",
95-
"ui.adsabs.harvard.edu/abs/2015MNRAS.453.4384H")],
96-
'hslice':
97-
[("Neal (2003)", "projecteuclid.org/euclid.aos/1056562461"),
98-
("Skilling (2012)", "aip.scitation.org/doi/abs/10.1063/1.3703630"),
99-
("Feroz & Skilling (2013)",
100-
"ui.adsabs.harvard.edu/abs/2013AIPC.1553..106F"),
101-
("Speagle (2020)", "ui.adsabs.harvard.edu/abs/2020MNRAS.493.3132S")]
95+
"ui.adsabs.harvard.edu/abs/2015MNRAS.453.4384H")]
10296
}
10397

10498
def reflist_tostring(x):
@@ -144,7 +138,7 @@ def reflist_tostring(x):
144138
return citations
145139

146140

147-
def _get_auto_sample(ndim, gradient):
141+
def _get_auto_sample(ndim):
148142
""" Decode which sampling method to use
149143
150144
Arguments:
@@ -157,10 +151,7 @@ def _get_auto_sample(ndim, gradient):
157151
elif 10 <= ndim <= 20:
158152
sample = 'rwalk'
159153
else:
160-
if gradient is None:
161-
sample = 'rslice'
162-
else:
163-
sample = 'hslice'
154+
sample = 'rslice'
164155
return sample
165156

166157

@@ -178,7 +169,7 @@ def _get_walks_slices(walks0, slices0, sample, ndim):
178169
"""
179170
walks, slices = None, None
180171
# see https://github.com/joshspeagle/dynesty/issues/289
181-
if sample in ['hslice', 'rslice']:
172+
if sample in ['rslice']:
182173
slices = 3 + ndim
183174
elif sample == 'slice':
184175
slices = 3
@@ -188,7 +179,7 @@ def _get_walks_slices(walks0, slices0, sample, ndim):
188179
walks = 20 + ndim
189180
slices = slices0 or slices
190181
walks = walks0 or walks
191-
if sample in ['hslice', 'rslice', 'slice'] and walks0 is not None:
182+
if sample in ['rslice', 'slice'] and walks0 is not None:
192183
warnings.warn('Specifying walks option while using slice sampler'
193184
' does not make sense')
194185
elif sample in ['rwalk'] and slices0 is not None:
@@ -533,10 +524,6 @@ def __new__(cls,
533524
logl_kwargs=None,
534525
ptform_args=None,
535526
ptform_kwargs=None,
536-
gradient=None,
537-
grad_args=None,
538-
grad_kwargs=None,
539-
compute_jac=False,
540527
enlarge=None,
541528
bootstrap=None,
542529
walks=None,
@@ -567,11 +554,11 @@ def __new__(cls,
567554

568555
# Sampling method.
569556
if sample == 'auto':
570-
sample = _get_auto_sample(ndim, gradient)
557+
sample = _get_auto_sample(ndim)
571558

572559
walks, slices = _get_walks_slices(walks, slices, sample, ndim)
573560

574-
if ncdim != ndim and sample in ['slice', 'hslice', 'rslice']:
561+
if ncdim != ndim and sample in ['slice', 'rslice']:
575562
raise ValueError('ncdim unsupported for slice sampling')
576563

577564
# Custom sampling function.
@@ -616,12 +603,6 @@ def __new__(cls,
616603
ptform_args = ptform_args or []
617604
ptform_kwargs = ptform_kwargs or {}
618605

619-
# gradient
620-
if grad_args is None:
621-
grad_args = []
622-
if grad_kwargs is None:
623-
grad_kwargs = {}
624-
625606
# Bounding distribution modifications.
626607
enlarge, bootstrap = get_enlarge_bootstrap(sample, enlarge, bootstrap)
627608
kwargs['enlarge'] = enlarge
@@ -669,15 +650,6 @@ def __new__(cls,
669650
or 'dynesty_logl_history.h5',
670651
pool=pool_logl)
671652

672-
# Add in gradient.
673-
if gradient is not None:
674-
grad = _function_wrapper(gradient,
675-
grad_args,
676-
grad_kwargs,
677-
name='gradient')
678-
kwargs['grad'] = grad
679-
kwargs['compute_jac'] = compute_jac
680-
681653
live_points, logvol_init, init_ncalls = _initialize_live_points(
682654
live_points,
683655
ptform,
@@ -741,10 +713,6 @@ def __init__(self,
741713
logl_kwargs=None,
742714
ptform_args=None,
743715
ptform_kwargs=None,
744-
gradient=None,
745-
grad_args=None,
746-
grad_kwargs=None,
747-
compute_jac=False,
748716
enlarge=None,
749717
bootstrap=None,
750718
walks=None,
@@ -776,11 +744,11 @@ def __init__(self,
776744

777745
# Sampling method.
778746
if sample == 'auto':
779-
sample = _get_auto_sample(ndim, gradient)
747+
sample = _get_auto_sample(ndim)
780748

781749
walks, slices = _get_walks_slices(walks, slices, sample, ndim)
782750

783-
if ncdim != ndim and sample in ['slice', 'hslice', 'rslice']:
751+
if ncdim != ndim and sample in ['slice', 'rslice']:
784752
raise ValueError('ncdim unsupported for slice sampling')
785753

786754
update_interval_ratio = _get_update_interval_ratio(
@@ -823,12 +791,6 @@ def __init__(self,
823791
ptform_args = ptform_args or []
824792
ptform_kwargs = ptform_kwargs or {}
825793

826-
# gradient
827-
if grad_args is None:
828-
grad_args = []
829-
if grad_kwargs is None:
830-
grad_kwargs = {}
831-
832794
# Bounding distribution modifications.
833795
enlarge, bootstrap = get_enlarge_bootstrap(sample, enlarge, bootstrap)
834796
kwargs['enlarge'] = enlarge
@@ -871,15 +833,6 @@ def __init__(self,
871833
save=save_history,
872834
blob=blob)
873835

874-
# Add in gradient.
875-
if gradient is not None:
876-
grad = _function_wrapper(gradient,
877-
grad_args,
878-
grad_kwargs,
879-
name='gradient')
880-
kwargs['grad'] = grad
881-
kwargs['compute_jac'] = compute_jac
882-
883836
# Initialize our nested sampler.
884837
super().__init__(loglike, ptform, ndim, bound, sample,
885838
update_interval_ratio, first_update, rstate,

0 commit comments

Comments
 (0)