Skip to content

Commit 5e3c274

Browse files
wenx-guofacebook-github-bot
authored andcommitted
add script for benchmarking baseline acquisition functions for full function estimate
Summary: We conducted benchmarking experiments on full function estimate for Sobol sampling, BALV, and BALD acquisition functions. These tests were performed on three distinct functions and a real psychophysics task data set obtained from Letham et al. (2022). Differential Revision: D57916855
1 parent 8bc289c commit 5e3c274

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
2+
3+
import os
4+
import logging
5+
import argparse
6+
7+
# run each job single-threaded, paralellize using pathos
8+
os.environ["OMP_NUM_THREADS"] = "1"
9+
os.environ["MKL_NUM_THREADS"] = "1"
10+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
11+
12+
# multi-socket friendly args
13+
os.environ["KMP_AFFINITY"] = "granularity=fine,compact,1,0"
14+
os.environ["KMP_BLOCKTIME"] = "1"
15+
import torch
16+
17+
# force torch to 1 thread too just in case
18+
torch.set_num_interop_threads(1)
19+
torch.set_num_threads(1)
20+
21+
import time
22+
from copy import deepcopy
23+
from pathlib import Path
24+
25+
from aepsych.benchmark import run_benchmarks_with_checkpoints
26+
import aepsych.utils_logging as utils_logging
27+
logger=utils_logging.getLogger(logging.ERROR)
28+
29+
from aepsych.benchmark.problem import (
30+
DiscrimLowDim,
31+
DiscrimHighDim,
32+
Hartmann6Binary,
33+
ContrastSensitivity6d, # This takes a few minutes to instantiate due to fitting the model
34+
)
35+
36+
problem_map = {
37+
"discrim_lowdim": DiscrimLowDim,
38+
"discrim_highdim": DiscrimHighDim,
39+
"hartmann6_binary": Hartmann6Binary,
40+
"contrast_sensitivity_6d": ContrastSensitivity6d,
41+
}
42+
43+
44+
def make_argparser():
45+
parser = argparse.ArgumentParser(description="Lookahead LSE Benchmarks")
46+
parser.add_argument("--nproc", type=int, default=30)
47+
parser.add_argument("--reps_per_chunk", type=int, default=20)
48+
parser.add_argument("--acqf_start_idx", type=int, default=0)
49+
parser.add_argument("--sobol_start_idx", type=int, default=0)
50+
parser.add_argument("--chunks", type=int, default=15)
51+
parser.add_argument("--opt_size", type=int, default=740) # 490
52+
parser.add_argument("--init_size", type=int, default=10)
53+
parser.add_argument("--global_seed", type=int, default=1000)
54+
parser.add_argument("--log_frequency", type=int, default=10)
55+
parser.add_argument("--output_path", type=Path, default=Path("data/benchmark"))
56+
parser.add_argument("--bench_name", type=str, default="exploration_baseline")
57+
parser.add_argument(
58+
"--problem",
59+
type=str,
60+
choices=[
61+
"discrim_highdim",
62+
"discrim_lowdim",
63+
"hartmann6_binary",
64+
"contrast_sensitivity_6d",
65+
"all",
66+
],
67+
default="all",
68+
)
69+
return parser
70+
71+
72+
if __name__ == "__main__":
73+
74+
parser = make_argparser()
75+
args = parser.parse_args()
76+
chunks = args.chunks # The number of chunks to break the results into. Each chunk will contain at least 1 run of every
77+
# combination of problem and config.
78+
acqf_start_idx = args.acqf_start_idx # The index of the first chunk to run for different acquisition functions
79+
sobol_start_idx = args.sobol_start_idx # The index of the first chunk to run for sobol sampling
80+
reps_per_chunk = args.reps_per_chunk # Number of repetitions to run each problem/config in each chunk.
81+
82+
nproc = args.nproc # how many processes to use
83+
global_seed = args.global_seed # random seed for reproducibility
84+
log_every = args.log_frequency # log to csv every this many trials
85+
checkpoint_every = 120 # save intermediate results every this many seconds
86+
serial_debug = False # whether to run simulations serially for debugging
87+
bench_name=args.bench_name
88+
89+
out_fname_base = args.output_path
90+
out_fname_base.mkdir(
91+
parents=True, exist_ok=True
92+
) # make an output folder if not exist
93+
if args.problem == "all":
94+
problems = [
95+
DiscrimLowDim(),
96+
DiscrimHighDim(),
97+
Hartmann6Binary(),
98+
ContrastSensitivity6d(),
99+
]
100+
else:
101+
problems = [problem_map[args.problem]()]
102+
103+
bench_config = {
104+
"common": {
105+
"stimuli_per_trial": 1,
106+
"outcome_types": "binary",
107+
"strategy_names": "[init_strat, opt_strat]",
108+
},
109+
"init_strat": {"n_trials": args.init_size, "generator": "SobolGenerator"},
110+
"opt_strat": {
111+
"model": "GPClassificationModel",
112+
"generator": "OptimizeAcqfGenerator",
113+
"n_trials": args.opt_size,
114+
"refit_every": args.log_frequency,
115+
},
116+
"GPClassificationModel": {
117+
"inducing_size": 100,
118+
"mean_covar_factory": "default_mean_covar_factory",
119+
"inducing_point_method": "auto",
120+
},
121+
"default_mean_covar_factory": {
122+
"fixed_mean": False,
123+
"lengthscale_priout_fname_baseor": "gamma",
124+
"outputscale_prior": "gamma",
125+
"kernel": "RBFKernel",
126+
},
127+
"OptimizeAcqfGenerator": {
128+
"acqf": [
129+
"MCPosteriorVariance", # BALV
130+
"BernoulliMCMutualInformation", # BALD
131+
],
132+
"restarts": 2,
133+
"samps": 100,
134+
},
135+
# Add the probit transform for non-probit-specific acqfs
136+
"BernoulliMCMutualInformation": {"objective": "ProbitObjective"},
137+
"MCPosteriorVariance": {"objective": "ProbitObjective"},
138+
}
139+
140+
# benchmaking with baseline acquisition functions
141+
run_benchmarks_with_checkpoints(
142+
out_fname_base,
143+
bench_name,
144+
problems,
145+
bench_config,
146+
global_seed,
147+
acqf_start_idx,
148+
chunks,
149+
reps_per_chunk,
150+
log_every,
151+
checkpoint_every,
152+
nproc,
153+
serial_debug,
154+
)
155+
156+
# benchmaking with sobol sampling
157+
sobol_config=deepcopy(bench_config)
158+
sobol_config["opt_strat"]['generator']='SobolGenerator'
159+
del sobol_config["OptimizeAcqfGenerator"]
160+
sobol_bench_name=bench_name+"_sobol"
161+
run_benchmarks_with_checkpoints(
162+
out_fname_base,
163+
sobol_bench_name,
164+
problems,
165+
sobol_config,
166+
global_seed,
167+
sobol_start_idx,
168+
chunks,
169+
reps_per_chunk,
170+
log_every,
171+
checkpoint_every,
172+
nproc,
173+
serial_debug,
174+
)

0 commit comments

Comments
 (0)