Skip to content

Commit fee3ccf

Browse files
committed
scaler for linearised, scaler for nonlinear, replace function for scalers
1 parent fb7a2d0 commit fee3ccf

File tree

7 files changed

+596
-222
lines changed

7 files changed

+596
-222
lines changed

cumulants/configs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def get_ndes_from_config(
334334
event_dim=event_dim,
335335
context_dim=context_dim if (context_dim is not None) else event_dim,
336336
key=key,
337-
scaler=scaler if use_scalers else None,
337+
scaler=scaler if (nde.use_scaling and use_scalers) else None,
338338
**dict(nde)
339339
)
340340
nde_dict.pop("model_type")
@@ -458,7 +458,7 @@ def cumulants_config(
458458
pretrain.start_step = 0
459459
pretrain.n_epochs = 10_000
460460
pretrain.n_batch = 100
461-
pretrain.patience = 10
461+
pretrain.patience = 100
462462
pretrain.lr = 1e-3
463463
pretrain.opt = "adam"
464464
pretrain.opt_kwargs = {}
@@ -467,8 +467,8 @@ def cumulants_config(
467467
train.start_step = 0
468468
train.n_epochs = 10_000
469469
train.n_batch = 100
470-
train.patience = 10
471-
train.lr = 1e-3
470+
train.patience = 100
471+
train.lr = 1e-4
472472
train.opt = "adam"
473473
train.opt_kwargs = {}
474474

cumulants/cumulants.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,32 @@ def remove_nuisances(dataset: Dataset) -> Dataset:
252252
return dataset
253253

254254

255+
@typecheck
256+
def calculate_derivatives(
257+
derivatives_pm: Float[np.ndarray, "500 5 z R 2 d"],
258+
alpha: Float[np.ndarray, "p"],
259+
dparams: Float[np.ndarray, "p"],
260+
parameter_strings: list[str],
261+
parameter_derivative_names: list[list[str]],
262+
*,
263+
verbose: bool = False
264+
) -> Float[np.ndarray, "500 5 z R d"]:
265+
266+
derivatives = derivatives_pm[..., 1, :] - derivatives_pm[..., 0, :]
267+
268+
for p in range(alpha.size):
269+
if verbose:
270+
print(
271+
"Parameter strings / dp / dp_name",
272+
parameter_strings[p], dparams[p], parameter_derivative_names[p]
273+
)
274+
derivatives[:, p, ...] = derivatives[:, p, ...] / dparams[p] # NOTE: OK before or after reducing cumulants
275+
276+
assert derivatives.ndim == 5, "{}".format(derivatives.shape)
277+
278+
return derivatives
279+
280+
255281
def get_cumulant_data(
256282
config: ConfigDict, *, verbose: bool = False, results_dir: Optional[str] = None
257283
) -> Dataset:
@@ -292,16 +318,14 @@ def get_cumulant_data(
292318
) = get_raw_data(data_dir, verbose=verbose)
293319

294320
# Euler derivative from plus minus statistics (NOTE: derivatives: Float[np.ndarray, "500 p z R 2 d"])
295-
derivatives = derivatives[..., 1, :] - derivatives[..., 0, :]
296-
for p in range(alpha.size):
297-
if verbose:
298-
print(
299-
"Parameter strings / alpha / dp / dp_name",
300-
parameter_strings[p], alpha[p], dparams[p], parameter_derivative_names[p]
301-
)
302-
derivatives[:, p, ...] = derivatives[:, p, ...] / dparams[p] # NOTE: OK before or after reducing cumulants
303-
304-
assert derivatives.ndim == 5, "{}".format(derivatives.shape)
321+
derivatives = calculate_derivatives(
322+
derivatives,
323+
alpha,
324+
dparams,
325+
parameter_strings,
326+
parameter_derivative_names,
327+
verbose=verbose
328+
)
305329

306330
# Grab and stack by redshift and scales
307331
(
@@ -467,7 +491,7 @@ def sample_prior(
467491
lower = lower.astype(jnp.float32)
468492
upper = upper.astype(jnp.float32)
469493

470-
assert jnp.all(upper - lower > 0.)
494+
assert jnp.all((upper - lower) > 0.)
471495

472496
keys_p = jr.split(key, alpha.size)
473497

@@ -648,7 +672,8 @@ def get_linear_compressor(
648672
mu = jnp.mean(dataset.fiducial_data, axis=0)
649673
dmu = jnp.mean(dataset.derivatives, axis=0)
650674

651-
def compressor(d, p):
675+
@typecheck
676+
def compressor(d: Float[Array, "d"], p: Float[Array, "p"]) -> Float[Array, "p"]:
652677
mu_p = linearised_model(
653678
alpha=dataset.alpha, alpha_=p, mu=mu, dmu=dmu
654679
)
@@ -703,7 +728,7 @@ def preprocess_fn(x):
703728
return net, preprocess_fn
704729

705730

706-
def get_compression_fn(key, config, dataset, results_dir):
731+
def get_compression_fn(key, config, dataset, *, results_dir):
707732
# Get linear or neural network compressor
708733
if config.compression == "nn":
709734

0 commit comments

Comments
 (0)