|
29 | 29 | from pisa.utils.comparisons import recursiveEquality, FTYPE_PREC, ALLCLOSE_KW
|
30 | 30 | from pisa.utils.log import logging, set_verbosity
|
31 | 31 | from pisa.utils.fileio import to_file
|
| 32 | +from pisa.utils.random_numbers import get_random_state |
32 | 33 | from pisa.utils.stats import (METRICS_TO_MAXIMIZE, METRICS_TO_MINIMIZE,
|
33 | 34 | LLH_METRICS, CHI2_METRICS, weighted_chi2,
|
34 | 35 | it_got_better, is_metric_to_maximize)
|
@@ -2681,6 +2682,140 @@ def _minimizer_callback(self, xk, **unused_kwargs): # pylint: disable=unused-arg
|
2681 | 2682 | """
|
2682 | 2683 | self._nit += 1
|
2683 | 2684 |
|
| 2685 | + def MCMC_sampling(self, data_dist, hypo_maker, metric, nwalkers, burnin, nsteps, |
| 2686 | + return_burn_in=False, random_state=None, sampling_algorithm=None): |
| 2687 | + """Performs MCMC sampling. Only supports serial (single CPU) execution at the |
| 2688 | + moment. See issue #830. |
| 2689 | +
|
| 2690 | + Parameters |
| 2691 | + ---------- |
| 2692 | +
|
| 2693 | + data_dist : Sequence of MapSets or MapSet |
| 2694 | + Data distribution to be fit. Can be an actual-, Asimov-, or pseudo-data |
| 2695 | + distribution (where the latter two are derived from simulation and so aren't |
| 2696 | + technically "data"). |
| 2697 | +
|
| 2698 | + hypo_maker : Detectors or DistributionMaker |
| 2699 | + Creates the per-bin expectation values per map based on its param values. |
| 2700 | + Free params in the `hypo_maker` are modified by the minimizer to achieve a |
| 2701 | + "best" fit. |
| 2702 | +
|
| 2703 | + metric : string or iterable of strings |
| 2704 | + Metric by which to evaluate the fit. See documentation of Map. |
| 2705 | +
|
| 2706 | + nwalkers : int |
| 2707 | + Number of walkers |
| 2708 | +
|
| 2709 | + burnin : int |
| 2710 | + Number of steps in burn in phase |
| 2711 | +
|
| 2712 | + nSteps : int |
| 2713 | + Number of steps after burn in |
| 2714 | + |
| 2715 | + return_burn_in : bool |
| 2716 | + Also return the steps of the burn in phase. Default is False. |
| 2717 | +
|
| 2718 | + random_state : None or type accepted by utils.random_numbers.get_random_state |
| 2719 | + Random state of the walker starting points. Default is None. |
| 2720 | + |
| 2721 | + sampling_algorithm : None or emcee.moves object |
| 2722 | + Sampling algorithm used by the emcee sampler. None means to use the default which |
| 2723 | + is a Goodman & Weare “stretch move” with parallelization. |
| 2724 | + See https://emcee.readthedocs.io/en/stable/user/moves/#moves-user to learn more |
| 2725 | + about the emcee sampling algorithms. |
| 2726 | +
|
| 2727 | + Returns |
| 2728 | + ------- |
| 2729 | +
|
| 2730 | + scaled_chain : numpy array |
| 2731 | + Array containing all points in the parameter space visited by each walker. |
| 2732 | + It is sorted by steps, so all the first steps of all walkers come first. |
| 2733 | + To for example get all values of the Nth parameter and the ith walker, use |
| 2734 | + scaled_chain[i::nwalkers, N]. |
| 2735 | +
|
| 2736 | + scaled_chain_burnin : numpy array (optional) |
| 2737 | + Same as scaled_chain, but for the burn in phase. |
| 2738 | +
|
| 2739 | + """ |
| 2740 | + import emcee |
| 2741 | + |
| 2742 | + assert 'llh' in metric or 'chi2' in metric, 'Use either a llh or chi2 metric' |
| 2743 | + if 'chi2' in metric: |
| 2744 | + warnings.warn("You are using a chi2 metric for the MCMC sampling." |
| 2745 | + "The sampler will assume that llh=0.5*chi2.") |
| 2746 | + |
| 2747 | + ndim = len(hypo_maker.params.free) |
| 2748 | + bounds = np.repeat([[0,1]], ndim, axis=0) |
| 2749 | + rs = get_random_state(random_state) |
| 2750 | + p0 = rs.random(ndim * nwalkers).reshape((nwalkers, ndim)) |
| 2751 | + |
| 2752 | + def func(scaled_param_vals, bounds, data_dist, hypo_maker, metric): |
| 2753 | + """Function called by the MCMC sampler. Similar to _minimizer_callable it |
| 2754 | + returns the current metric value + prior penalties. |
| 2755 | + |
| 2756 | + """ |
| 2757 | + if np.any(scaled_param_vals > np.array(bounds)[:, 1]) or np.any(scaled_param_vals < np.array(bounds)[:, 0]): |
| 2758 | + return -np.inf |
| 2759 | + sign = +1 if metric in METRICS_TO_MAXIMIZE else -1 |
| 2760 | + if 'llh' in metric: |
| 2761 | + N = 1 |
| 2762 | + elif 'chi2' in metric: |
| 2763 | + N = 0.5 |
| 2764 | + |
| 2765 | + hypo_maker._set_rescaled_free_params(scaled_param_vals) # pylint: disable=protected-access |
| 2766 | + hypo_asimov_dist = hypo_maker.get_outputs(return_sum=True) |
| 2767 | + metric_val = ( |
| 2768 | + N * data_dist.metric_total(expected_values=hypo_asimov_dist, metric=metric) |
| 2769 | + + hypo_maker.params.priors_penalty(metric=metric) |
| 2770 | + ) |
| 2771 | + return sign*metric_val |
| 2772 | + |
| 2773 | + sampler = emcee.EnsembleSampler( |
| 2774 | + nwalkers, ndim, func, |
| 2775 | + moves=sampling_algorithm, |
| 2776 | + args=[bounds, data_dist, hypo_maker, metric] |
| 2777 | + ) |
| 2778 | + |
| 2779 | + if self.pprint: |
| 2780 | + sys.stdout.write('Burn in') |
| 2781 | + sys.stdout.flush() |
| 2782 | + pos, prob, state = sampler.run_mcmc(p0, burnin, progress=self.pprint) |
| 2783 | + |
| 2784 | + if return_burn_in: |
| 2785 | + flatchain_burnin = sampler.flatchain |
| 2786 | + scaled_chain_burnin = np.full_like(flatchain_burnin, np.nan, dtype=FTYPE) |
| 2787 | + param_copy_burnin = ParamSet(hypo_maker.params.free) |
| 2788 | + |
| 2789 | + for s, sample in enumerate(flatchain_burnin): |
| 2790 | + for dim, rescaled_val in enumerate(sample): |
| 2791 | + param = param_copy_burnin[dim] |
| 2792 | + param._rescaled_value = rescaled_val |
| 2793 | + val = param.value.m |
| 2794 | + scaled_chain_burnin[s, dim] = val |
| 2795 | + |
| 2796 | + sampler.reset() |
| 2797 | + if self.pprint: |
| 2798 | + sys.stdout.write('Main sampling') |
| 2799 | + sys.stdout.flush() |
| 2800 | + sampler.run_mcmc(pos, nsteps, progress=self.pprint) |
| 2801 | + |
| 2802 | + flatchain = sampler.flatchain |
| 2803 | + scaled_chain = np.full_like(flatchain, np.nan, dtype=FTYPE) |
| 2804 | + param_copy = ParamSet(hypo_maker.params.free) |
| 2805 | + |
| 2806 | + for s, sample in enumerate(flatchain): |
| 2807 | + for dim, rescaled_val in enumerate(sample): |
| 2808 | + param = param_copy[dim] |
| 2809 | + param._rescaled_value = rescaled_val |
| 2810 | + val = param.value.m |
| 2811 | + scaled_chain[s, dim] = val |
| 2812 | + |
| 2813 | + if return_burn_in: |
| 2814 | + return scaled_chain, scaled_chain_burnin |
| 2815 | + else: |
| 2816 | + return scaled_chain |
| 2817 | + |
| 2818 | + |
2684 | 2819 | class Analysis(BasicAnalysis):
|
2685 | 2820 | """Analysis class for "canonical" IceCube/DeepCore/PINGU analyses.
|
2686 | 2821 |
|
|
0 commit comments