Skip to content

Commit 73a532d

Browse files
committed
cumulants dataset dataclass to stop repeated loading of data
1 parent 29d2707 commit 73a532d

File tree

5 files changed

+832
-200
lines changed

5 files changed

+832
-200
lines changed

cumulants/constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def get_base_posteriors_dir():
5656
["s8_m", "s8_p"]
5757
]
5858

59-
# dOm, dOb, dh, dn_s, ds8
59+
# Derivative: dS/dp = (S(p + dp) - S(p - dp)) / 2dp
60+
# > below are 2dp values for dOm, dOb, dh, dn_s, ds8
6061
DPARAMS = np.array(
6162
[
6263
0.3275 - 0.3075,

cumulants/cumulants.py

Lines changed: 158 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import equinox as eqx
1111
import optax
12-
from beartype import beartype as typechecker
12+
from beartype import beartype as typechecker
13+
from beartype.door import is_bearable
1314
import numpy as np
1415
from scipy.stats import qmc
1516
from ml_collections import ConfigDict
@@ -23,7 +24,6 @@
2324
from nn import fit_nn
2425
from pca import PCA
2526
from sbiax.utils import marker
26-
from sbiax.compression.linear import mle
2727

2828
typecheck = jaxtyped(typechecker=typechecker)
2929

@@ -45,11 +45,11 @@ class Dataset:
4545

4646

4747
def convert_dataset_to_jax(dataset: Dataset) -> Dataset:
48-
def convert_to_jax_array(a):
49-
if isinstance(a, np.ndarray):
50-
a = jnp.asarray(a)
51-
return a
52-
return jax.tree.map(convert_to_jax_array, dataset)
48+
return jax.tree.map(
49+
lambda a: jnp.asarray(a),
50+
dataset,
51+
is_leaf=lambda a: isinstance(a, np.ndarray)
52+
)
5353

5454

5555
@typecheck
@@ -232,7 +232,10 @@ def _maybe_reduce(
232232
)
233233

234234
if verbose:
235-
print("Processed data shapes:", [_.shape for _ in [fiducial_pdfs_z_R, latin_pdfs_z_R, derivatives]])
235+
print(
236+
"Processed data shapes:",
237+
[_.shape for _ in [fiducial_pdfs_z_R, latin_pdfs_z_R, derivatives]]
238+
)
236239

237240
return fiducial_pdfs_z_R, latin_pdfs_z_R, derivatives_z_R
238241

@@ -251,36 +254,35 @@ def remove_nuisances(dataset: Dataset) -> Dataset:
251254
return dataset
252255

253256

254-
@typecheck
255-
def calculate_derivatives(
256-
derivatives_pm: Float[np.ndarray, "500 5 z R 2 d"],
257-
alpha: Float[np.ndarray, "p"],
258-
dparams: Float[np.ndarray, "p"],
259-
parameter_strings: list[str],
260-
parameter_derivative_names: list[list[str]],
261-
*,
262-
verbose: bool = False
263-
) -> Float[np.ndarray, "500 5 z R d"]:
264-
265-
derivatives = derivatives_pm[..., 1, :] - derivatives_pm[..., 0, :]
266-
267-
for p in range(alpha.size):
268-
if verbose:
269-
print(
270-
"Parameter strings / dp / dp_name",
271-
parameter_strings[p], dparams[p], parameter_derivative_names[p]
272-
)
273-
derivatives[:, p, ...] = derivatives[:, p, ...] / dparams[p] # NOTE: OK before or after reducing cumulants
274-
275-
assert derivatives.ndim == 5, "{}".format(derivatives.shape)
276-
277-
return derivatives
278-
279-
280257
def get_cumulant_data(
281258
config: ConfigDict, *, verbose: bool = False, results_dir: Optional[str] = None
282259
) -> Dataset:
283260

261+
@typecheck
262+
def calculate_derivatives(
263+
derivatives_pm: Float[np.ndarray, "500 5 z R 2 d"],
264+
alpha: Float[np.ndarray, "p"],
265+
dparams: Float[np.ndarray, "p"],
266+
parameter_strings: list[str],
267+
parameter_derivative_names: list[list[str]],
268+
*,
269+
verbose: bool = False
270+
) -> Float[np.ndarray, "500 5 z R d"]:
271+
272+
derivatives = derivatives_pm[..., 1, :] - derivatives_pm[..., 0, :]
273+
274+
for p in range(alpha.size):
275+
if verbose:
276+
print(
277+
"Parameter strings / dp / dp_name",
278+
parameter_strings[p], dparams[p], parameter_derivative_names[p]
279+
)
280+
derivatives[:, p, ...] = derivatives[:, p, ...] / dparams[p] # NOTE: OK before or after reducing cumulants
281+
282+
assert derivatives.ndim == 5, "{}".format(derivatives.shape)
283+
284+
return derivatives
285+
284286
data_dir, *_ = get_save_and_load_dirs()
285287

286288
(
@@ -313,12 +315,12 @@ def get_cumulant_data(
313315
fiducial_moments,
314316
latin_moments,
315317
latin_moments_parameters,
316-
derivatives
318+
derivatives_pm
317319
) = get_raw_data(data_dir, verbose=verbose)
318320

319321
# Euler derivative from plus minus statistics (NOTE: derivatives: Float[np.ndarray, "500 p z R 2 d"])
320322
derivatives = calculate_derivatives(
321-
derivatives,
323+
derivatives_pm,
322324
alpha,
323325
dparams,
324326
parameter_strings=parameter_strings,
@@ -371,6 +373,20 @@ def get_cumulant_data(
371373
F = np.linalg.multi_dot([dmu, Cinv, dmu.T])
372374
Finv = np.linalg.inv(F)
373375

376+
# dataset = Dataset(
377+
# alpha=alpha,
378+
# lower=lower,
379+
# upper=upper,
380+
# parameter_strings=parameter_strings,
381+
# Finv=Finv,
382+
# Cinv=Cinv,
383+
# C=C,
384+
# fiducial_data=fiducial_moments_z_R,
385+
# data=latin_moments_z_R,
386+
# parameters=latin_moments_parameters,
387+
# derivatives=derivatives
388+
# )
389+
374390
dataset = Dataset(
375391
alpha=jnp.asarray(alpha),
376392
lower=jnp.asarray(lower),
@@ -385,6 +401,9 @@ def get_cumulant_data(
385401
derivatives=jnp.asarray(derivatives)
386402
)
387403

404+
# dataset = convert_dataset_to_jax(dataset)
405+
# assert is_bearable(dataset, Dataset)
406+
388407
if verbose:
389408
corr_matrix = np.corrcoef(fiducial_moments_z_R, rowvar=False) + 1e-6 # Log colouring
390409

@@ -446,9 +465,7 @@ def get_cumulant_data(
446465
return dataset
447466

448467

449-
def get_prior(config: ConfigDict) -> tfd.Distribution:
450-
451-
dataset: Dataset = get_data(config)
468+
def get_prior(config: ConfigDict, dataset: Dataset) -> tfd.Distribution:
452469

453470
lower = jnp.asarray(dataset.lower)
454471
upper = jnp.asarray(dataset.upper)
@@ -519,7 +536,7 @@ def sample_prior(
519536

520537
@typecheck
521538
def get_linearised_data(
522-
config: ConfigDict
539+
config: ConfigDict, dataset: Dataset
523540
) -> Tuple[Float[Array, "n d"], Float[Array, "n p"]]:
524541
"""
525542
Get linearised PDFs and get their MLEs
@@ -537,8 +554,6 @@ def get_linearised_data(
537554

538555
key_parameters, key_simulations = jr.split(key)
539556

540-
dataset: Dataset = get_cumulant_data(config)
541-
542557
if config.n_linear_sims is not None:
543558
Y = sample_prior(
544559
key_parameters,
@@ -640,7 +655,7 @@ def get_data(config: ConfigDict, *, verbose: bool = False, results_dir: Optional
640655
if hasattr(config, "linearised"):
641656
if config.linearised:
642657
print("Using linearised model, Gaussian noise.")
643-
D, Y = get_linearised_data(config)
658+
D, Y = get_linearised_data(config, dataset)
644659

645660
dataset = replace(dataset, data=D, parameters=Y)
646661

@@ -657,17 +672,44 @@ def get_data(config: ConfigDict, *, verbose: bool = False, results_dir: Optional
657672
return dataset
658673

659674

675+
@typecheck
676+
def mle(
677+
d: Float[Array, "d"],
678+
pi: Float[Array, "p"],
679+
Finv: Float[Array, "p p"],
680+
mu: Float[Array, "d"],
681+
dmu: Float[Array, "p d"],
682+
precision: Float[Array, "d d"]
683+
) -> Float[Array, "p"]:
684+
"""
685+
Calculates a maximum likelihood estimator (MLE) from a datavector by
686+
assuming a linear model `mu` in parameters `pi` and using
687+
688+
Args:
689+
d (`Array`): The datavector to compress.
690+
p (`Array`): The estimated parameters of the datavector (e.g. a fiducial set).
691+
Finv (`Array`): The Fisher matrix. Calculated with a precision matrix (e.g. `precision`) and
692+
theory derivatives.
693+
mu (`Array`): The model evaluated at the estimated set of parameters `pi`.
694+
dmu (`Array`): The first-order theory derivatives (for the implicitly assumed linear model,
695+
these are parameter independent!)
696+
precision (`Array`): The precision matrix - defined as the inverse of the data covariance matrix.
697+
698+
Returns:
699+
`Array`: the MLE.
700+
"""
701+
return pi + jnp.linalg.multi_dot([Finv, dmu, precision, d - mu])
702+
703+
660704
@typecheck
661705
def get_linear_compressor(
662-
config: ConfigDict
706+
config: ConfigDict, dataset: Dataset
663707
) -> Callable[[Float[Array, "d"], Float[Array, "p"]], Float[Array, "p"]]:
664708
"""
665709
Get Chi^2 minimisation function; compressing datavector
666710
at estimated parameters to summary
667711
"""
668712

669-
dataset: Dataset = get_data(config)
670-
671713
mu = jnp.mean(dataset.fiducial_data, axis=0)
672714
dmu = jnp.mean(dataset.derivatives, axis=0)
673715

@@ -738,7 +780,7 @@ def get_compression_fn(key, config, dataset, *, results_dir):
738780
compressor = lambda d, p: net(preprocess_fn(d)) # Ignore parameter kwarg!
739781

740782
if config.compression == "linear":
741-
compressor = get_linear_compressor(config)
783+
compressor = get_linear_compressor(config, dataset)
742784

743785
# Fit PCA transform to simulated data and apply after compressing
744786
if config.use_pca:
@@ -763,13 +805,10 @@ def get_compression_fn(key, config, dataset, *, results_dir):
763805

764806
@typecheck
765807
def get_datavector(
766-
key: PRNGKeyArray, config: ConfigDict, n: int = 1
808+
key: PRNGKeyArray, config: ConfigDict, dataset: Dataset, n: int = 1
767809
) -> Float[Array, "... d"]:
768810
""" Measurement: either Gaussian linear model or not """
769811

770-
# NOTE: must be working with fiducial parameters!
771-
dataset: Dataset = get_data(config)
772-
773812
# Choose a linearised model datavector or simply one of the Quijote realisations
774813
# which corresponds to a non-linearised datavector with Gaussian noise
775814
if not config.use_expectation:
@@ -788,4 +827,71 @@ def get_datavector(
788827
if not (n > 1):
789828
datavector = jnp.squeeze(datavector, axis=0)
790829

791-
return datavector # Remove batch axis by default
830+
return datavector # Remove batch axis by default
831+
832+
833+
@dataclass
834+
class CumulantsDataset:
835+
""" Dataset for Simulation-Based Inference with cumulants of the matter PDF """
836+
config: ConfigDict
837+
data: Dataset
838+
prior: tfd.Distribution
839+
compression_fn: Callable
840+
results_dir: str
841+
842+
def __init__(
843+
self,
844+
config: ConfigDict,
845+
*,
846+
verbose: bool = False,
847+
results_dir: Optional[str] = None
848+
):
849+
self.config = config
850+
self.data = get_data(
851+
config, verbose=verbose, results_dir=results_dir
852+
)
853+
self.prior = get_prior(config, self.data) # Possibly not equal to Quijote prior
854+
self.results_dir = results_dir
855+
856+
key = jr.key(config.seed)
857+
self.compression_fn = get_compression_fn(
858+
key, self.config, self.data, results_dir=self.results_dir
859+
)
860+
861+
def get_parameter_strings(self):
862+
return get_parameter_strings()
863+
864+
def sample_prior(self, key: PRNGKeyArray, n: int, *, hypercube: bool = True) -> Float[Array, "n p"]:
865+
# Sample Quijote prior which may not be the same as inference prior
866+
P = sample_prior(
867+
key,
868+
n,
869+
self.data.alpha,
870+
self.data.lower,
871+
self.data.upper,
872+
hypercube=hypercube
873+
)
874+
return P
875+
876+
def get_compression_fn(self):
877+
return self.compression_fn
878+
879+
def get_datavector(self, key: PRNGKeyArray, n: int = 1) -> Float[Array, "... d"]:
880+
d = get_datavector(key, self.config, self.data, n)
881+
return d
882+
883+
def get_linearised_datavector(self, key: PRNGKeyArray, n: int = 1) -> Float[Array, "... d"]:
884+
# Sample datavector from linearised Gaussian model
885+
mu = jnp.mean(self.data.fiducial_data, axis=0)
886+
d = jr.multivariate_normal(key, mu, self.data.C, (n,))
887+
if not (n > 1):
888+
d = jnp.squeeze(d, axis=0)
889+
return d
890+
891+
def get_linearised_data(self):
892+
# Get linearised data (e.g. pre-training), where config only sets how many simulations
893+
return get_linearised_data(self.config, self.data)
894+
895+
def get_preprocess_fn(self):
896+
# Get (X, P) preprocessor?
897+
...

0 commit comments

Comments
 (0)