|
4 | 4 |
|
5 | 5 | # This source code is licensed under the license found in the
|
6 | 6 | # LICENSE file in the root directory of this source tree.
|
7 |
| - |
| 7 | +import os |
8 | 8 | from functools import cached_property
|
9 | 9 | from typing import Any, Dict, Union
|
10 | 10 |
|
11 | 11 | import aepsych
|
12 | 12 | import numpy as np
|
13 | 13 | import torch
|
| 14 | +from scipy.stats import bernoulli, norm, pearsonr |
14 | 15 | from aepsych.strategy import SequentialStrategy, Strategy
|
15 | 16 | from aepsych.utils import make_scaled_sobol
|
16 |
| -from scipy.stats import bernoulli, norm, pearsonr |
| 17 | +from aepsych.benchmark.test_functions import ( |
| 18 | + modified_hartmann6, |
| 19 | + discrim_highdim, |
| 20 | + novel_discrimination_testfun, |
| 21 | +) |
| 22 | +from aepsych.models import GPClassificationModel |
| 23 | + |
17 | 24 |
|
18 | 25 |
|
19 | 26 | class Problem:
|
@@ -281,3 +288,122 @@ def evaluate(self, strat: Union[Strategy, SequentialStrategy]) -> Dict[str, floa
|
281 | 288 | )
|
282 | 289 |
|
283 | 290 | return metrics
|
| 291 | + |
| 292 | +""" |
| 293 | +The LSEProblemWithEdgeLogging, DiscrimLowDim, DiscrimHighDim, ContrastSensitivity6d, and Hartmann6Binary classes |
| 294 | +are copied from bernoulli_lse repository (https://github.com/facebookresearch/bernoulli_lse) by Letham et al. 2022. |
| 295 | +""" |
| 296 | +class LSEProblemWithEdgeLogging(LSEProblem): |
| 297 | + eps = 0.05 |
| 298 | + |
| 299 | + def evaluate(self, strat): |
| 300 | + metrics = super().evaluate(strat) |
| 301 | + |
| 302 | + # add number of edge samples to the log |
| 303 | + |
| 304 | + # get the trials selected by the final strat only |
| 305 | + n_opt_trials = strat.strat_list[-1].n_trials |
| 306 | + |
| 307 | + lb, ub = strat.lb, strat.ub |
| 308 | + r = ub - lb |
| 309 | + lb2 = lb + self.eps * r |
| 310 | + ub2 = ub - self.eps * r |
| 311 | + |
| 312 | + near_edge = ( |
| 313 | + np.logical_or( |
| 314 | + (strat.x[-n_opt_trials:, :] <= lb2), (strat.x[-n_opt_trials:, :] >= ub2) |
| 315 | + ) |
| 316 | + .any(axis=-1) |
| 317 | + .double() |
| 318 | + ) |
| 319 | + |
| 320 | + metrics["prop_edge_sampling_mean"] = near_edge.mean().item() |
| 321 | + metrics["prop_edge_sampling_err"] = ( |
| 322 | + 2 * near_edge.std() / np.sqrt(len(near_edge)) |
| 323 | + ).item() |
| 324 | + return metrics |
| 325 | + |
| 326 | + |
| 327 | +class DiscrimLowDim(LSEProblemWithEdgeLogging): |
| 328 | + name = "discrim_lowdim" |
| 329 | + bounds = torch.tensor([[-1, 1], [-1, 1]], dtype=torch.double).T |
| 330 | + threshold = 0.75 |
| 331 | + |
| 332 | + def f(self, x: torch.Tensor) -> torch.Tensor: |
| 333 | + return torch.tensor(novel_discrimination_testfun(x), dtype=torch.double) |
| 334 | + |
| 335 | + |
| 336 | +class DiscrimHighDim(LSEProblemWithEdgeLogging): |
| 337 | + name = "discrim_highdim" |
| 338 | + threshold = 0.75 |
| 339 | + bounds = torch.tensor( |
| 340 | + [ |
| 341 | + [-1, 1], |
| 342 | + [-1, 1], |
| 343 | + [0.5, 1.5], |
| 344 | + [0.05, 0.15], |
| 345 | + [0.05, 0.2], |
| 346 | + [0, 0.9], |
| 347 | + [0, 3.14 / 2], |
| 348 | + [0.5, 2], |
| 349 | + ], |
| 350 | + dtype=torch.double, |
| 351 | + ).T |
| 352 | + |
| 353 | + def f(self, x: torch.Tensor) -> torch.Tensor: |
| 354 | + return torch.tensor(discrim_highdim(x), dtype=torch.double) |
| 355 | + |
| 356 | + |
| 357 | +class Hartmann6Binary(LSEProblemWithEdgeLogging): |
| 358 | + name = "hartmann6_binary" |
| 359 | + threshold = 0.5 |
| 360 | + bounds = torch.stack( |
| 361 | + ( |
| 362 | + torch.zeros(6, dtype=torch.double), |
| 363 | + torch.ones(6, dtype=torch.double), |
| 364 | + ) |
| 365 | + ) |
| 366 | + |
| 367 | + def f(self, X: torch.Tensor) -> torch.Tensor: |
| 368 | + y = torch.tensor([modified_hartmann6(x) for x in X], dtype=torch.double) |
| 369 | + |
| 370 | + |
| 371 | +class ContrastSensitivity6d(LSEProblemWithEdgeLogging): |
| 372 | + """ |
| 373 | + Uses a surrogate model fit to real data from a constrast sensitivity study. |
| 374 | + """ |
| 375 | + |
| 376 | + name = "contrast_sensitivity_6d" |
| 377 | + threshold = 0.75 |
| 378 | + bounds = torch.tensor( |
| 379 | + [[-1.5, 0], [-1.5, 0], [0, 20], [0.5, 7], [1, 10], [0, 10]], |
| 380 | + dtype=torch.double, |
| 381 | + ).T |
| 382 | + |
| 383 | + def __init__(self): |
| 384 | + |
| 385 | + # Load the data |
| 386 | + self.data = np.loadtxt( |
| 387 | + os.path.join("..", "..", "dataset", "csf_dataset.csv"), |
| 388 | + delimiter=",", |
| 389 | + skiprows=1, |
| 390 | + ) |
| 391 | + y = torch.LongTensor(self.data[:, 0]) |
| 392 | + x = torch.Tensor(self.data[:, 1:]) |
| 393 | + |
| 394 | + # Fit a model, with a large number of inducing points |
| 395 | + self.m = GPClassificationModel( |
| 396 | + lb=self.bounds[0], |
| 397 | + ub=self.bounds[1], |
| 398 | + inducing_size=100, |
| 399 | + inducing_point_method="kmeans++", |
| 400 | + ) |
| 401 | + |
| 402 | + self.m.fit( |
| 403 | + x, |
| 404 | + y, |
| 405 | + ) |
| 406 | + |
| 407 | + def f(self, X: torch.Tensor) -> torch.Tensor: |
| 408 | + # clamp f to 0 since we expect p(x) to be lower-bounded at 0.5 |
| 409 | + return torch.clamp(self.m.predict(torch.tensor(X))[0], min=0) |
0 commit comments