Skip to content
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"RIFT>=0.0.15.11",
"arviz>=0.20.0",
"chex>=0.1.87",
"corner>=2.2.3",
"equinox>=0.11.3",
"flowMC==0.4.5",
"glasbey>=0.3.0",
Expand All @@ -39,10 +40,13 @@ dependencies = [
"loguru>=0.7.0",
"matplotlib>=3.9.0",
"mplcursors>=0.6",
"nbconvert>=7.17.1",
Comment thread
Qazalbash marked this conversation as resolved.
"numpy<3",
"numpyro>=0.19.0",
"optax<0.2.7",
"pandas>=2.2.0",
"papermill>=2.7.0",
"plotly>=6.7.0",
"pydantic>=2.12.0",
"quadax>=0.2.5",
"rich>=14.0.0",
Expand Down Expand Up @@ -91,6 +95,7 @@ gwk_param_lens = "gwkokab_scripts.param_lens:main"
gwk_pe_diagnostics = "gwkokab_scripts.pe_diagnostics:main"
gwk_ppd_plot = "gwkokab_scripts.ppd_plot:main"
gwk_r_hat_plot = "gwkokab_scripts.r_hat_plot:main"
gwk_report = "gwkokab.analysis.report.generate_report:generate_report"
gwk_scatter2d = "gwkokab_scripts.scatter2d:main"
gwk_scatter3d = "gwkokab_scripts.scatter3d:main"
gwk_trace_plot = "gwkokab_scripts.trace_plot:main"
Expand Down Expand Up @@ -168,3 +173,6 @@ find = { where = ["src/"] }

[tool.setuptools.dynamic]
version = { attr = "gwkokab.version.__version__" }

[tool.setuptools.package-data]
"gwkokab.analysis.report" = ["template_report.ipynb"]
236 changes: 91 additions & 145 deletions src/gwkokab/analysis/core/flowMC_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
# SPDX-License-Identifier: Apache-2.0


import os
import warnings
from typing import Any, Callable, Dict, List, Literal, Optional

import equinox as eqx
import h5py
import jax
import numpy as np
import tqdm
Expand All @@ -31,15 +30,14 @@

from gwkokab.analysis.core.guru import Guru, guru_arg_parser
from gwkokab.analysis.core.inference_io import FlowMCGlobalConfig
from gwkokab.analysis.core.utils import write_to_hdf5
from gwkokab.analysis.utils.literals import (
INFERENCE_DIRECTORY,
POSTERIOR_SAMPLES_FILENAME,
CHAIN_GROUP_FORMAT,
INFERENCE_OUTPUT_FILENAME,
SAMPLES_GROUP_NAME,
)
from gwkokab.models.utils import JointDistribution
from gwkokab.utils.exceptions import LoggedUserWarning, LoggedValueError


_INFERENCE_DIRECTORY = "flowMC_" + INFERENCE_DIRECTORY
from gwkokab.utils.exceptions import LoggedValueError


# WARNING: do not change anything in this class
Expand Down Expand Up @@ -594,8 +592,6 @@ def sample(
n_global_steps = self.strategy_order.count("model_trainer")
n_local_steps = (n_total_strategy - 6 * n_global_steps - 2) // 4

os.makedirs(_INFERENCE_DIRECTORY, exist_ok=True)

with tqdm.tqdm(range(n_global_steps), total=n_global_steps) as pbar:
pbar.set_description("Global Tuning")
for i in pbar:
Expand Down Expand Up @@ -655,174 +651,89 @@ def deserialize(self):
raise NotImplementedError


def _same_length_arrays(length: int, *arrays: np.ndarray) -> tuple[np.ndarray, ...]:
"""This function pads the arrays with None to make them the same length.

Parameters
----------
length : int
The length of the arrays.
arrays : np.ndarray
The arrays to pad.

Returns
-------
tuple[np.ndarray, ...]
The padded arrays.
"""
padded_arrays = []
for array in arrays:
padded_array = np.empty((length,))
padded_array[..., : array.shape[0]] = array
padded_array[..., array.shape[0] :] = None
padded_arrays.append(padded_array)
return tuple(padded_arrays)


def _save_acceptances(resources: dict) -> None:
"""Saves global and local acceptance rates to disk."""
# Mean acceptances
for acc_type in ["global", "local"]:
train_key = f"{acc_type}_accs_training"
prod_key = f"{acc_type}_accs_production"

# Check if training data was cleared
train_data = (
np.array(resources[train_key].data).mean(0)
if train_key in resources
else np.array([])
)
prod_data = (
np.array(resources[prod_key].data).mean(0)
if prod_key in resources
else np.array([])
)
"""Overwrites global and local acceptance rates in the HDF5 file."""
with h5py.File(INFERENCE_OUTPUT_FILENAME, "a") as f:
for acc_type in ["global", "local"]:
train_key = f"{acc_type}_accs_training"
prod_key = f"{acc_type}_accs_production"

if (max_len := max(len(train_data), len(prod_data))) == 0:
warnings.warn(
f"No data found for {acc_type} acceptance rates in both phases.",
LoggedUserWarning,
)
if train_key in resources and len(resources[train_key].data) > 0:
train_data = np.array(resources[train_key].data).mean(0)
write_to_hdf5(f, f"acceptances/{acc_type}/train", train_data)

np.savetxt(
f"{_INFERENCE_DIRECTORY}/{acc_type}_accs.dat",
np.column_stack(_same_length_arrays(max_len, train_data, prod_data)),
header="train prod",
comments="#",
)
if prod_key in resources and len(resources[prod_key].data) > 0:
prod_data = np.array(resources[prod_key].data).mean(0)
write_to_hdf5(f, f"acceptances/{acc_type}/prod", prod_data)

Comment thread
Qazalbash marked this conversation as resolved.

def _save_chains(resources: Dict, labels: List[str], *, is_training: bool) -> None:
"""Saves the chains to disk.
def _save_chains(resources: dict, labels: list[str], *, is_training: bool) -> None:
"""Overwrites the chains and log probabilities in the HDF5 file."""
phase = "train" if is_training else "prod"
pos_key = f"positions_{'training' if is_training else 'production'}"
lp_key = f"log_prob_{'training' if is_training else 'production'}"

Parameters
----------
resources : Dict
dictionary of resources
labels : List[str]
list of parameter labels
is_training : bool
whether the phase is training or production
"""
header = " ".join(labels)

if is_training:
phase = "train"
pos_key = "positions_training"
lp_key = "log_prob_training"
else:
phase = "prod"
pos_key = "positions_production"
lp_key = "log_prob_production"

if pos_key not in resources:
logger.warning(f"Key {pos_key} not found in resources. Skipping save.")
if pos_key not in resources or lp_key not in resources:
return

positions = np.array(resources[pos_key].data)
log_probs = np.array(resources[lp_key].data)
positions = np.array(resources[pos_key].data) # Shape: (n_chains, n_steps, n_dims)
log_probs = np.array(resources[lp_key].data) # Shape: (n_chains, n_steps)

n_chains = positions.shape[0]
width = len(str(n_chains - 1))

for n in range(n_chains):
tag = str(n).zfill(width)
np.savetxt(
f"{_INFERENCE_DIRECTORY}/{phase}_chains_{tag}.dat",
positions[n],
header=header,
)
np.savetxt(
f"{_INFERENCE_DIRECTORY}/log_prob_{phase}_{tag}.dat",
log_probs[n],
header=phase,
)

# Store parameter labels as metadata attributes
with h5py.File(INFERENCE_OUTPUT_FILENAME, "a") as f:
f.attrs["labels"] = labels

# Overwrite each chain dataset individually with the complete updated sequence
for n in range(n_chains):
dataset_suffix = (
f"chains/{phase}/" + CHAIN_GROUP_FORMAT.format(chain_id=n) + "/"
)
write_to_hdf5(f, dataset_suffix + "positions", positions[n])
write_to_hdf5(f, dataset_suffix + "log_probs", log_probs[n])

Comment thread
Qazalbash marked this conversation as resolved.

def _save_samples(
resources: Dict,
labels: List[str],
resources: dict,
labels: list[str],
n_local_steps_per_loop: int,
n_global_steps_per_loop: int,
) -> None:
"""Saves the posterior samples from the local sampler to disk.

Parameters
----------
resources : Dict
dictionary of resources
labels : List[str]
list of parameter labels
n_local_steps_per_loop : int
number of local steps per loop
n_global_steps_per_loop : int
number of global steps per loop
"""
header = " ".join(labels)
"""Overwrites the posterior samples dataset in the HDF5 file."""
if "positions_production" not in resources:
return

positions = np.array(resources["positions_production"].data)

_, n_production_steps, n_dims = positions.shape

selected_indices = list(
filter(
lambda idx: (
(idx % (n_local_steps_per_loop + n_global_steps_per_loop))
< n_local_steps_per_loop
),
range(n_production_steps),
)
)
selected_indices = [
idx
for idx in range(n_production_steps)
if (idx % (n_local_steps_per_loop + n_global_steps_per_loop))
< n_local_steps_per_loop
]

local_sampler_positions = positions[:, selected_indices, :].reshape(-1, n_dims)
local_sampler_positions = local_sampler_positions[
~np.isneginf(local_sampler_positions).any(axis=1)
]

np.savetxt(
rf"{_INFERENCE_DIRECTORY}/{POSTERIOR_SAMPLES_FILENAME}",
local_sampler_positions,
header=header,
write_to_hdf5(
INFERENCE_OUTPUT_FILENAME, SAMPLES_GROUP_NAME, local_sampler_positions
)


def _save_loss(resources: dict) -> None:
"""Saves the training loss to disk.
"""Overwrites the training loss dataset in the HDF5 file."""
if "loss_buffer" not in resources:
return

Parameters
----------
resources : Dict
dictionary of resources
"""
train_loss_vals = np.array(resources["loss_buffer"].data)
np.savetxt(
rf"{_INFERENCE_DIRECTORY}/loss.dat", train_loss_vals.reshape(-1), header="loss"
)
train_loss_vals = np.array(resources["loss_buffer"].data).reshape(-1)
write_to_hdf5(INFERENCE_OUTPUT_FILENAME, "loss", train_loss_vals)


class FlowMCBased(Guru):
output_directory: str = _INFERENCE_DIRECTORY

def driver(
self,
*,
Expand Down Expand Up @@ -883,7 +794,42 @@ def driver(
n_NFproposal_batch_size=sampler_cfg.n_NFproposal_batch_size,
verbose=sampler_cfg.verbose,
)
logger.info("Local_Global_Sampler_Bundle created successfully.")
logger.success("Local_Global_Sampler_Bundle created.")

logger.info("Saving sampler configuration to HDF5.")
write_to_hdf5(
INFERENCE_OUTPUT_FILENAME,
dataset_path="sampler_cfg",
mode="a",
attrs={
"sampler_name": "flowMC",
"n_chains": n_chains,
"n_dims": n_dims,
"n_local_steps": sampler_cfg.n_local_steps,
"n_global_steps": sampler_cfg.n_global_steps,
"n_training_loops": sampler_cfg.n_training_loops,
"n_production_loops": sampler_cfg.n_production_loops,
"n_epochs": sampler_cfg.n_epochs,
"local_sampler_name": sampler_cfg.local_sampler_name,
"step_size": sampler_cfg.step_size,
"mass_matrix": mass_matrix,
"n_leapfrog": sampler_cfg.n_leapfrog,
"chain_batch_size": sampler_cfg.chain_batch_size,
"rq_spline_hidden_units": sampler_cfg.rq_spline_hidden_units,
"rq_spline_n_bins": sampler_cfg.rq_spline_n_bins,
"rq_spline_n_layers": sampler_cfg.rq_spline_n_layers,
"rq_spline_range": sampler_cfg.rq_spline_range,
"learning_rate": sampler_cfg.learning_rate,
"batch_size": sampler_cfg.batch_size,
"n_max_examples": sampler_cfg.n_max_examples,
"history_window": sampler_cfg.history_window,
"local_thinning": sampler_cfg.local_thinning,
"global_thinning": sampler_cfg.global_thinning,
"n_NFproposal_batch_size": sampler_cfg.n_NFproposal_batch_size,
"verbose": sampler_cfg.verbose,
},
)
logger.success("Sampler configuration saved.")

sampler = Sampler(
n_dims,
Expand Down
18 changes: 12 additions & 6 deletions src/gwkokab/analysis/core/guru.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
from collections.abc import Callable
from typing import Any, Dict, List, Set, Tuple, Union

import h5py
import jax
from jax import lax
from jaxtyping import Array
from loguru import logger
from numpyro.distributions.distribution import Distribution

from gwkokab.analysis.core.utils import PRNGKeyMixin
from gwkokab.analysis.utils.common import read_json, write_json
from gwkokab.analysis.core.utils import PRNGKeyMixin, write_to_hdf5
from gwkokab.analysis.utils.common import read_json
from gwkokab.analysis.utils.literals import INFERENCE_OUTPUT_FILENAME
from gwkokab.analysis.utils.priors import get_processed_priors
from gwkokab.models.utils import JointDistribution, LazyJointDistribution
from gwkokab.utils.exceptions import LoggedValueError
Expand Down Expand Up @@ -198,8 +200,6 @@ class Guru(PRNGKeyMixin):
Guru classes.
"""

output_directory: str

def __init__(
self,
*,
Expand Down Expand Up @@ -275,8 +275,14 @@ def classify_model_parameters(
for value in group_variables.values(): # type: ignore
logger.debug("Recovering variable: {variable}", variable=", ".join(value))

write_json("constants.json", constants)
write_json("nf_samples_mapping.json", variables_index)
with h5py.File(INFERENCE_OUTPUT_FILENAME, "a") as f:
logger.info("Saving constants to HDF5.")
write_to_hdf5(f, dataset_path="constants", attrs=constants)
logger.success("Constants saved.")

logger.info("Saving variables index to HDF5.")
write_to_hdf5(f, dataset_path="variables_index", attrs=variables_index)
logger.success("Variables index saved.")

sorted_variables = sorted(variables.keys())
if len(lazy_order) == 0:
Expand Down
Loading
Loading