Skip to content

Commit 9a7371c

Browse files
wenx-guofacebook-github-bot
authored andcommitted
add additional test functions and psychophysics task and dataset from Letham et al. 2022 (#350)
Summary: Additional high-dimensional test functions and real psychophysics task are added to problem.py for benchmarking performance of acquistions functions or GP models. The code and dataset are obtained from https://github.com/facebookresearch/bernoulli_lse/blob/main/problems.py. Differential Revision: D57885175
1 parent e16927e commit 9a7371c

File tree

2 files changed

+1130
-2
lines changed

2 files changed

+1130
-2
lines changed

aepsych/benchmark/problem.py

+128-2
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,23 @@
44

55
# This source code is licensed under the license found in the
66
# LICENSE file in the root directory of this source tree.
7-
7+
import os
88
from functools import cached_property
99
from typing import Any, Dict, Union
1010

1111
import aepsych
1212
import numpy as np
1313
import torch
14+
from scipy.stats import bernoulli, norm, pearsonr
1415
from aepsych.strategy import SequentialStrategy, Strategy
1516
from aepsych.utils import make_scaled_sobol
16-
from scipy.stats import bernoulli, norm, pearsonr
17+
from aepsych.benchmark.test_functions import (
18+
modified_hartmann6,
19+
discrim_highdim,
20+
novel_discrimination_testfun,
21+
)
22+
from aepsych.models import GPClassificationModel
23+
1724

1825

1926
class Problem:
@@ -281,3 +288,122 @@ def evaluate(self, strat: Union[Strategy, SequentialStrategy]) -> Dict[str, floa
281288
)
282289

283290
return metrics
291+
292+
"""
293+
The LSEProblemWithEdgeLogging, DiscrimLowDim, DiscrimHighDim, ContrastSensitivity6d, and Hartmann6Binary classes
294+
are copied from bernoulli_lse repository (https://github.com/facebookresearch/bernoulli_lse) by Letham et al. 2022.
295+
"""
296+
class LSEProblemWithEdgeLogging(LSEProblem):
297+
eps = 0.05
298+
299+
def evaluate(self, strat):
300+
metrics = super().evaluate(strat)
301+
302+
# add number of edge samples to the log
303+
304+
# get the trials selected by the final strat only
305+
n_opt_trials = strat.strat_list[-1].n_trials
306+
307+
lb, ub = strat.lb, strat.ub
308+
r = ub - lb
309+
lb2 = lb + self.eps * r
310+
ub2 = ub - self.eps * r
311+
312+
near_edge = (
313+
np.logical_or(
314+
(strat.x[-n_opt_trials:, :] <= lb2), (strat.x[-n_opt_trials:, :] >= ub2)
315+
)
316+
.any(axis=-1)
317+
.double()
318+
)
319+
320+
metrics["prop_edge_sampling_mean"] = near_edge.mean().item()
321+
metrics["prop_edge_sampling_err"] = (
322+
2 * near_edge.std() / np.sqrt(len(near_edge))
323+
).item()
324+
return metrics
325+
326+
327+
class DiscrimLowDim(LSEProblemWithEdgeLogging):
328+
name = "discrim_lowdim"
329+
bounds = torch.tensor([[-1, 1], [-1, 1]], dtype=torch.double).T
330+
threshold = 0.75
331+
332+
def f(self, x: torch.Tensor) -> torch.Tensor:
333+
return torch.tensor(novel_discrimination_testfun(x), dtype=torch.double)
334+
335+
336+
class DiscrimHighDim(LSEProblemWithEdgeLogging):
337+
name = "discrim_highdim"
338+
threshold = 0.75
339+
bounds = torch.tensor(
340+
[
341+
[-1, 1],
342+
[-1, 1],
343+
[0.5, 1.5],
344+
[0.05, 0.15],
345+
[0.05, 0.2],
346+
[0, 0.9],
347+
[0, 3.14 / 2],
348+
[0.5, 2],
349+
],
350+
dtype=torch.double,
351+
).T
352+
353+
def f(self, x: torch.Tensor) -> torch.Tensor:
354+
return torch.tensor(discrim_highdim(x), dtype=torch.double)
355+
356+
357+
class Hartmann6Binary(LSEProblemWithEdgeLogging):
358+
name = "hartmann6_binary"
359+
threshold = 0.5
360+
bounds = torch.stack(
361+
(
362+
torch.zeros(6, dtype=torch.double),
363+
torch.ones(6, dtype=torch.double),
364+
)
365+
)
366+
367+
def f(self, X: torch.Tensor) -> torch.Tensor:
368+
y = torch.tensor([modified_hartmann6(x) for x in X], dtype=torch.double)
369+
370+
371+
class ContrastSensitivity6d(LSEProblemWithEdgeLogging):
372+
"""
373+
Uses a surrogate model fit to real data from a constrast sensitivity study.
374+
"""
375+
376+
name = "contrast_sensitivity_6d"
377+
threshold = 0.75
378+
bounds = torch.tensor(
379+
[[-1.5, 0], [-1.5, 0], [0, 20], [0.5, 7], [1, 10], [0, 10]],
380+
dtype=torch.double,
381+
).T
382+
383+
def __init__(self):
384+
385+
# Load the data
386+
self.data = np.loadtxt(
387+
os.path.join("..", "..", "dataset", "csf_dataset.csv"),
388+
delimiter=",",
389+
skiprows=1,
390+
)
391+
y = torch.LongTensor(self.data[:, 0])
392+
x = torch.Tensor(self.data[:, 1:])
393+
394+
# Fit a model, with a large number of inducing points
395+
self.m = GPClassificationModel(
396+
lb=self.bounds[0],
397+
ub=self.bounds[1],
398+
inducing_size=100,
399+
inducing_point_method="kmeans++",
400+
)
401+
402+
self.m.fit(
403+
x,
404+
y,
405+
)
406+
407+
def f(self, X: torch.Tensor) -> torch.Tensor:
408+
# clamp f to 0 since we expect p(x) to be lower-bounded at 0.5
409+
return torch.clamp(self.m.predict(torch.tensor(X))[0], min=0)

0 commit comments

Comments
 (0)