Skip to content

Commit cbf043a

Browse files
QazalbashIamMuhammadZeeshangemini-code-assist[bot]
authored
refactor: move to popsummary format to store probs (#850)
* refactor: move to popsummary format to store probs * using the chain name instead of index for clarity * Update src/gwkokab/analysis/utils/marginals.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: M. Zeeshan <m.zeeshan5885@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent f424617 commit cbf043a

4 files changed

Lines changed: 151 additions & 116 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dependencies = [
4848
"pandas>=2.2.0",
4949
"papermill>=2.7.0",
5050
"plotly>=6.7.0",
51+
"popsummary>=0.1.0",
5152
"pydantic>=2.12.0",
5253
"quadax>=0.2.5",
5354
"rich>=14.0.0",

src/gwkokab/analysis/core/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,12 @@ def read_attrs_from_hdf5(
298298
_value = None
299299
elif isinstance(value, bytes):
300300
_value = value.decode("utf-8")
301+
elif isinstance(value, np.integer):
302+
_value = int(value)
303+
elif isinstance(value, np.floating):
304+
_value = float(value)
305+
elif isinstance(value, np.bool_):
306+
_value = bool(value)
301307
elif isinstance(value, str):
302308
try:
303309
_value = json.loads(value)

src/gwkokab/analysis/utils/marginals.py

Lines changed: 55 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
import functools as ft
66
import inspect
7+
from collections import defaultdict
78
from pathlib import Path
89
from typing import Callable, NamedTuple
910

1011
import h5py
1112
import jax
1213
import numpy as np
14+
import popsummary as ps
1315
from jax import jit, numpy as jnp
1416
from jaxtyping import Array
1517
from matplotlib import pyplot as plt
@@ -218,55 +220,6 @@ def _compute_component_marginals_single_sample(
218220
return jax.lax.map(single_sample_fn, samples_batch, batch_size=batch_size)
219221

220222

221-
def read_domains(
222-
filepath: str | Path,
223-
) -> dict[str, tuple[float, float, int]]:
224-
"""Read domain specifications from an HDF5 file.
225-
226-
Parameters
227-
----------
228-
filepath : str | Path
229-
The path to the HDF5 file containing the domain specifications.
230-
231-
Returns
232-
-------
233-
dict[str, tuple[float, float, int]]
234-
A dictionary mapping parameter names to their corresponding domain specifications.
235-
Each value in the dictionary is a tuple containing the start, stop, and number of
236-
points for the domain of the parameter.
237-
"""
238-
with h5py.File(filepath, "r") as f:
239-
domains_array = f["probs"].attrs["domains"]
240-
return {
241-
param.decode("utf-8"): (float(start), float(stop), int(num_points))
242-
for param, start, stop, num_points in domains_array
243-
}
244-
245-
246-
def write_domains(f: h5py.File, domain_cfg: dict[str, tuple[float, float, int]]):
247-
"""Write domain specifications to an HDF5 file.
248-
249-
Parameters
250-
----------
251-
f : h5py.File
252-
The HDF5 file where the domain specifications will be saved.
253-
domain_cfg : dict[str, tuple[float, float, int]]
254-
A dictionary mapping parameter names to their corresponding domain specifications.
255-
Each value in the dictionary is a tuple containing the start, stop, and number of
256-
points for the domain of the parameter.
257-
"""
258-
string_dt = h5py.string_dtype(encoding="utf-8")
259-
f.attrs["domains"] = np.asarray(
260-
[(str(param), *info) for param, info in domain_cfg.items()],
261-
dtype=np.dtype([
262-
("param", string_dt),
263-
("start", np.float32),
264-
("stop", np.float32),
265-
("num_points", np.uint32),
266-
]),
267-
)
268-
269-
270223
def save_results_to_hdf5(
271224
constants: dict,
272225
variables_index: dict[str, int],
@@ -303,21 +256,45 @@ def save_results_to_hdf5(
303256
filepath : str | Path
304257
The path to the HDF5 file where the results will be saved.
305258
"""
259+
# TODO(Qazalbash): save labels in numpyro sampler case and
260+
# use them instead of following logic
261+
inverted_variables_index = defaultdict(list)
262+
for param, idx in variables_index.items():
263+
inverted_variables_index[idx].append(param)
264+
265+
hyperparameters = [0] * len(inverted_variables_index)
266+
for idx, params in inverted_variables_index.items():
267+
canonical_param = sorted(params)[0]
268+
hyperparameters[idx] = canonical_param
269+
270+
result = ps.PopulationResult(
271+
fname=filepath,
272+
hyperparameters=hyperparameters,
273+
default_h5py_kwargs={"compression": "gzip", "compression_opts": 9},
274+
)
306275
N_components = len(batched_results)
307276

308-
with h5py.File(filepath, "w") as f:
309-
write_to_hdf5(f, dataset_path="constants", attrs=constants)
310-
write_to_hdf5(f, dataset_path="variables_index", attrs=variables_index)
277+
result.set_hyperparameter_samples(samples, overwrite=True)
311278

312-
probs_group = f.create_group("probs")
279+
domains = {p: np.linspace(*info).reshape(1, -1) for p, info in domain_cfg.items()}
313280

314-
write_domains(probs_group, domain_cfg)
315-
write_to_hdf5(probs_group, "samples", samples)
281+
for i in range(N_components):
282+
for idx, param in enumerate(parameters):
283+
param = str(param)
284+
rate_scaled_pdf = np.array(batched_results[i][idx])
285+
result.set_rates_on_grids(
286+
f"component_{i}_{param}",
287+
grid_params=param,
288+
positions=domains[param],
289+
rates=rate_scaled_pdf,
290+
overwrite=True,
291+
)
316292

317-
for i in range(N_components):
318-
comp_i_group = probs_group.create_group(f"component_{i}")
319-
for idx, param in enumerate(parameters):
320-
write_to_hdf5(comp_i_group, param, np.array(batched_results[i][idx]))
293+
write_to_hdf5(
294+
filepath,
295+
dataset_path="/posterior/hyperparameter_samples",
296+
attrs={"constants": constants, "variables_index": variables_index},
297+
)
321298

322299

323300
def remove_comoving_volume_factor(
@@ -391,7 +368,9 @@ def generate_marginal_probs(
391368

392369
with h5py.File(input_file_path, "r") as f:
393370
constants = read_attrs_from_hdf5(f, "constants")
394-
variables_index = read_attrs_from_hdf5(f, "variables_index")
371+
variables_index = {
372+
p: int(idx) for p, idx in read_attrs_from_hdf5(f, "variables_index").items()
373+
}
395374
samples_arr = read_from_hdf5(f, "samples")
396375

397376
if max_samples is not None:
@@ -486,14 +465,17 @@ def plot_marginal_with_intervals(
486465
normalize : bool, optional
487466
Whether to normalize the marginal densities, by default False
488467
"""
489-
domains = read_domains(filename)
490-
domain = np.linspace(*domains[parameter])
468+
result = ps.PopulationResult(filename)
469+
470+
datasets = [f"component_{i}_{parameter}" for i in component_idxs]
471+
472+
samples = result.get_hyperparameter_samples()
491473

492-
datasets = [f"/probs/component_{i}/{parameter}" for i in component_idxs]
474+
cv_dict = read_attrs_from_hdf5(filename, "/posterior/hyperparameter_samples")
475+
476+
constants = cv_dict["constants"]
477+
variables_index = cv_dict["variables_index"]
493478

494-
samples = read_from_hdf5(filename, "probs/samples")
495-
constants = read_attrs_from_hdf5(filename, "constants")
496-
variables_index = read_attrs_from_hdf5(filename, "variables_index")
497479
params = {p: samples[:, m][:, np.newaxis] for p, m in variables_index.items()}
498480
params.update(constants)
499481

@@ -508,8 +490,13 @@ def plot_marginal_with_intervals(
508490
w = weights[i]
509491
weight_values.append(w)
510492

511-
with h5py.File(filename, "r") as f:
512-
data = [np.asarray(f[dataset][:]) for dataset in datasets]
493+
pos_and_rates: list[tuple[np.ndarray, np.ndarray]] = [
494+
result.get_rates_on_grids(dataset) for dataset in datasets
495+
]
496+
data = [rate for _, rate in pos_and_rates]
497+
498+
# assume all components share the same domain for the parameter of interest
499+
domain = np.squeeze(pos_and_rates[0][0], axis=0)
513500

514501
weighted_data = np.sum([w * d for w, d in zip(weight_values, data)], axis=0)
515502

0 commit comments

Comments
 (0)