Skip to content
Merged
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"RIFT>=0.0.15.11",
"arviz>=0.20.0",
"chex>=0.1.87",
"corner>=2.2.3",
"equinox>=0.11.3",
"flowMC==0.4.5",
"glasbey>=0.3.0",
Expand All @@ -39,10 +40,13 @@ dependencies = [
"loguru>=0.7.0",
"matplotlib>=3.9.0",
"mplcursors>=0.6",
"nbconvert>=7.17.1",
"numpy<3",
"numpyro>=0.19.0",
"optax<0.2.7",
"pandas>=2.2.0",
"papermill>=2.7.0",
"plotly>=6.7.0",
"pydantic>=2.12.0",
"quadax>=0.2.5",
"rich>=14.0.0",
Expand Down Expand Up @@ -84,13 +88,14 @@ gwk_ess_evolution_plot = "gwkokab_scripts.ess_evolution_plot:main"
gwk_ess_plot = "gwkokab_scripts.ess_plot:main"
gwk_flowMC_cfg_template = "gwkokab.analysis.core.inference_io._sampler:_dump_flowMC_cfg"
gwk_flowMC_info = "gwkokab_scripts.flowMC_info:main"
gwk_h5repack = "gwkokab_scripts.h5repack:main"
gwk_hist_overplot = "gwkokab_scripts.hist_overplot:main"
gwk_joint_plot = "gwkokab_scripts.joint_plot:main"
gwk_numpyro_cfg_template = "gwkokab.analysis.core.inference_io._sampler:_dump_numpyro_cfg"
gwk_param_lens = "gwkokab_scripts.param_lens:main"
gwk_pe_diagnostics = "gwkokab_scripts.pe_diagnostics:main"
gwk_ppd_plot = "gwkokab_scripts.ppd_plot:main"
gwk_r_hat_plot = "gwkokab_scripts.r_hat_plot:main"
gwk_report = "gwkokab.analysis.report.generate_report:generate_report"
gwk_scatter2d = "gwkokab_scripts.scatter2d:main"
gwk_scatter3d = "gwkokab_scripts.scatter3d:main"
gwk_trace_plot = "gwkokab_scripts.trace_plot:main"
Expand Down Expand Up @@ -168,3 +173,6 @@ find = { where = ["src/"] }

[tool.setuptools.dynamic]
version = { attr = "gwkokab.version.__version__" }

[tool.setuptools.package-data]
"gwkokab.analysis.report" = ["template_report.ipynb"]
236 changes: 91 additions & 145 deletions src/gwkokab/analysis/core/flowMC_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
# SPDX-License-Identifier: Apache-2.0


import os
import warnings
from typing import Any, Callable, Dict, List, Literal, Optional

import equinox as eqx
import h5py
import jax
import numpy as np
import tqdm
Expand All @@ -31,15 +30,14 @@

from gwkokab.analysis.core.guru import Guru, guru_arg_parser
from gwkokab.analysis.core.inference_io import FlowMCGlobalConfig
from gwkokab.analysis.core.utils import write_to_hdf5
from gwkokab.analysis.utils.literals import (
INFERENCE_DIRECTORY,
POSTERIOR_SAMPLES_FILENAME,
CHAIN_GROUP_FORMAT,
INFERENCE_OUTPUT_FILENAME,
SAMPLES_GROUP_NAME,
)
from gwkokab.models.utils import JointDistribution
from gwkokab.utils.exceptions import LoggedUserWarning, LoggedValueError


_INFERENCE_DIRECTORY = "flowMC_" + INFERENCE_DIRECTORY
from gwkokab.utils.exceptions import LoggedValueError


# WARNING: do not change anything in this class
Expand Down Expand Up @@ -594,8 +592,6 @@ def sample(
n_global_steps = self.strategy_order.count("model_trainer")
n_local_steps = (n_total_strategy - 6 * n_global_steps - 2) // 4

os.makedirs(_INFERENCE_DIRECTORY, exist_ok=True)

with tqdm.tqdm(range(n_global_steps), total=n_global_steps) as pbar:
pbar.set_description("Global Tuning")
for i in pbar:
Expand Down Expand Up @@ -655,174 +651,89 @@ def deserialize(self):
raise NotImplementedError


def _same_length_arrays(length: int, *arrays: np.ndarray) -> tuple[np.ndarray, ...]:
"""This function pads the arrays with None to make them the same length.

Parameters
----------
length : int
The length of the arrays.
arrays : np.ndarray
The arrays to pad.

Returns
-------
tuple[np.ndarray, ...]
The padded arrays.
"""
padded_arrays = []
for array in arrays:
padded_array = np.empty((length,))
padded_array[..., : array.shape[0]] = array
padded_array[..., array.shape[0] :] = None
padded_arrays.append(padded_array)
return tuple(padded_arrays)


def _save_acceptances(resources: dict) -> None:
"""Saves global and local acceptance rates to disk."""
# Mean acceptances
for acc_type in ["global", "local"]:
train_key = f"{acc_type}_accs_training"
prod_key = f"{acc_type}_accs_production"

# Check if training data was cleared
train_data = (
np.array(resources[train_key].data).mean(0)
if train_key in resources
else np.array([])
)
prod_data = (
np.array(resources[prod_key].data).mean(0)
if prod_key in resources
else np.array([])
)
"""Overwrites global and local acceptance rates in the HDF5 file."""
with h5py.File(INFERENCE_OUTPUT_FILENAME, "a") as f:
for acc_type in ["global", "local"]:
train_key = f"{acc_type}_accs_training"
prod_key = f"{acc_type}_accs_production"

if (max_len := max(len(train_data), len(prod_data))) == 0:
warnings.warn(
f"No data found for {acc_type} acceptance rates in both phases.",
LoggedUserWarning,
)
if train_key in resources and len(resources[train_key].data) > 0:
train_data = np.array(resources[train_key].data).mean(0)
write_to_hdf5(f, f"acceptances/{acc_type}/train", train_data)

np.savetxt(
f"{_INFERENCE_DIRECTORY}/{acc_type}_accs.dat",
np.column_stack(_same_length_arrays(max_len, train_data, prod_data)),
header="train prod",
comments="#",
)
if prod_key in resources and len(resources[prod_key].data) > 0:
prod_data = np.array(resources[prod_key].data).mean(0)
write_to_hdf5(f, f"acceptances/{acc_type}/prod", prod_data)


def _save_chains(resources: Dict, labels: List[str], *, is_training: bool) -> None:
"""Saves the chains to disk.
def _save_chains(resources: dict, labels: list[str], *, is_training: bool) -> None:
"""Overwrites the chains and log probabilities in the HDF5 file."""
phase = "train" if is_training else "prod"
pos_key = f"positions_{'training' if is_training else 'production'}"
lp_key = f"log_prob_{'training' if is_training else 'production'}"

Parameters
----------
resources : Dict
dictionary of resources
labels : List[str]
list of parameter labels
is_training : bool
whether the phase is training or production
"""
header = " ".join(labels)

if is_training:
phase = "train"
pos_key = "positions_training"
lp_key = "log_prob_training"
else:
phase = "prod"
pos_key = "positions_production"
lp_key = "log_prob_production"

if pos_key not in resources:
logger.warning(f"Key {pos_key} not found in resources. Skipping save.")
if pos_key not in resources or lp_key not in resources:
return

positions = np.array(resources[pos_key].data)
log_probs = np.array(resources[lp_key].data)
positions = np.array(resources[pos_key].data) # Shape: (n_chains, n_steps, n_dims)
log_probs = np.array(resources[lp_key].data) # Shape: (n_chains, n_steps)

n_chains = positions.shape[0]
width = len(str(n_chains - 1))

for n in range(n_chains):
tag = str(n).zfill(width)
np.savetxt(
f"{_INFERENCE_DIRECTORY}/{phase}_chains_{tag}.dat",
positions[n],
header=header,
)
np.savetxt(
f"{_INFERENCE_DIRECTORY}/log_prob_{phase}_{tag}.dat",
log_probs[n],
header=phase,
)

# Store parameter labels as metadata attributes
with h5py.File(INFERENCE_OUTPUT_FILENAME, "a") as f:
f.attrs["labels"] = labels

# Overwrite each chain dataset individually with the complete updated sequence
for n in range(n_chains):
dataset_suffix = (
f"chains/{phase}/" + CHAIN_GROUP_FORMAT.format(chain_id=n) + "/"
)
write_to_hdf5(f, dataset_suffix + "positions", positions[n])
write_to_hdf5(f, dataset_suffix + "log_probs", log_probs[n])


def _save_samples(
resources: Dict,
labels: List[str],
resources: dict,
labels: list[str],
n_local_steps_per_loop: int,
n_global_steps_per_loop: int,
) -> None:
"""Saves the posterior samples from the local sampler to disk.

Parameters
----------
resources : Dict
dictionary of resources
labels : List[str]
list of parameter labels
n_local_steps_per_loop : int
number of local steps per loop
n_global_steps_per_loop : int
number of global steps per loop
"""
header = " ".join(labels)
"""Overwrites the posterior samples dataset in the HDF5 file."""
if "positions_production" not in resources:
return

positions = np.array(resources["positions_production"].data)

_, n_production_steps, n_dims = positions.shape

selected_indices = list(
filter(
lambda idx: (
(idx % (n_local_steps_per_loop + n_global_steps_per_loop))
< n_local_steps_per_loop
),
range(n_production_steps),
)
)
selected_indices = [
idx
for idx in range(n_production_steps)
if (idx % (n_local_steps_per_loop + n_global_steps_per_loop))
< n_local_steps_per_loop
]

local_sampler_positions = positions[:, selected_indices, :].reshape(-1, n_dims)
local_sampler_positions = local_sampler_positions[
~np.isneginf(local_sampler_positions).any(axis=1)
]

np.savetxt(
rf"{_INFERENCE_DIRECTORY}/{POSTERIOR_SAMPLES_FILENAME}",
local_sampler_positions,
header=header,
write_to_hdf5(
INFERENCE_OUTPUT_FILENAME, SAMPLES_GROUP_NAME, local_sampler_positions
)


def _save_loss(resources: dict) -> None:
"""Saves the training loss to disk.
"""Overwrites the training loss dataset in the HDF5 file."""
if "loss_buffer" not in resources:
return

Parameters
----------
resources : Dict
dictionary of resources
"""
train_loss_vals = np.array(resources["loss_buffer"].data)
np.savetxt(
rf"{_INFERENCE_DIRECTORY}/loss.dat", train_loss_vals.reshape(-1), header="loss"
)
train_loss_vals = np.array(resources["loss_buffer"].data).reshape(-1)
write_to_hdf5(INFERENCE_OUTPUT_FILENAME, "loss", train_loss_vals)


class FlowMCBased(Guru):
output_directory: str = _INFERENCE_DIRECTORY

def driver(
self,
*,
Expand Down Expand Up @@ -883,7 +794,42 @@ def driver(
n_NFproposal_batch_size=sampler_cfg.n_NFproposal_batch_size,
verbose=sampler_cfg.verbose,
)
logger.info("Local_Global_Sampler_Bundle created successfully.")
logger.success("Local_Global_Sampler_Bundle created.")

logger.info("Saving sampler configuration to HDF5.")
write_to_hdf5(
INFERENCE_OUTPUT_FILENAME,
dataset_path="sampler_cfg",
mode="a",
attrs={
"sampler_name": "flowMC",
"n_chains": n_chains,
"n_dims": n_dims,
"n_local_steps": sampler_cfg.n_local_steps,
"n_global_steps": sampler_cfg.n_global_steps,
"n_training_loops": sampler_cfg.n_training_loops,
"n_production_loops": sampler_cfg.n_production_loops,
"n_epochs": sampler_cfg.n_epochs,
"local_sampler_name": sampler_cfg.local_sampler_name,
"step_size": sampler_cfg.step_size,
"mass_matrix": mass_matrix,
"n_leapfrog": sampler_cfg.n_leapfrog,
"chain_batch_size": sampler_cfg.chain_batch_size,
"rq_spline_hidden_units": sampler_cfg.rq_spline_hidden_units,
"rq_spline_n_bins": sampler_cfg.rq_spline_n_bins,
"rq_spline_n_layers": sampler_cfg.rq_spline_n_layers,
"rq_spline_range": sampler_cfg.rq_spline_range,
"learning_rate": sampler_cfg.learning_rate,
"batch_size": sampler_cfg.batch_size,
"n_max_examples": sampler_cfg.n_max_examples,
"history_window": sampler_cfg.history_window,
"local_thinning": sampler_cfg.local_thinning,
"global_thinning": sampler_cfg.global_thinning,
"n_NFproposal_batch_size": sampler_cfg.n_NFproposal_batch_size,
"verbose": sampler_cfg.verbose,
},
)
logger.success("Sampler configuration saved.")

sampler = Sampler(
n_dims,
Expand Down
Loading
Loading