44
55import functools as ft
66import inspect
7+ from collections import defaultdict
78from pathlib import Path
89from typing import Callable , NamedTuple
910
1011import h5py
1112import jax
1213import numpy as np
14+ import popsummary as ps
1315from jax import jit , numpy as jnp
1416from jaxtyping import Array
1517from matplotlib import pyplot as plt
@@ -218,55 +220,6 @@ def _compute_component_marginals_single_sample(
218220 return jax .lax .map (single_sample_fn , samples_batch , batch_size = batch_size )
219221
220222
221- def read_domains (
222- filepath : str | Path ,
223- ) -> dict [str , tuple [float , float , int ]]:
224- """Read domain specifications from an HDF5 file.
225-
226- Parameters
227- ----------
228- filepath : str | Path
229- The path to the HDF5 file containing the domain specifications.
230-
231- Returns
232- -------
233- dict[str, tuple[float, float, int]]
234- A dictionary mapping parameter names to their corresponding domain specifications.
235- Each value in the dictionary is a tuple containing the start, stop, and number of
236- points for the domain of the parameter.
237- """
238- with h5py .File (filepath , "r" ) as f :
239- domains_array = f ["probs" ].attrs ["domains" ]
240- return {
241- param .decode ("utf-8" ): (float (start ), float (stop ), int (num_points ))
242- for param , start , stop , num_points in domains_array
243- }
244-
245-
246- def write_domains (f : h5py .File , domain_cfg : dict [str , tuple [float , float , int ]]):
247- """Write domain specifications to an HDF5 file.
248-
249- Parameters
250- ----------
251- f : h5py.File
252- The HDF5 file where the domain specifications will be saved.
253- domain_cfg : dict[str, tuple[float, float, int]]
254- A dictionary mapping parameter names to their corresponding domain specifications.
255- Each value in the dictionary is a tuple containing the start, stop, and number of
256- points for the domain of the parameter.
257- """
258- string_dt = h5py .string_dtype (encoding = "utf-8" )
259- f .attrs ["domains" ] = np .asarray (
260- [(str (param ), * info ) for param , info in domain_cfg .items ()],
261- dtype = np .dtype ([
262- ("param" , string_dt ),
263- ("start" , np .float32 ),
264- ("stop" , np .float32 ),
265- ("num_points" , np .uint32 ),
266- ]),
267- )
268-
269-
270223def save_results_to_hdf5 (
271224 constants : dict ,
272225 variables_index : dict [str , int ],
@@ -303,21 +256,45 @@ def save_results_to_hdf5(
303256 filepath : str | Path
304257 The path to the HDF5 file where the results will be saved.
305258 """
259+ # TODO(Qazalbash): save labels in numpyro sampler case and
260+ # use them instead of following logic
261+ inverted_variables_index = defaultdict (list )
262+ for param , idx in variables_index .items ():
263+ inverted_variables_index [idx ].append (param )
264+
265+ hyperparameters = [0 ] * len (inverted_variables_index )
266+ for idx , params in inverted_variables_index .items ():
267+ canonical_param = sorted (params )[0 ]
268+ hyperparameters [idx ] = canonical_param
269+
270+ result = ps .PopulationResult (
271+ fname = filepath ,
272+ hyperparameters = hyperparameters ,
273+ default_h5py_kwargs = {"compression" : "gzip" , "compression_opts" : 9 },
274+ )
306275 N_components = len (batched_results )
307276
308- with h5py .File (filepath , "w" ) as f :
309- write_to_hdf5 (f , dataset_path = "constants" , attrs = constants )
310- write_to_hdf5 (f , dataset_path = "variables_index" , attrs = variables_index )
277+ result .set_hyperparameter_samples (samples , overwrite = True )
311278
312- probs_group = f . create_group ( "probs" )
279+ domains = { p : np . linspace ( * info ). reshape ( 1 , - 1 ) for p , info in domain_cfg . items ()}
313280
314- write_domains (probs_group , domain_cfg )
315- write_to_hdf5 (probs_group , "samples" , samples )
281+ for i in range (N_components ):
282+ for idx , param in enumerate (parameters ):
283+ param = str (param )
284+ rate_scaled_pdf = np .array (batched_results [i ][idx ])
285+ result .set_rates_on_grids (
286+ f"component_{ i } _{ param } " ,
287+ grid_params = param ,
288+ positions = domains [param ],
289+ rates = rate_scaled_pdf ,
290+ overwrite = True ,
291+ )
316292
317- for i in range (N_components ):
318- comp_i_group = probs_group .create_group (f"component_{ i } " )
319- for idx , param in enumerate (parameters ):
320- write_to_hdf5 (comp_i_group , param , np .array (batched_results [i ][idx ]))
293+ write_to_hdf5 (
294+ filepath ,
295+ dataset_path = "/posterior/hyperparameter_samples" ,
296+ attrs = {"constants" : constants , "variables_index" : variables_index },
297+ )
321298
322299
323300def remove_comoving_volume_factor (
@@ -391,7 +368,9 @@ def generate_marginal_probs(
391368
392369 with h5py .File (input_file_path , "r" ) as f :
393370 constants = read_attrs_from_hdf5 (f , "constants" )
394- variables_index = read_attrs_from_hdf5 (f , "variables_index" )
371+ variables_index = {
372+ p : int (idx ) for p , idx in read_attrs_from_hdf5 (f , "variables_index" ).items ()
373+ }
395374 samples_arr = read_from_hdf5 (f , "samples" )
396375
397376 if max_samples is not None :
@@ -486,14 +465,17 @@ def plot_marginal_with_intervals(
486465 normalize : bool, optional
487466 Whether to normalize the marginal densities, by default False
488467 """
489- domains = read_domains (filename )
490- domain = np .linspace (* domains [parameter ])
468+ result = ps .PopulationResult (filename )
469+
470+ datasets = [f"component_{ i } _{ parameter } " for i in component_idxs ]
471+
472+ samples = result .get_hyperparameter_samples ()
491473
492- datasets = [f"/probs/component_{ i } /{ parameter } " for i in component_idxs ]
474+ cv_dict = read_attrs_from_hdf5 (filename , "/posterior/hyperparameter_samples" )
475+
476+ constants = cv_dict ["constants" ]
477+ variables_index = cv_dict ["variables_index" ]
493478
494- samples = read_from_hdf5 (filename , "probs/samples" )
495- constants = read_attrs_from_hdf5 (filename , "constants" )
496- variables_index = read_attrs_from_hdf5 (filename , "variables_index" )
497479 params = {p : samples [:, m ][:, np .newaxis ] for p , m in variables_index .items ()}
498480 params .update (constants )
499481
@@ -508,8 +490,13 @@ def plot_marginal_with_intervals(
508490 w = weights [i ]
509491 weight_values .append (w )
510492
511- with h5py .File (filename , "r" ) as f :
512- data = [np .asarray (f [dataset ][:]) for dataset in datasets ]
493+ pos_and_rates : list [tuple [np .ndarray , np .ndarray ]] = [
494+ result .get_rates_on_grids (dataset ) for dataset in datasets
495+ ]
496+ data = [rate for _ , rate in pos_and_rates ]
497+
498+ # assume all components share the same domain for the parameter of interest
499+ domain = np .squeeze (pos_and_rates [0 ][0 ], axis = 0 )
513500
514501 weighted_data = np .sum ([w * d for w , d in zip (weight_values , data )], axis = 0 )
515502
0 commit comments