Skip to content

Commit 3159e78

Browse files
committed
updates, clearup
1 parent 421813b commit 3159e78

22 files changed

+315
-4405
lines changed

cumulants/bash/run_cumulants_sbi_one_training_linear_copy.sh

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ RESULTS_DIR="results/${1:-"results/"}"
2020
SINGLE_RUN="${2:-false}" # Run all experiments once for a single figure one
2121
ONLY_RUN_FIGURES="${3:-false}" # Only run figure_one.py jobs
2222

23-
RUN_LINEARISED=false
23+
RUN_LINEARISED=true
2424
RUN_FROZEN=false
2525
RUN_NONLINEAR=true
2626

@@ -46,11 +46,11 @@ if [[ "$SINGLE_RUN" == "true" ]]; then
4646
RUN_FROZEN=false
4747
else
4848
echo "MULTIPLE SEEDS RUN."
49-
N_SEEDS=50
49+
N_SEEDS=20
5050
START_SEED=0
5151
N_SEEDS_GLOBAL=10 # Number of repeated trainings for SBI
5252
END_SEED=$(( $START_SEED + $N_SEEDS ))
53-
N_PARALLEL=50
53+
N_PARALLEL=100
5454
fi
5555

5656
N_GB=8
@@ -182,8 +182,8 @@ $FREEZE_FLAG"
182182
cat <<END
183183
#!/bin/bash
184184
#SBATCH --job-name=sbi_${global_seed}_${bt_flag}_${l_flag}_${f_flag}_z${z}
185-
#SBATCH --output=$OUT_DIR/sbi_${bt_flag}_${l_flag}_${f_flag}_z${z}_fixed_%j.out
186-
#SBATCH --error=$OUT_DIR/sbi_${bt_flag}_${l_flag}_${f_flag}_z${z}_fixed_%j.err
185+
#SBATCH --output=$OUT_DIR/${bt_flag}/${l_flag}/${f_flag}/z${z}/sbi_fixed_%j.out
186+
#SBATCH --error=$OUT_DIR/${bt_flag}/${l_flag}/${f_flag}/z${z}/sbi_fixed_%j.err
187187
#SBATCH --partition=cluster
188188
#SBATCH --time=$JOB_TIME
189189
#SBATCH --mem=${N_GB}GB
@@ -262,8 +262,8 @@ $FREEZE_FLAG"
262262
cat <<END
263263
#!/bin/bash
264264
#SBATCH --job-name=m_z_${global_seed}_${bt_flag}_${l_flag}_${f_flag}
265-
#SBATCH --output=$OUT_DIR/m_z_${global_seed}_${bt_flag}_${l_flag}_${f_flag}_%a_%j.out
266-
#SBATCH --error=$OUT_DIR/m_z_${global_seed}_${bt_flag}_${l_flag}_${f_flag}_%a_%j.err
265+
#SBATCH --output=$OUT_DIR/${global_seed}/${bt_flag}/${l_flag}/${f_flag}/m_z_%a_%j.out
266+
#SBATCH --error=$OUT_DIR/${global_seed}/${bt_flag}/${l_flag}/${f_flag}/m_z_%a_%j.err
267267
#SBATCH --array=$JOB_ARRAY_STR
268268
#SBATCH --partition=cluster
269269
#SBATCH --time=$JOB_TIME
@@ -338,8 +338,8 @@ $FREEZE_FLAG"
338338
cat <<END
339339
#!/bin/bash
340340
#SBATCH --job-name=figure_one
341-
#SBATCH --output=$OUT_DIR/figure_one_${global_seed}_${l_flag}_${f_flag}_%a_%j.out
342-
#SBATCH --error=$OUT_DIR/figure_one_${global_seed}_${l_flag}_${f_flag}_%a_%j.err
341+
#SBATCH --output=$OUT_DIR/${global_seed}/${l_flag}/${f_flag}/figure_one_%a_%j.out
342+
#SBATCH --error=$OUT_DIR/${global_seed}/${l_flag}/${f_flag}/figure_one_%a_%j.err
343343
#SBATCH --array=$JOB_ARRAY_STR
344344
#SBATCH --partition=cluster
345345
#SBATCH --time=$JOB_TIME
@@ -442,8 +442,8 @@ for order_idx_args in "${order_idxs[@]}"; do
442442
sbatch <<END
443443
#!/bin/bash
444444
#SBATCH --job-name=figure_two
445-
#SBATCH --output=$OUT_DIR/figure_two_${l_flag}_${f_flag}_%j.out
446-
#SBATCH --error=$OUT_DIR/figure_two_${l_flag}_${f_flag}_%j.err
445+
#SBATCH --output=$OUT_DIR/${l_flag}/${f_flag}/figure_two_%j.out
446+
#SBATCH --error=$OUT_DIR/${l_flag}/${f_flag}/figure_two_%j.err
447447
#SBATCH --partition=cluster
448448
#SBATCH --time=$JOB_TIME
449449
#SBATCH --mem=${N_GB}GB
@@ -474,8 +474,8 @@ python figure_two3.py \
474474
--n_datavectors $N_DATAVECTORS \
475475
--compression linear \
476476
--n_linear_sims $N_LINEAR_SIMS \
477-
--order_idx "$order_idx_args" \
478-
--scales "$scale_args" \
477+
--order_idx $order_idx_args \
478+
--scales $scale_args \
479479
$LINEARISED_FLAG \
480480
$PRETRAIN_FLAG \
481481
$USE_PLANCK_FLAG \

cumulants/configs/configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ def get_ndes_from_config(
222222
ndes = []
223223
for nde, key in zip(config.ndes, keys):
224224

225+
logger.info("Using NDE of type '{}'".format(nde.model_type))
226+
225227
assert nde.model_type in ["maf", "cnf"], (
226228
"Invalid NDE model type (={})".format(nde.model_type)
227229
)

cumulants/configs/cumulants_configs.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def default(v, d):
7676
start_step = 0,
7777
n_epochs = 10_000,
7878
n_batch = 80, #100,
79-
patience = 70, #200,
79+
patience = 300, #70,
8080
lr = 0.000489390761268084, #1e-3,
8181
opt = "adam",
8282
opt_kwargs = {}
@@ -97,7 +97,7 @@ def default(v, d):
9797
DEFAULT_OPT_MAF = HP_OPT_OPT_MAF
9898

9999
DEFAULT_CNF_ARCH = DEFAULT_CNF_ARCH # HP_OPT_CNF_ARCH
100-
DEFAULT_OPT_CNF = DEFAULT_OPT # HP_OPT_OPT_CNF
100+
DEFAULT_OPT_CNF = HP_OPT_OPT_CNF # HP_OPT_OPT_CNF
101101

102102
# Number of density estimators in the ensemble
103103
N_NDES = default(int(DEFAULT_N_NDES), 1)
@@ -108,6 +108,11 @@ def get_default_nde(cnf, maf):
108108
return default(_default, maf)
109109

110110

111+
def get_default_nde_opt():
112+
_default = {"CNF": DEFAULT_OPT_CNF, "MAF": DEFAULT_OPT_MAF}[DEFAULT_NDE_TYPE] if DEFAULT_NDE_TYPE is not None else DEFAULT_OPT
113+
return _default
114+
115+
111116
def get_config_ndes(config):
112117
# Set the NDE architecture and training parameters for a config
113118

@@ -149,13 +154,14 @@ def get_config_ndes(config):
149154
pretrain.opt_kwargs = {}
150155

151156
# Optimisation hyperparameters (same for all NDEs...)
157+
_CONFIG_DEFAULT_OPT = get_default_nde_opt()
152158
config.train = train = ConfigDict()
153159
train.start_step = 0
154160
train.n_epochs = 10_000
155-
train.n_batch = DEFAULT_OPT["n_batch"] # 100
156-
train.patience = DEFAULT_OPT["patience"] # 200
157-
train.lr = DEFAULT_OPT["lr"] # 1e-3
158-
train.opt = DEFAULT_OPT["opt"] # "adam"
161+
train.n_batch = _CONFIG_DEFAULT_OPT["n_batch"]
162+
train.patience = _CONFIG_DEFAULT_OPT["patience"]
163+
train.lr = _CONFIG_DEFAULT_OPT["lr"]
164+
train.opt = _CONFIG_DEFAULT_OPT["opt"]
159165
train.opt_kwargs = {}
160166

161167
return config
@@ -172,11 +178,11 @@ def default_posterior_sampling(config, no_config=False):
172178

173179
# Posterior sampling
174180
if linearised:
175-
config.n_steps = 100
176-
config.n_walkers = 2000
181+
config.n_steps = 500
182+
config.n_walkers = 4000
177183
else:
178-
config.n_steps = 100
179-
config.n_walkers = 2000
184+
config.n_steps = 500
185+
config.n_walkers = 4000
180186
config.burn = int(0.1 * config.n_steps)
181187

182188
return config

cumulants/configs/log.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
LOG_DIR = os.getenv("LOG_DIR", "logs/")
55

66

7-
def get_log_level(default="INFO"):
7+
def get_log_level(default="DEBUG"):
88

99
level_str = os.getenv("LOG_LEVEL", default).upper()
1010

@@ -20,6 +20,8 @@ def setup_module_logger(module_name: str, level=logging.INFO, log_dir=LOG_DIR):
2020

2121
log_path = os.path.join(log_dir, f"{module_name}.log")
2222

23+
print("LOG PATH:\n\t{}".format(log_path))
24+
2325
try:
2426
if os.path.exists(log_path):
2527
os.remove(log_path)

0 commit comments

Comments
 (0)