From 1d2ad1fdea279d2a7a3ed6051aba08a679e0e053 Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Sat, 4 Feb 2023 08:24:56 -0800 Subject: [PATCH] acquisition function wrapper Summary: Add a wrapper for modifying inputs/outputs. This is useful for not only probabilistic reparameterization, but will also simplify other integrated AFs (e.g. MCMC) as well as fixed feature AFs and things like prior-guided AFs Differential Revision: D41629186 fbshipit-source-id: cef679d1519c2d2ef5d9fb759c992827e87f2796 --- botorch/acquisition/wrapper.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 botorch/acquisition/wrapper.py diff --git a/botorch/acquisition/wrapper.py b/botorch/acquisition/wrapper.py new file mode 100644 index 0000000000..d2f100be8e --- /dev/null +++ b/botorch/acquisition/wrapper.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +A wrapper classes around AcquisitionFunctions to modify inputs and outputs. +""" + +from __future__ import annotations + +from botorch.acquisition.acquisition import AcquisitionFunction +from torch.nn import Module + + +class AcquisitionFunctionWrapper(AcquisitionFunction): + r"""Abstract acquisition wrapper.""" + + def __init__(self, acq_function: AcquisitionFunction) -> None: + Module.__init__(self) + self.__class__ = type( + acq_function.__class__.__name__, + (self.__class__, acq_function.__class__), + {}, + ) + self.__dict__ = acq_function.__dict__ + self.acq_function = acq_function