Skip to content

Commit d9bf3da

Browse files
author
sambit-giri
committed
posterior diagnostic demo
1 parent 5261a23 commit d9bf3da

5 files changed

Lines changed: 838 additions & 29 deletions

File tree

docs/changelog.rst

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,39 @@
22
Changelog
33
=========
44

5+
v2.4
6+
----
7+
* ``DistributionDiagnostic`` class family for comparing and diagnosing probability distributions.
8+
* ``SampledDistribution``: diagnose distributions from sample arrays (MCMC chains, Monte Carlo draws, bootstrap replicates); supports importance weights.
9+
* ``GriddedProbabilities``: same interface for distributions on a regular N-D probability grid.
10+
* Corner plots (``corner`` and ``getdist`` backends) with default 68% and 95% contours; mixed-dimensionality overlays supported.
11+
* Forest plots showing 68% and 95% credible intervals across distributions.
12+
* Calibration metrics: Z-score, PIT, bias, RMSE, Mahalanobis distance, coverage.
13+
* ``fftconvolve`` now handles arrays of unequal shape.
14+
15+
v2.3
16+
----
17+
* GPU-accelerated topology: Euler characteristics via PyTorch, with Apple M-chip (MPS) support.
18+
* Radio telescope sensitivity: SEFD tables, SKA-Low Bessel primary beam, UV mapping in Lagrangian space, uniform weighting, spectral-leakage suppression.
19+
* Astrophysical data: fesc LyA constraints, Qin+2025 MAP reionization model, reionization observational constraints.
20+
* Zarr file format support.
21+
* Noise lightcone fixes (double ``jansky_2_kelvin`` call, decreasing-redshift input).
22+
* ``fftconvolve`` moved to dedicated ``fft_functions.py``.
23+
24+
v2.2
25+
----
26+
* Bispectrum and integrated bispectrum estimators.
27+
* Multiple SKA layouts (AA1, AA2, AA*, AA4) with antenna-wise gain modelling.
28+
* UV track simulation speed-up (×10).
29+
* ViteBetti topology with Cython acceleration.
30+
* py21cmfast interface for dark-matter halo retrieval.
31+
* Landy-Szalay correlation function estimator.
32+
* Migrated build system to ``pyproject.toml``; ``scipy`` version compatibility layer.
33+
534
v2.1
635
----
736
* Modules to analyse 21 cm images added.
8-
* Compatible with python 3 only
37+
* Compatible with python 3 only.
938

1039
v1.1
1140
----

notebooks/posterior_diagnostic_demo.ipynb

Lines changed: 508 additions & 0 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
99

1010
[tool.poetry]
1111
name = "tools21cm"
12-
version = "2.3.9"
12+
version = "2.4.0"
1313
description = "A package providing tools to analyse cosmological simulations of reionization"
1414
authors = ["Sambit Giri <sambit.giri@gmail.com>"]
1515
license = "MIT"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
setup(
3838
name='tools21cm',
39-
version='2.3.10',
39+
version='2.4.0',
4040
author='Sambit Giri',
4141
author_email='sambit.giri@gmail.com',
4242
packages=find_packages(where="src"),

src/tools21cm/plotting.py

Lines changed: 298 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,25 @@
1+
from abc import ABC, abstractmethod
12
import numpy as np
2-
from . import xfrac_file
3-
from . import density_file
4-
from . import conv
5-
from .helper_functions import get_data_and_type
6-
73
import matplotlib.pyplot as plt
84
import matplotlib.lines as mlines
95
from matplotlib import colors as mcolors
6+
from scipy.linalg import inv
7+
from scipy.stats import entropy
8+
import pandas as pd
9+
10+
# External dependencies for corner and mcmc plots
11+
try:
12+
import corner
13+
except ImportError:
14+
corner = None
15+
16+
try:
17+
from getdist import plots, MCSamples
18+
except ImportError:
19+
plots, MCSamples = None, None
20+
21+
from . import conv
22+
from .helper_functions import get_data_and_type
1023

1124
def plot_slice(data, los_axis = 0, slice_num = 0, logscale = False, **kwargs):
1225
'''
@@ -249,24 +262,283 @@ def plot_triangle(samples_dict, weights_dict=None,
249262
c.set_plot_config(PlotConfig(bins=bins, extents=extents, smooth=smooth))
250263
return c.plotter.plot(**kwargs)
251264

252-
if __name__ == '__main__':
253-
import tools21cm as t2c
254-
import pylab as pl
255-
256-
t2c.set_verbose(True)
257-
258-
pl.figure()
259-
260-
dfilename = '/disk/sn-12/garrelt/Science/Simulations/Reionization/C2Ray_WMAP5/114Mpc_WMAP5/coarser_densities/nc256_halos_removed/6.905n_all.dat'
261-
xfilename = '/disk/sn-12/garrelt/Science/Simulations/Reionization/C2Ray_WMAP5/114Mpc_WMAP5/114Mpc_f2_10S_256/results_ranger/xfrac3d_8.958.bin'
262-
263-
dfile = t2c.DensityFile(dfilename)
264-
# plot_slice(dfile, los_axis=1, logscale=True, cmap=pl.cm.hot)
265-
# ax2 = pl.subplot(1,2,2)
266-
# plot_slice(xfilename)
267-
plot_slice(t2c.XfracFile(xfilename))
268-
pl.show()
269-
270-
271-
272-
265+
class DistributionDiagnostic(ABC):
266+
"""
267+
Base class for diagnosing and comparing probability distributions.
268+
269+
Attributes:
270+
backend (str): 'corner' or 'getdist' for multidimensional plots.
271+
true_values (list): Ground truth values for computing diagnostic metrics.
272+
param_labels (list): LaTeX labels for the parameters (e.g., [r'\Omega_m']).
273+
distributions (dict): Dictionary storing distribution data and stats.
274+
"""
275+
_METRIC_LABELS = {
276+
'Z': r'$Z_p = |\mu_p - \theta_{\mathrm{truth},p}|\,/\,\sigma_p$',
277+
'PIT': r'$F_p(\theta_{\mathrm{truth},p})$',
278+
'Bias': r'$\tilde{\theta}_p - \theta_{\mathrm{truth},p}$',
279+
'CI68': r'$\Delta_{68,p}$',
280+
'Mahalanobis': r'$D_M$',
281+
'KL': r'$D_{KL}$ (bits)',
282+
'RMSE': r'RMSE',
283+
'Cover_68': 'Cover 68%',
284+
'Cover_95': 'Cover 95%',
285+
}
286+
_PER_PARAM = {'Z', 'PIT', 'Bias', 'CI68'}
287+
_IDEAL_VALUES = {
288+
'Z': 0.0, 'PIT': 0.5, 'Bias': 0.0, 'Mahalanobis': 1.0,
289+
'KL': 0.0, 'RMSE': 0.0
290+
}
291+
292+
def __init__(self, backend='corner', true_values=None, param_labels=None):
293+
self.backend = backend.lower()
294+
self.true_values = true_values
295+
self.param_labels = param_labels # Can be None; will be generated dynamically if needed
296+
self.distributions = {}
297+
298+
# Priority on C0-C9 for clarity, then tab20 for density
299+
self.fallback_colors = [f'C{i}' for i in range(10)] + \
300+
[plt.get_cmap('tab20')(i) for i in range(20)]
301+
302+
def _get_default_param_labels(self, num_params):
303+
"""Generates labels like \theta_1, \theta_2... if param_labels is None or too short."""
304+
if self.param_labels is None:
305+
return [r"\theta_{%d}" % (i+1) for i in range(num_params)]
306+
if len(self.param_labels) < num_params:
307+
extended = list(self.param_labels)
308+
for i in range(len(self.param_labels), num_params):
309+
extended.append(r"\theta_{%d}" % (i+1))
310+
return extended
311+
return self.param_labels
312+
313+
def _get_distribution_label(self, label):
314+
"""Returns provided label or generates 'Distribution N'."""
315+
if label is not None:
316+
return label
317+
return "Distribution %d" % (len(self.distributions) + 1)
318+
319+
@abstractmethod
320+
def add_distribution(self, data, label=None, color=None):
321+
"""Must be implemented by subclasses."""
322+
323+
def _calculate_base_metrics(self, points, weights):
324+
"""Common metric calculation for weighted samples."""
325+
weights_norm = weights / np.sum(weights)
326+
num_params = points.shape[1]
327+
328+
means = np.average(points, axis=0, weights=weights_norm)
329+
cov = np.cov(points.T, aweights=weights_norm)
330+
sigmas = np.sqrt(np.diag(cov))
331+
332+
cis = []
333+
for p in range(num_params):
334+
data_p = points[:, p]
335+
idx = np.argsort(data_p)
336+
sorted_data = data_p[idx]
337+
sorted_weights = weights_norm[idx]
338+
cum_weights = np.cumsum(sorted_weights)
339+
340+
def quantile(q, _cw=cum_weights, _sd=sorted_data): return float(np.interp(q, _cw, _sd))
341+
def cdf_at(val, _cw=cum_weights, _sd=sorted_data): return float(np.interp(val, _sd, _cw))
342+
343+
cis.append({
344+
'median': quantile(0.500),
345+
'lo1': quantile(0.160), 'hi1': quantile(0.840),
346+
'lo2': quantile(0.025), 'hi2': quantile(0.975),
347+
'cdf_at': cdf_at,
348+
})
349+
350+
metrics = {'means': means, 'sigmas': sigmas, 'cov': cov, 'cis': cis}
351+
352+
if self.true_values is not None:
353+
# Handle variable parameter counts
354+
tv_slice = np.asarray(self.true_values)[:num_params]
355+
metrics['z_scores'] = np.abs(means - tv_slice) / sigmas
356+
delta = means - tv_slice
357+
metrics['rmse'] = np.sqrt(np.mean(delta**2))
358+
359+
try:
360+
metrics['mahalanobis'] = float(np.sqrt(delta @ inv(cov) @ delta))
361+
except:
362+
metrics['mahalanobis'] = np.nan
363+
364+
metrics['pit'] = np.array([cis[p]['cdf_at'](tv_slice[p]) for p in range(len(tv_slice))])
365+
metrics['cover_68'] = np.array([cis[p]['lo1'] <= tv_slice[p] <= cis[p]['hi1'] for p in range(len(tv_slice))])
366+
metrics['cover_95'] = np.array([cis[p]['lo2'] <= tv_slice[p] <= cis[p]['hi2'] for p in range(len(tv_slice))])
367+
368+
# Info gain proxy (relative to unit volume)
369+
try:
370+
metrics['entropy'] = 0.5 * np.log(np.linalg.det(2 * np.pi * np.e * cov))
371+
except:
372+
metrics['entropy'] = np.nan
373+
374+
return metrics
375+
376+
def plot_corner(self, levels=None, **kwargs):
377+
"""Corner/triangle plot for all added distributions.
378+
379+
Args:
380+
levels: Contour levels as probability fractions. Default: [0.68, 0.95].
381+
**kwargs: Passed to the backend (corner or getdist).
382+
"""
383+
if not self.distributions:
384+
raise ValueError("No distributions added.")
385+
if levels is None:
386+
levels = [0.68, 0.95]
387+
388+
if self.backend == 'corner':
389+
return self._plot_corner_backend(levels=levels, **kwargs)
390+
elif self.backend == 'getdist':
391+
return self._plot_getdist_backend(levels=levels, **kwargs)
392+
393+
def _plot_corner_backend(self, levels=None, **kwargs):
394+
if corner is None: raise ImportError("Please install 'corner'.")
395+
if levels is None:
396+
levels = [0.68, 0.95]
397+
fig = None
398+
handle_map = {} # orig insertion index → legend handle
399+
400+
max_params = max([p['points'].shape[1] for p in self.distributions.values()])
401+
full_labels = [f"${l}$" for l in self._get_default_param_labels(max_params)]
402+
403+
# Render full-dim distributions first so the base figure exists before transplanting
404+
ordered = sorted(
405+
enumerate(self.distributions.items()),
406+
key=lambda x: -x[1][1]['points'].shape[1]
407+
)
408+
409+
for orig_i, (name, dist) in ordered:
410+
color = dist['color'] or self.fallback_colors[orig_i % len(self.fallback_colors)]
411+
k = dist['points'].shape[1]
412+
413+
if k == max_params:
414+
opts = {
415+
'labels': full_labels, 'color': color, 'levels': levels,
416+
'fill_contours': True, 'plot_datapoints': False, 'fig': fig
417+
}
418+
opts.update(kwargs)
419+
fig = corner.corner(dist['points'], weights=dist['weights'], **opts)
420+
else:
421+
# Render on a temporary figure then transplant artists into the
422+
# top-left k×k sub-panels of the main figure — no fake data added
423+
extra_kw = {kk: vv for kk, vv in kwargs.items()
424+
if kk not in ('fig', 'labels', 'color', 'levels',
425+
'fill_contours', 'plot_datapoints')}
426+
temp_fig = corner.corner(
427+
dist['points'], weights=dist['weights'],
428+
labels=full_labels[:k], color=color, levels=levels,
429+
fill_contours=True, plot_datapoints=False, **extra_kw
430+
)
431+
temp_axarr = np.array(temp_fig.axes).reshape(k, k)
432+
main_axarr = np.array(fig.axes).reshape(max_params, max_params)
433+
for row in range(k):
434+
for col in range(k):
435+
src = temp_axarr[row, col]
436+
dst = main_axarr[row, col]
437+
for coll in list(src.collections):
438+
coll.remove()
439+
dst.add_collection(coll)
440+
coll.set_transform(dst.transData)
441+
if row == col:
442+
for ln in list(src.lines):
443+
ln.remove()
444+
dst.add_line(ln)
445+
plt.close(temp_fig)
446+
447+
handle_map[orig_i] = mlines.Line2D([], [], color=color, label=name, lw=2)
448+
449+
if self.true_values:
450+
corner.overplot_lines(fig, self.true_values[:max_params],
451+
color="gray", ls="--", alpha=0.5)
452+
453+
handles = [handle_map[i] for i in sorted(handle_map)]
454+
fig.legend(handles=handles, loc='upper right', bbox_to_anchor=(0.95, 0.95))
455+
return fig
456+
457+
def _plot_getdist_backend(self, levels=None, **kwargs):
458+
if plots is None: raise ImportError("Please install 'getdist'.")
459+
if levels is None:
460+
levels = [0.68, 0.95]
461+
samples_list = []
462+
463+
max_params = max([p['points'].shape[1] for p in self.distributions.values()])
464+
full_labels = self._get_default_param_labels(max_params)
465+
466+
colors = []
467+
for i, (name, dist) in enumerate(self.distributions.items()):
468+
k = dist['points'].shape[1]
469+
p_names = [f"p{j}" for j in range(k)]
470+
s = MCSamples(samples=dist['points'], weights=dist['weights'],
471+
names=p_names, labels=full_labels[:k], label=name)
472+
samples_list.append(s)
473+
colors.append(dist['color'] or self.fallback_colors[i % len(self.fallback_colors)])
474+
475+
g = plots.get_subplot_plotter()
476+
g.triangle_plot(samples_list, filled=True, contour_levels=levels,
477+
colors=colors, markers=self.true_values, **kwargs)
478+
return g
479+
480+
def plot_forest(self):
481+
num_dist = len(self.distributions)
482+
max_params = max([p['points'].shape[1] for p in self.distributions.values()])
483+
labels = self._get_default_param_labels(max_params)
484+
485+
fig, axes = plt.subplots(1, max_params, figsize=(max_params*4, num_dist * 0.4 + 2), sharey=True)
486+
if max_params == 1: axes = [axes]
487+
488+
for p in range(max_params):
489+
ax = axes[p]
490+
for i, (name, dist) in enumerate(self.distributions.items()):
491+
if p >= dist['points'].shape[1]: continue
492+
493+
ci = dist['stats']['cis'][p]
494+
color = dist['color'] or self.fallback_colors[i % len(self.fallback_colors)]
495+
ax.errorbar(ci['median'], i, xerr=[[ci['median'] - ci['lo2']], [ci['hi2'] - ci['median']]],
496+
fmt='none', color=color, lw=1, alpha=0.3)
497+
ax.errorbar(ci['median'], i, xerr=[[ci['median'] - ci['lo1']], [ci['hi1'] - ci['median']]],
498+
fmt='o', color=color, lw=3)
499+
500+
if self.true_values and p < len(self.true_values):
501+
ax.axvline(self.true_values[p], color='gray', ls='--', alpha=0.6)
502+
ax.set_xlabel(f'${labels[p]}$')
503+
if p == 0:
504+
ax.set_yticks(range(num_dist))
505+
ax.set_yticklabels(list(self.distributions.keys()))
506+
ax.invert_yaxis()
507+
plt.tight_layout()
508+
return fig
509+
510+
511+
class GriddedProbabilities(DistributionDiagnostic):
512+
"""Diagnoses distributions defined on a regular N-D probability grid."""
513+
def __init__(self, coords_1d=None, **kwargs):
514+
super().__init__(**kwargs)
515+
self.coords_1d = coords_1d if coords_1d is not None else np.linspace(0, 1, 100)
516+
517+
def add_distribution(self, grid, label=None, color=None):
518+
label = self._get_distribution_label(label)
519+
ndim = grid.ndim
520+
axes_coords = [self.coords_1d] * ndim
521+
mesh = np.meshgrid(*axes_coords, indexing='ij')
522+
points = np.vstack([m.flatten() for m in mesh]).T
523+
weights = grid.flatten()
524+
525+
self.distributions[label] = {
526+
'grid': grid, 'points': points, 'weights': weights,
527+
'color': color,
528+
'stats': self._calculate_base_metrics(points, weights)
529+
}
530+
531+
532+
class SampledDistribution(DistributionDiagnostic):
533+
"""Diagnoses distributions represented as samples (e.g. MCMC chains, Monte Carlo draws)."""
534+
def add_distribution(self, samples, label=None, weights=None, color=None):
535+
label = self._get_distribution_label(label)
536+
if weights is None:
537+
weights = np.ones(len(samples))
538+
539+
self.distributions[label] = {
540+
'points': samples, 'weights': weights,
541+
'color': color,
542+
'stats': self._calculate_base_metrics(samples, weights)
543+
}
544+

0 commit comments

Comments
 (0)