Skip to content

Commit 95eadf3

Browse files
authored
First version of an MCMC sampler (#829)
* First version of an MCMC sampler * More explicit check for free shared params * Progress bar, example notebook, and first paralellization attempt * Return burn in and changes requested by Thomas * Set random state to None and clean up MCMC notebook * Check for metric and correct if chi2 * Add sampling algorithm argument and simplify code
1 parent e6bff6e commit 95eadf3

File tree

3 files changed

+659
-6
lines changed

3 files changed

+659
-6
lines changed

pisa/analysis/analysis.py

+135
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pisa.utils.comparisons import recursiveEquality, FTYPE_PREC, ALLCLOSE_KW
3030
from pisa.utils.log import logging, set_verbosity
3131
from pisa.utils.fileio import to_file
32+
from pisa.utils.random_numbers import get_random_state
3233
from pisa.utils.stats import (METRICS_TO_MAXIMIZE, METRICS_TO_MINIMIZE,
3334
LLH_METRICS, CHI2_METRICS, weighted_chi2,
3435
it_got_better, is_metric_to_maximize)
@@ -2681,6 +2682,140 @@ def _minimizer_callback(self, xk, **unused_kwargs): # pylint: disable=unused-arg
26812682
"""
26822683
self._nit += 1
26832684

2685+
def MCMC_sampling(self, data_dist, hypo_maker, metric, nwalkers, burnin, nsteps,
2686+
return_burn_in=False, random_state=None, sampling_algorithm=None):
2687+
"""Performs MCMC sampling. Only supports serial (single CPU) execution at the
2688+
moment. See issue #830.
2689+
2690+
Parameters
2691+
----------
2692+
2693+
data_dist : Sequence of MapSets or MapSet
2694+
Data distribution to be fit. Can be an actual-, Asimov-, or pseudo-data
2695+
distribution (where the latter two are derived from simulation and so aren't
2696+
technically "data").
2697+
2698+
hypo_maker : Detectors or DistributionMaker
2699+
Creates the per-bin expectation values per map based on its param values.
2700+
Free params in the `hypo_maker` are modified by the minimizer to achieve a
2701+
"best" fit.
2702+
2703+
metric : string or iterable of strings
2704+
Metric by which to evaluate the fit. See documentation of Map.
2705+
2706+
nwalkers : int
2707+
Number of walkers
2708+
2709+
burnin : int
2710+
Number of steps in burn in phase
2711+
2712+
nSteps : int
2713+
Number of steps after burn in
2714+
2715+
return_burn_in : bool
2716+
Also return the steps of the burn in phase. Default is False.
2717+
2718+
random_state : None or type accepted by utils.random_numbers.get_random_state
2719+
Random state of the walker starting points. Default is None.
2720+
2721+
sampling_algorithm : None or emcee.moves object
2722+
Sampling algorithm used by the emcee sampler. None means to use the default which
2723+
is a Goodman & Weare “stretch move” with parallelization.
2724+
See https://emcee.readthedocs.io/en/stable/user/moves/#moves-user to learn more
2725+
about the emcee sampling algorithms.
2726+
2727+
Returns
2728+
-------
2729+
2730+
scaled_chain : numpy array
2731+
Array containing all points in the parameter space visited by each walker.
2732+
It is sorted by steps, so all the first steps of all walkers come first.
2733+
To for example get all values of the Nth parameter and the ith walker, use
2734+
scaled_chain[i::nwalkers, N].
2735+
2736+
scaled_chain_burnin : numpy array (optional)
2737+
Same as scaled_chain, but for the burn in phase.
2738+
2739+
"""
2740+
import emcee
2741+
2742+
assert 'llh' in metric or 'chi2' in metric, 'Use either a llh or chi2 metric'
2743+
if 'chi2' in metric:
2744+
warnings.warn("You are using a chi2 metric for the MCMC sampling."
2745+
"The sampler will assume that llh=0.5*chi2.")
2746+
2747+
ndim = len(hypo_maker.params.free)
2748+
bounds = np.repeat([[0,1]], ndim, axis=0)
2749+
rs = get_random_state(random_state)
2750+
p0 = rs.random(ndim * nwalkers).reshape((nwalkers, ndim))
2751+
2752+
def func(scaled_param_vals, bounds, data_dist, hypo_maker, metric):
2753+
"""Function called by the MCMC sampler. Similar to _minimizer_callable it
2754+
returns the current metric value + prior penalties.
2755+
2756+
"""
2757+
if np.any(scaled_param_vals > np.array(bounds)[:, 1]) or np.any(scaled_param_vals < np.array(bounds)[:, 0]):
2758+
return -np.inf
2759+
sign = +1 if metric in METRICS_TO_MAXIMIZE else -1
2760+
if 'llh' in metric:
2761+
N = 1
2762+
elif 'chi2' in metric:
2763+
N = 0.5
2764+
2765+
hypo_maker._set_rescaled_free_params(scaled_param_vals) # pylint: disable=protected-access
2766+
hypo_asimov_dist = hypo_maker.get_outputs(return_sum=True)
2767+
metric_val = (
2768+
N * data_dist.metric_total(expected_values=hypo_asimov_dist, metric=metric)
2769+
+ hypo_maker.params.priors_penalty(metric=metric)
2770+
)
2771+
return sign*metric_val
2772+
2773+
sampler = emcee.EnsembleSampler(
2774+
nwalkers, ndim, func,
2775+
moves=sampling_algorithm,
2776+
args=[bounds, data_dist, hypo_maker, metric]
2777+
)
2778+
2779+
if self.pprint:
2780+
sys.stdout.write('Burn in')
2781+
sys.stdout.flush()
2782+
pos, prob, state = sampler.run_mcmc(p0, burnin, progress=self.pprint)
2783+
2784+
if return_burn_in:
2785+
flatchain_burnin = sampler.flatchain
2786+
scaled_chain_burnin = np.full_like(flatchain_burnin, np.nan, dtype=FTYPE)
2787+
param_copy_burnin = ParamSet(hypo_maker.params.free)
2788+
2789+
for s, sample in enumerate(flatchain_burnin):
2790+
for dim, rescaled_val in enumerate(sample):
2791+
param = param_copy_burnin[dim]
2792+
param._rescaled_value = rescaled_val
2793+
val = param.value.m
2794+
scaled_chain_burnin[s, dim] = val
2795+
2796+
sampler.reset()
2797+
if self.pprint:
2798+
sys.stdout.write('Main sampling')
2799+
sys.stdout.flush()
2800+
sampler.run_mcmc(pos, nsteps, progress=self.pprint)
2801+
2802+
flatchain = sampler.flatchain
2803+
scaled_chain = np.full_like(flatchain, np.nan, dtype=FTYPE)
2804+
param_copy = ParamSet(hypo_maker.params.free)
2805+
2806+
for s, sample in enumerate(flatchain):
2807+
for dim, rescaled_val in enumerate(sample):
2808+
param = param_copy[dim]
2809+
param._rescaled_value = rescaled_val
2810+
val = param.value.m
2811+
scaled_chain[s, dim] = val
2812+
2813+
if return_burn_in:
2814+
return scaled_chain, scaled_chain_burnin
2815+
else:
2816+
return scaled_chain
2817+
2818+
26842819
class Analysis(BasicAnalysis):
26852820
"""Analysis class for "canonical" IceCube/DeepCore/PINGU analyses.
26862821

pisa/core/detectors.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,19 @@ def __init__(self, pipelines, label=None, set_livetime_from_data=True, profile=F
8383
)
8484

8585
for sp in self.shared_params:
86-
n = 0
86+
N, n = 0, 0
8787
for distribution_maker in self._distribution_makers:
88+
if sp in distribution_maker.params.names:
89+
N += 1
8890
if sp in distribution_maker.params.free.names:
8991
n += 1
90-
if n < 2:
91-
raise NameError('Shared param %s only a free param in less than 2 detectors.' % sp)
92+
if N < 2:
93+
raise NameError(f'Shared param {sp} only exists in {N} detectors.')
94+
if n > 0 and n != N:
95+
raise NameError(f'Shared param {sp} exists in {N} detectors but only a free param in {n} detectors.')
9296

9397
self.init_params()
94-
98+
9599
def __repr__(self):
96100
return self.tabulate(tablefmt="presto")
97101

@@ -225,7 +229,7 @@ def shared_param_ind_list(self):
225229
spi = []
226230
for p_name in free_names:
227231
if p_name in self.shared_params:
228-
spi.append((free_names.index(p_name),self.shared_params.index(p_name)))
232+
spi.append((free_names.index(p_name), self.shared_params.index(p_name)))
229233
shared_param_ind_list.append(spi)
230234
return shared_param_ind_list
231235

@@ -347,7 +351,7 @@ def _set_rescaled_free_params(self, rvalues):
347351
for j in range(len(self._distribution_makers[i].params.free) - len(spi[i])):
348352
rp.append(rvalues.pop(0))
349353
for j in range(len(spi[i])):
350-
rp.insert(spi[i][j][0],sp[spi[i][j][1]])
354+
rp.insert(spi[i][j][0], sp[spi[i][j][1]])
351355
self._distribution_makers[i]._set_rescaled_free_params(rp)
352356

353357

pisa_examples/MCMC_example.ipynb

+514
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)