|
| 1 | +from abc import ABC, abstractmethod |
1 | 2 | import numpy as np |
2 | | -from . import xfrac_file |
3 | | -from . import density_file |
4 | | -from . import conv |
5 | | -from .helper_functions import get_data_and_type |
6 | | - |
7 | 3 | import matplotlib.pyplot as plt |
8 | 4 | import matplotlib.lines as mlines |
9 | 5 | from matplotlib import colors as mcolors |
| 6 | +from scipy.linalg import inv |
| 7 | +from scipy.stats import entropy |
| 8 | +import pandas as pd |
| 9 | + |
| 10 | +# External dependencies for corner and mcmc plots |
| 11 | +try: |
| 12 | + import corner |
| 13 | +except ImportError: |
| 14 | + corner = None |
| 15 | + |
| 16 | +try: |
| 17 | + from getdist import plots, MCSamples |
| 18 | +except ImportError: |
| 19 | + plots, MCSamples = None, None |
| 20 | + |
| 21 | +from . import conv |
| 22 | +from .helper_functions import get_data_and_type |
10 | 23 |
|
11 | 24 | def plot_slice(data, los_axis = 0, slice_num = 0, logscale = False, **kwargs): |
12 | 25 | ''' |
@@ -249,24 +262,283 @@ def plot_triangle(samples_dict, weights_dict=None, |
249 | 262 | c.set_plot_config(PlotConfig(bins=bins, extents=extents, smooth=smooth)) |
250 | 263 | return c.plotter.plot(**kwargs) |
251 | 264 |
|
252 | | -if __name__ == '__main__': |
253 | | - import tools21cm as t2c |
254 | | - import pylab as pl |
255 | | - |
256 | | - t2c.set_verbose(True) |
257 | | - |
258 | | - pl.figure() |
259 | | - |
260 | | - dfilename = '/disk/sn-12/garrelt/Science/Simulations/Reionization/C2Ray_WMAP5/114Mpc_WMAP5/coarser_densities/nc256_halos_removed/6.905n_all.dat' |
261 | | - xfilename = '/disk/sn-12/garrelt/Science/Simulations/Reionization/C2Ray_WMAP5/114Mpc_WMAP5/114Mpc_f2_10S_256/results_ranger/xfrac3d_8.958.bin' |
262 | | - |
263 | | - dfile = t2c.DensityFile(dfilename) |
264 | | -# plot_slice(dfile, los_axis=1, logscale=True, cmap=pl.cm.hot) |
265 | | -# ax2 = pl.subplot(1,2,2) |
266 | | -# plot_slice(xfilename) |
267 | | - plot_slice(t2c.XfracFile(xfilename)) |
268 | | - pl.show() |
269 | | - |
270 | | - |
271 | | - |
272 | | - |
| 265 | +class DistributionDiagnostic(ABC): |
| 266 | + """ |
| 267 | + Base class for diagnosing and comparing probability distributions. |
| 268 | +
|
| 269 | + Attributes: |
| 270 | + backend (str): 'corner' or 'getdist' for multidimensional plots. |
| 271 | + true_values (list): Ground truth values for computing diagnostic metrics. |
| 272 | + param_labels (list): LaTeX labels for the parameters (e.g., [r'\Omega_m']). |
| 273 | + distributions (dict): Dictionary storing distribution data and stats. |
| 274 | + """ |
| 275 | + _METRIC_LABELS = { |
| 276 | + 'Z': r'$Z_p = |\mu_p - \theta_{\mathrm{truth},p}|\,/\,\sigma_p$', |
| 277 | + 'PIT': r'$F_p(\theta_{\mathrm{truth},p})$', |
| 278 | + 'Bias': r'$\tilde{\theta}_p - \theta_{\mathrm{truth},p}$', |
| 279 | + 'CI68': r'$\Delta_{68,p}$', |
| 280 | + 'Mahalanobis': r'$D_M$', |
| 281 | + 'KL': r'$D_{KL}$ (bits)', |
| 282 | + 'RMSE': r'RMSE', |
| 283 | + 'Cover_68': 'Cover 68%', |
| 284 | + 'Cover_95': 'Cover 95%', |
| 285 | + } |
| 286 | + _PER_PARAM = {'Z', 'PIT', 'Bias', 'CI68'} |
| 287 | + _IDEAL_VALUES = { |
| 288 | + 'Z': 0.0, 'PIT': 0.5, 'Bias': 0.0, 'Mahalanobis': 1.0, |
| 289 | + 'KL': 0.0, 'RMSE': 0.0 |
| 290 | + } |
| 291 | + |
| 292 | + def __init__(self, backend='corner', true_values=None, param_labels=None): |
| 293 | + self.backend = backend.lower() |
| 294 | + self.true_values = true_values |
| 295 | + self.param_labels = param_labels # Can be None; will be generated dynamically if needed |
| 296 | + self.distributions = {} |
| 297 | + |
| 298 | + # Priority on C0-C9 for clarity, then tab20 for density |
| 299 | + self.fallback_colors = [f'C{i}' for i in range(10)] + \ |
| 300 | + [plt.get_cmap('tab20')(i) for i in range(20)] |
| 301 | + |
| 302 | + def _get_default_param_labels(self, num_params): |
| 303 | + """Generates labels like \theta_1, \theta_2... if param_labels is None or too short.""" |
| 304 | + if self.param_labels is None: |
| 305 | + return [r"\theta_{%d}" % (i+1) for i in range(num_params)] |
| 306 | + if len(self.param_labels) < num_params: |
| 307 | + extended = list(self.param_labels) |
| 308 | + for i in range(len(self.param_labels), num_params): |
| 309 | + extended.append(r"\theta_{%d}" % (i+1)) |
| 310 | + return extended |
| 311 | + return self.param_labels |
| 312 | + |
| 313 | + def _get_distribution_label(self, label): |
| 314 | + """Returns provided label or generates 'Distribution N'.""" |
| 315 | + if label is not None: |
| 316 | + return label |
| 317 | + return "Distribution %d" % (len(self.distributions) + 1) |
| 318 | + |
| 319 | + @abstractmethod |
| 320 | + def add_distribution(self, data, label=None, color=None): |
| 321 | + """Must be implemented by subclasses.""" |
| 322 | + |
| 323 | + def _calculate_base_metrics(self, points, weights): |
| 324 | + """Common metric calculation for weighted samples.""" |
| 325 | + weights_norm = weights / np.sum(weights) |
| 326 | + num_params = points.shape[1] |
| 327 | + |
| 328 | + means = np.average(points, axis=0, weights=weights_norm) |
| 329 | + cov = np.cov(points.T, aweights=weights_norm) |
| 330 | + sigmas = np.sqrt(np.diag(cov)) |
| 331 | + |
| 332 | + cis = [] |
| 333 | + for p in range(num_params): |
| 334 | + data_p = points[:, p] |
| 335 | + idx = np.argsort(data_p) |
| 336 | + sorted_data = data_p[idx] |
| 337 | + sorted_weights = weights_norm[idx] |
| 338 | + cum_weights = np.cumsum(sorted_weights) |
| 339 | + |
| 340 | + def quantile(q, _cw=cum_weights, _sd=sorted_data): return float(np.interp(q, _cw, _sd)) |
| 341 | + def cdf_at(val, _cw=cum_weights, _sd=sorted_data): return float(np.interp(val, _sd, _cw)) |
| 342 | + |
| 343 | + cis.append({ |
| 344 | + 'median': quantile(0.500), |
| 345 | + 'lo1': quantile(0.160), 'hi1': quantile(0.840), |
| 346 | + 'lo2': quantile(0.025), 'hi2': quantile(0.975), |
| 347 | + 'cdf_at': cdf_at, |
| 348 | + }) |
| 349 | + |
| 350 | + metrics = {'means': means, 'sigmas': sigmas, 'cov': cov, 'cis': cis} |
| 351 | + |
| 352 | + if self.true_values is not None: |
| 353 | + # Handle variable parameter counts |
| 354 | + tv_slice = np.asarray(self.true_values)[:num_params] |
| 355 | + metrics['z_scores'] = np.abs(means - tv_slice) / sigmas |
| 356 | + delta = means - tv_slice |
| 357 | + metrics['rmse'] = np.sqrt(np.mean(delta**2)) |
| 358 | + |
| 359 | + try: |
| 360 | + metrics['mahalanobis'] = float(np.sqrt(delta @ inv(cov) @ delta)) |
| 361 | + except: |
| 362 | + metrics['mahalanobis'] = np.nan |
| 363 | + |
| 364 | + metrics['pit'] = np.array([cis[p]['cdf_at'](tv_slice[p]) for p in range(len(tv_slice))]) |
| 365 | + metrics['cover_68'] = np.array([cis[p]['lo1'] <= tv_slice[p] <= cis[p]['hi1'] for p in range(len(tv_slice))]) |
| 366 | + metrics['cover_95'] = np.array([cis[p]['lo2'] <= tv_slice[p] <= cis[p]['hi2'] for p in range(len(tv_slice))]) |
| 367 | + |
| 368 | + # Info gain proxy (relative to unit volume) |
| 369 | + try: |
| 370 | + metrics['entropy'] = 0.5 * np.log(np.linalg.det(2 * np.pi * np.e * cov)) |
| 371 | + except: |
| 372 | + metrics['entropy'] = np.nan |
| 373 | + |
| 374 | + return metrics |
| 375 | + |
| 376 | + def plot_corner(self, levels=None, **kwargs): |
| 377 | + """Corner/triangle plot for all added distributions. |
| 378 | +
|
| 379 | + Args: |
| 380 | + levels: Contour levels as probability fractions. Default: [0.68, 0.95]. |
| 381 | + **kwargs: Passed to the backend (corner or getdist). |
| 382 | + """ |
| 383 | + if not self.distributions: |
| 384 | + raise ValueError("No distributions added.") |
| 385 | + if levels is None: |
| 386 | + levels = [0.68, 0.95] |
| 387 | + |
| 388 | + if self.backend == 'corner': |
| 389 | + return self._plot_corner_backend(levels=levels, **kwargs) |
| 390 | + elif self.backend == 'getdist': |
| 391 | + return self._plot_getdist_backend(levels=levels, **kwargs) |
| 392 | + |
| 393 | + def _plot_corner_backend(self, levels=None, **kwargs): |
| 394 | + if corner is None: raise ImportError("Please install 'corner'.") |
| 395 | + if levels is None: |
| 396 | + levels = [0.68, 0.95] |
| 397 | + fig = None |
| 398 | + handle_map = {} # orig insertion index → legend handle |
| 399 | + |
| 400 | + max_params = max([p['points'].shape[1] for p in self.distributions.values()]) |
| 401 | + full_labels = [f"${l}$" for l in self._get_default_param_labels(max_params)] |
| 402 | + |
| 403 | + # Render full-dim distributions first so the base figure exists before transplanting |
| 404 | + ordered = sorted( |
| 405 | + enumerate(self.distributions.items()), |
| 406 | + key=lambda x: -x[1][1]['points'].shape[1] |
| 407 | + ) |
| 408 | + |
| 409 | + for orig_i, (name, dist) in ordered: |
| 410 | + color = dist['color'] or self.fallback_colors[orig_i % len(self.fallback_colors)] |
| 411 | + k = dist['points'].shape[1] |
| 412 | + |
| 413 | + if k == max_params: |
| 414 | + opts = { |
| 415 | + 'labels': full_labels, 'color': color, 'levels': levels, |
| 416 | + 'fill_contours': True, 'plot_datapoints': False, 'fig': fig |
| 417 | + } |
| 418 | + opts.update(kwargs) |
| 419 | + fig = corner.corner(dist['points'], weights=dist['weights'], **opts) |
| 420 | + else: |
| 421 | + # Render on a temporary figure then transplant artists into the |
| 422 | + # top-left k×k sub-panels of the main figure — no fake data added |
| 423 | + extra_kw = {kk: vv for kk, vv in kwargs.items() |
| 424 | + if kk not in ('fig', 'labels', 'color', 'levels', |
| 425 | + 'fill_contours', 'plot_datapoints')} |
| 426 | + temp_fig = corner.corner( |
| 427 | + dist['points'], weights=dist['weights'], |
| 428 | + labels=full_labels[:k], color=color, levels=levels, |
| 429 | + fill_contours=True, plot_datapoints=False, **extra_kw |
| 430 | + ) |
| 431 | + temp_axarr = np.array(temp_fig.axes).reshape(k, k) |
| 432 | + main_axarr = np.array(fig.axes).reshape(max_params, max_params) |
| 433 | + for row in range(k): |
| 434 | + for col in range(k): |
| 435 | + src = temp_axarr[row, col] |
| 436 | + dst = main_axarr[row, col] |
| 437 | + for coll in list(src.collections): |
| 438 | + coll.remove() |
| 439 | + dst.add_collection(coll) |
| 440 | + coll.set_transform(dst.transData) |
| 441 | + if row == col: |
| 442 | + for ln in list(src.lines): |
| 443 | + ln.remove() |
| 444 | + dst.add_line(ln) |
| 445 | + plt.close(temp_fig) |
| 446 | + |
| 447 | + handle_map[orig_i] = mlines.Line2D([], [], color=color, label=name, lw=2) |
| 448 | + |
| 449 | + if self.true_values: |
| 450 | + corner.overplot_lines(fig, self.true_values[:max_params], |
| 451 | + color="gray", ls="--", alpha=0.5) |
| 452 | + |
| 453 | + handles = [handle_map[i] for i in sorted(handle_map)] |
| 454 | + fig.legend(handles=handles, loc='upper right', bbox_to_anchor=(0.95, 0.95)) |
| 455 | + return fig |
| 456 | + |
| 457 | + def _plot_getdist_backend(self, levels=None, **kwargs): |
| 458 | + if plots is None: raise ImportError("Please install 'getdist'.") |
| 459 | + if levels is None: |
| 460 | + levels = [0.68, 0.95] |
| 461 | + samples_list = [] |
| 462 | + |
| 463 | + max_params = max([p['points'].shape[1] for p in self.distributions.values()]) |
| 464 | + full_labels = self._get_default_param_labels(max_params) |
| 465 | + |
| 466 | + colors = [] |
| 467 | + for i, (name, dist) in enumerate(self.distributions.items()): |
| 468 | + k = dist['points'].shape[1] |
| 469 | + p_names = [f"p{j}" for j in range(k)] |
| 470 | + s = MCSamples(samples=dist['points'], weights=dist['weights'], |
| 471 | + names=p_names, labels=full_labels[:k], label=name) |
| 472 | + samples_list.append(s) |
| 473 | + colors.append(dist['color'] or self.fallback_colors[i % len(self.fallback_colors)]) |
| 474 | + |
| 475 | + g = plots.get_subplot_plotter() |
| 476 | + g.triangle_plot(samples_list, filled=True, contour_levels=levels, |
| 477 | + colors=colors, markers=self.true_values, **kwargs) |
| 478 | + return g |
| 479 | + |
| 480 | + def plot_forest(self): |
| 481 | + num_dist = len(self.distributions) |
| 482 | + max_params = max([p['points'].shape[1] for p in self.distributions.values()]) |
| 483 | + labels = self._get_default_param_labels(max_params) |
| 484 | + |
| 485 | + fig, axes = plt.subplots(1, max_params, figsize=(max_params*4, num_dist * 0.4 + 2), sharey=True) |
| 486 | + if max_params == 1: axes = [axes] |
| 487 | + |
| 488 | + for p in range(max_params): |
| 489 | + ax = axes[p] |
| 490 | + for i, (name, dist) in enumerate(self.distributions.items()): |
| 491 | + if p >= dist['points'].shape[1]: continue |
| 492 | + |
| 493 | + ci = dist['stats']['cis'][p] |
| 494 | + color = dist['color'] or self.fallback_colors[i % len(self.fallback_colors)] |
| 495 | + ax.errorbar(ci['median'], i, xerr=[[ci['median'] - ci['lo2']], [ci['hi2'] - ci['median']]], |
| 496 | + fmt='none', color=color, lw=1, alpha=0.3) |
| 497 | + ax.errorbar(ci['median'], i, xerr=[[ci['median'] - ci['lo1']], [ci['hi1'] - ci['median']]], |
| 498 | + fmt='o', color=color, lw=3) |
| 499 | + |
| 500 | + if self.true_values and p < len(self.true_values): |
| 501 | + ax.axvline(self.true_values[p], color='gray', ls='--', alpha=0.6) |
| 502 | + ax.set_xlabel(f'${labels[p]}$') |
| 503 | + if p == 0: |
| 504 | + ax.set_yticks(range(num_dist)) |
| 505 | + ax.set_yticklabels(list(self.distributions.keys())) |
| 506 | + ax.invert_yaxis() |
| 507 | + plt.tight_layout() |
| 508 | + return fig |
| 509 | + |
| 510 | + |
| 511 | +class GriddedProbabilities(DistributionDiagnostic): |
| 512 | + """Diagnoses distributions defined on a regular N-D probability grid.""" |
| 513 | + def __init__(self, coords_1d=None, **kwargs): |
| 514 | + super().__init__(**kwargs) |
| 515 | + self.coords_1d = coords_1d if coords_1d is not None else np.linspace(0, 1, 100) |
| 516 | + |
| 517 | + def add_distribution(self, grid, label=None, color=None): |
| 518 | + label = self._get_distribution_label(label) |
| 519 | + ndim = grid.ndim |
| 520 | + axes_coords = [self.coords_1d] * ndim |
| 521 | + mesh = np.meshgrid(*axes_coords, indexing='ij') |
| 522 | + points = np.vstack([m.flatten() for m in mesh]).T |
| 523 | + weights = grid.flatten() |
| 524 | + |
| 525 | + self.distributions[label] = { |
| 526 | + 'grid': grid, 'points': points, 'weights': weights, |
| 527 | + 'color': color, |
| 528 | + 'stats': self._calculate_base_metrics(points, weights) |
| 529 | + } |
| 530 | + |
| 531 | + |
| 532 | +class SampledDistribution(DistributionDiagnostic): |
| 533 | + """Diagnoses distributions represented as samples (e.g. MCMC chains, Monte Carlo draws).""" |
| 534 | + def add_distribution(self, samples, label=None, weights=None, color=None): |
| 535 | + label = self._get_distribution_label(label) |
| 536 | + if weights is None: |
| 537 | + weights = np.ones(len(samples)) |
| 538 | + |
| 539 | + self.distributions[label] = { |
| 540 | + 'points': samples, 'weights': weights, |
| 541 | + 'color': color, |
| 542 | + 'stats': self._calculate_base_metrics(samples, weights) |
| 543 | + } |
| 544 | + |
0 commit comments