99
1010import equinox as eqx
1111import optax
12- from beartype import beartype as typechecker
12+ from beartype import beartype as typechecker
13+ from beartype .door import is_bearable
1314import numpy as np
1415from scipy .stats import qmc
1516from ml_collections import ConfigDict
2324from nn import fit_nn
2425from pca import PCA
2526from sbiax .utils import marker
26- from sbiax .compression .linear import mle
2727
2828typecheck = jaxtyped (typechecker = typechecker )
2929
@@ -45,11 +45,11 @@ class Dataset:
4545
4646
4747def convert_dataset_to_jax (dataset : Dataset ) -> Dataset :
48- def convert_to_jax_array ( a ):
49- if isinstance ( a , np . ndarray ):
50- a = jnp . asarray ( a )
51- return a
52- return jax . tree . map ( convert_to_jax_array , dataset )
48+ return jax . tree . map (
49+ lambda a : jnp . asarray ( a ),
50+ dataset ,
51+ is_leaf = lambda a : isinstance ( a , np . ndarray )
52+ )
5353
5454
5555@typecheck
@@ -232,7 +232,10 @@ def _maybe_reduce(
232232 )
233233
234234 if verbose :
235- print ("Processed data shapes:" , [_ .shape for _ in [fiducial_pdfs_z_R , latin_pdfs_z_R , derivatives ]])
235+ print (
236+ "Processed data shapes:" ,
237+ [_ .shape for _ in [fiducial_pdfs_z_R , latin_pdfs_z_R , derivatives ]]
238+ )
236239
237240 return fiducial_pdfs_z_R , latin_pdfs_z_R , derivatives_z_R
238241
@@ -251,36 +254,35 @@ def remove_nuisances(dataset: Dataset) -> Dataset:
251254 return dataset
252255
253256
254- @typecheck
255- def calculate_derivatives (
256- derivatives_pm : Float [np .ndarray , "500 5 z R 2 d" ],
257- alpha : Float [np .ndarray , "p" ],
258- dparams : Float [np .ndarray , "p" ],
259- parameter_strings : list [str ],
260- parameter_derivative_names : list [list [str ]],
261- * ,
262- verbose : bool = False
263- ) -> Float [np .ndarray , "500 5 z R d" ]:
264-
265- derivatives = derivatives_pm [..., 1 , :] - derivatives_pm [..., 0 , :]
266-
267- for p in range (alpha .size ):
268- if verbose :
269- print (
270- "Parameter strings / dp / dp_name" ,
271- parameter_strings [p ], dparams [p ], parameter_derivative_names [p ]
272- )
273- derivatives [:, p , ...] = derivatives [:, p , ...] / dparams [p ] # NOTE: OK before or after reducing cumulants
274-
275- assert derivatives .ndim == 5 , "{}" .format (derivatives .shape )
276-
277- return derivatives
278-
279-
280257def get_cumulant_data (
281258 config : ConfigDict , * , verbose : bool = False , results_dir : Optional [str ] = None
282259) -> Dataset :
283260
261+ @typecheck
262+ def calculate_derivatives (
263+ derivatives_pm : Float [np .ndarray , "500 5 z R 2 d" ],
264+ alpha : Float [np .ndarray , "p" ],
265+ dparams : Float [np .ndarray , "p" ],
266+ parameter_strings : list [str ],
267+ parameter_derivative_names : list [list [str ]],
268+ * ,
269+ verbose : bool = False
270+ ) -> Float [np .ndarray , "500 5 z R d" ]:
271+
272+ derivatives = derivatives_pm [..., 1 , :] - derivatives_pm [..., 0 , :]
273+
274+ for p in range (alpha .size ):
275+ if verbose :
276+ print (
277+ "Parameter strings / dp / dp_name" ,
278+ parameter_strings [p ], dparams [p ], parameter_derivative_names [p ]
279+ )
280+ derivatives [:, p , ...] = derivatives [:, p , ...] / dparams [p ] # NOTE: OK before or after reducing cumulants
281+
282+ assert derivatives .ndim == 5 , "{}" .format (derivatives .shape )
283+
284+ return derivatives
285+
284286 data_dir , * _ = get_save_and_load_dirs ()
285287
286288 (
@@ -313,12 +315,12 @@ def get_cumulant_data(
313315 fiducial_moments ,
314316 latin_moments ,
315317 latin_moments_parameters ,
316- derivatives
318+ derivatives_pm
317319 ) = get_raw_data (data_dir , verbose = verbose )
318320
319321 # Euler derivative from plus minus statistics (NOTE: derivatives: Float[np.ndarray, "500 p z R 2 d"])
320322 derivatives = calculate_derivatives (
321- derivatives ,
323+ derivatives_pm ,
322324 alpha ,
323325 dparams ,
324326 parameter_strings = parameter_strings ,
@@ -371,6 +373,20 @@ def get_cumulant_data(
371373 F = np .linalg .multi_dot ([dmu , Cinv , dmu .T ])
372374 Finv = np .linalg .inv (F )
373375
376+ # dataset = Dataset(
377+ # alpha=alpha,
378+ # lower=lower,
379+ # upper=upper,
380+ # parameter_strings=parameter_strings,
381+ # Finv=Finv,
382+ # Cinv=Cinv,
383+ # C=C,
384+ # fiducial_data=fiducial_moments_z_R,
385+ # data=latin_moments_z_R,
386+ # parameters=latin_moments_parameters,
387+ # derivatives=derivatives
388+ # )
389+
374390 dataset = Dataset (
375391 alpha = jnp .asarray (alpha ),
376392 lower = jnp .asarray (lower ),
@@ -385,6 +401,9 @@ def get_cumulant_data(
385401 derivatives = jnp .asarray (derivatives )
386402 )
387403
404+ # dataset = convert_dataset_to_jax(dataset)
405+ # assert is_bearable(dataset, Dataset)
406+
388407 if verbose :
389408 corr_matrix = np .corrcoef (fiducial_moments_z_R , rowvar = False ) + 1e-6 # Log colouring
390409
@@ -446,9 +465,7 @@ def get_cumulant_data(
446465 return dataset
447466
448467
449- def get_prior (config : ConfigDict ) -> tfd .Distribution :
450-
451- dataset : Dataset = get_data (config )
468+ def get_prior (config : ConfigDict , dataset : Dataset ) -> tfd .Distribution :
452469
453470 lower = jnp .asarray (dataset .lower )
454471 upper = jnp .asarray (dataset .upper )
@@ -519,7 +536,7 @@ def sample_prior(
519536
520537@typecheck
521538def get_linearised_data (
522- config : ConfigDict
539+ config : ConfigDict , dataset : Dataset
523540) -> Tuple [Float [Array , "n d" ], Float [Array , "n p" ]]:
524541 """
525542 Get linearised PDFs and get their MLEs
@@ -537,8 +554,6 @@ def get_linearised_data(
537554
538555 key_parameters , key_simulations = jr .split (key )
539556
540- dataset : Dataset = get_cumulant_data (config )
541-
542557 if config .n_linear_sims is not None :
543558 Y = sample_prior (
544559 key_parameters ,
@@ -640,7 +655,7 @@ def get_data(config: ConfigDict, *, verbose: bool = False, results_dir: Optional
640655 if hasattr (config , "linearised" ):
641656 if config .linearised :
642657 print ("Using linearised model, Gaussian noise." )
643- D , Y = get_linearised_data (config )
658+ D , Y = get_linearised_data (config , dataset )
644659
645660 dataset = replace (dataset , data = D , parameters = Y )
646661
@@ -657,17 +672,44 @@ def get_data(config: ConfigDict, *, verbose: bool = False, results_dir: Optional
657672 return dataset
658673
659674
675+ @typecheck
676+ def mle (
677+ d : Float [Array , "d" ],
678+ pi : Float [Array , "p" ],
679+ Finv : Float [Array , "p p" ],
680+ mu : Float [Array , "d" ],
681+ dmu : Float [Array , "p d" ],
682+ precision : Float [Array , "d d" ]
683+ ) -> Float [Array , "p" ]:
684+ """
685+ Calculates a maximum likelihood estimator (MLE) from a datavector by
686+ assuming a linear model `mu` in parameters `pi` and using
687+
688+ Args:
689+ d (`Array`): The datavector to compress.
690+ p (`Array`): The estimated parameters of the datavector (e.g. a fiducial set).
691+ Finv (`Array`): The Fisher matrix. Calculated with a precision matrix (e.g. `precision`) and
692+ theory derivatives.
693+ mu (`Array`): The model evaluated at the estimated set of parameters `pi`.
694+ dmu (`Array`): The first-order theory derivatives (for the implicitly assumed linear model,
695+ these are parameter independent!)
696+ precision (`Array`): The precision matrix - defined as the inverse of the data covariance matrix.
697+
698+ Returns:
699+ `Array`: the MLE.
700+ """
701+ return pi + jnp .linalg .multi_dot ([Finv , dmu , precision , d - mu ])
702+
703+
660704@typecheck
661705def get_linear_compressor (
662- config : ConfigDict
706+ config : ConfigDict , dataset : Dataset
663707) -> Callable [[Float [Array , "d" ], Float [Array , "p" ]], Float [Array , "p" ]]:
664708 """
665709 Get Chi^2 minimisation function; compressing datavector
666710 at estimated parameters to summary
667711 """
668712
669- dataset : Dataset = get_data (config )
670-
671713 mu = jnp .mean (dataset .fiducial_data , axis = 0 )
672714 dmu = jnp .mean (dataset .derivatives , axis = 0 )
673715
@@ -738,7 +780,7 @@ def get_compression_fn(key, config, dataset, *, results_dir):
738780 compressor = lambda d , p : net (preprocess_fn (d )) # Ignore parameter kwarg!
739781
740782 if config .compression == "linear" :
741- compressor = get_linear_compressor (config )
783+ compressor = get_linear_compressor (config , dataset )
742784
743785 # Fit PCA transform to simulated data and apply after compressing
744786 if config .use_pca :
@@ -763,13 +805,10 @@ def get_compression_fn(key, config, dataset, *, results_dir):
763805
764806@typecheck
765807def get_datavector (
766- key : PRNGKeyArray , config : ConfigDict , n : int = 1
808+ key : PRNGKeyArray , config : ConfigDict , dataset : Dataset , n : int = 1
767809) -> Float [Array , "... d" ]:
768810 """ Measurement: either Gaussian linear model or not """
769811
770- # NOTE: must be working with fiducial parameters!
771- dataset : Dataset = get_data (config )
772-
773812 # Choose a linearised model datavector or simply one of the Quijote realisations
774813 # which corresponds to a non-linearised datavector with Gaussian noise
775814 if not config .use_expectation :
@@ -788,4 +827,71 @@ def get_datavector(
788827 if not (n > 1 ):
789828 datavector = jnp .squeeze (datavector , axis = 0 )
790829
791- return datavector # Remove batch axis by default
830+ return datavector # Remove batch axis by default
831+
832+
833+ @dataclass
834+ class CumulantsDataset :
835+ """ Dataset for Simulation-Based Inference with cumulants of the matter PDF """
836+ config : ConfigDict
837+ data : Dataset
838+ prior : tfd .Distribution
839+ compression_fn : Callable
840+ results_dir : str
841+
842+ def __init__ (
843+ self ,
844+ config : ConfigDict ,
845+ * ,
846+ verbose : bool = False ,
847+ results_dir : Optional [str ] = None
848+ ):
849+ self .config = config
850+ self .data = get_data (
851+ config , verbose = verbose , results_dir = results_dir
852+ )
853+ self .prior = get_prior (config , self .data ) # Possibly not equal to Quijote prior
854+ self .results_dir = results_dir
855+
856+ key = jr .key (config .seed )
857+ self .compression_fn = get_compression_fn (
858+ key , self .config , self .data , results_dir = self .results_dir
859+ )
860+
861+ def get_parameter_strings (self ):
862+ return get_parameter_strings ()
863+
864+ def sample_prior (self , key : PRNGKeyArray , n : int , * , hypercube : bool = True ) -> Float [Array , "n p" ]:
865+ # Sample Quijote prior which may not be the same as inference prior
866+ P = sample_prior (
867+ key ,
868+ n ,
869+ self .data .alpha ,
870+ self .data .lower ,
871+ self .data .upper ,
872+ hypercube = hypercube
873+ )
874+ return P
875+
876+ def get_compression_fn (self ):
877+ return self .compression_fn
878+
879+ def get_datavector (self , key : PRNGKeyArray , n : int = 1 ) -> Float [Array , "... d" ]:
880+ d = get_datavector (key , self .config , self .data , n )
881+ return d
882+
883+ def get_linearised_datavector (self , key : PRNGKeyArray , n : int = 1 ) -> Float [Array , "... d" ]:
884+ # Sample datavector from linearised Gaussian model
885+ mu = jnp .mean (self .data .fiducial_data , axis = 0 )
886+ d = jr .multivariate_normal (key , mu , self .data .C , (n ,))
887+ if not (n > 1 ):
888+ d = jnp .squeeze (d , axis = 0 )
889+ return d
890+
891+ def get_linearised_data (self ):
892+ # Get linearised data (e.g. pre-training), where config only sets how many simulations
893+ return get_linearised_data (self .config , self .data )
894+
895+ def get_preprocess_fn (self ):
896+ # Get (X, P) preprocessor?
897+ ...
0 commit comments