@@ -252,6 +252,32 @@ def remove_nuisances(dataset: Dataset) -> Dataset:
252252 return dataset
253253
254254
255+ @typecheck
256+ def calculate_derivatives (
257+ derivatives_pm : Float [np .ndarray , "500 5 z R 2 d" ],
258+ alpha : Float [np .ndarray , "p" ],
259+ dparams : Float [np .ndarray , "p" ],
260+ parameter_strings : list [str ],
261+ parameter_derivative_names : list [list [str ]],
262+ * ,
263+ verbose : bool = False
264+ ) -> Float [np .ndarray , "500 5 z R d" ]:
265+
266+ derivatives = derivatives_pm [..., 1 , :] - derivatives_pm [..., 0 , :]
267+
268+ for p in range (alpha .size ):
269+ if verbose :
270+ print (
271+ "Parameter strings / dp / dp_name" ,
272+ parameter_strings [p ], dparams [p ], parameter_derivative_names [p ]
273+ )
274+ derivatives [:, p , ...] = derivatives [:, p , ...] / dparams [p ] # NOTE: OK before or after reducing cumulants
275+
276+ assert derivatives .ndim == 5 , "{}" .format (derivatives .shape )
277+
278+ return derivatives
279+
280+
255281def get_cumulant_data (
256282 config : ConfigDict , * , verbose : bool = False , results_dir : Optional [str ] = None
257283) -> Dataset :
@@ -292,16 +318,14 @@ def get_cumulant_data(
292318 ) = get_raw_data (data_dir , verbose = verbose )
293319
294320 # Euler derivative from plus minus statistics (NOTE: derivatives: Float[np.ndarray, "500 p z R 2 d"])
295- derivatives = derivatives [..., 1 , :] - derivatives [..., 0 , :]
296- for p in range (alpha .size ):
297- if verbose :
298- print (
299- "Parameter strings / alpha / dp / dp_name" ,
300- parameter_strings [p ], alpha [p ], dparams [p ], parameter_derivative_names [p ]
301- )
302- derivatives [:, p , ...] = derivatives [:, p , ...] / dparams [p ] # NOTE: OK before or after reducing cumulants
303-
304- assert derivatives .ndim == 5 , "{}" .format (derivatives .shape )
321+ derivatives = calculate_derivatives (
322+ derivatives ,
323+ alpha ,
324+ dparams ,
325+ parameter_strings ,
326+ parameter_derivative_names ,
327+ verbose = verbose
328+ )
305329
306330 # Grab and stack by redshift and scales
307331 (
@@ -467,7 +491,7 @@ def sample_prior(
467491 lower = lower .astype (jnp .float32 )
468492 upper = upper .astype (jnp .float32 )
469493
470- assert jnp .all (upper - lower > 0. )
494+ assert jnp .all (( upper - lower ) > 0. )
471495
472496 keys_p = jr .split (key , alpha .size )
473497
@@ -648,7 +672,8 @@ def get_linear_compressor(
648672 mu = jnp .mean (dataset .fiducial_data , axis = 0 )
649673 dmu = jnp .mean (dataset .derivatives , axis = 0 )
650674
651- def compressor (d , p ):
675+ @typecheck
676+ def compressor (d : Float [Array , "d" ], p : Float [Array , "p" ]) -> Float [Array , "p" ]:
652677 mu_p = linearised_model (
653678 alpha = dataset .alpha , alpha_ = p , mu = mu , dmu = dmu
654679 )
@@ -703,7 +728,7 @@ def preprocess_fn(x):
703728 return net , preprocess_fn
704729
705730
706- def get_compression_fn (key , config , dataset , results_dir ):
731+ def get_compression_fn (key , config , dataset , * , results_dir ):
707732 # Get linear or neural network compressor
708733 if config .compression == "nn" :
709734
0 commit comments