Skip to content

Commit 7819b41

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
PFN Integration (pytorch#2784)
Summary: An initial working and tests passing implementation of PFNs that is compatible with botorch's `optimize_acqf`. Reviewed By: Balandat Differential Revision: D61689567
1 parent da28b43 commit 7819b41

File tree

8 files changed

+881
-0
lines changed

8 files changed

+881
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""Abstract base module for all botorch acquisition functions."""
8+
9+
from __future__ import annotations
10+
11+
from abc import ABC, abstractmethod
12+
13+
import torch
14+
15+
from botorch.acquisition import AcquisitionFunction
16+
from botorch.models.model import Model
17+
from botorch.utils.transforms import t_batch_mode_transform
18+
from torch import Tensor
19+
20+
21+
class DiscretizedAcquistionFunction(AcquisitionFunction, ABC):
22+
r"""DiscretizedAcquistionFunction is an abstract base class for acquisition
23+
functions that are defined on discrete distributions. It wraps a model and
24+
implements a forward method that computes the acquisition function value at
25+
a given set of points.
26+
This class can be subclassed to define acquisiton functions for Riemann-
27+
distributed posteriors.
28+
The acquisition function must have the form $$acq(x) = \int p(y|x) ag(x)$$,
29+
where $$ag$$ is defined differently for each acquisition function.
30+
The ag_integrate method, which computes the integral of ag between two points, must be
31+
implemented by subclasses to define the specific acquisition functions.
32+
"""
33+
34+
def __init__(self, model: Model) -> None:
35+
super().__init__(model=model)
36+
37+
@t_batch_mode_transform(expected_q=1)
38+
def forward(self, X: Tensor) -> Tensor:
39+
r"""Evaluate the acquisition function on the candidate set X.
40+
41+
Args:
42+
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim
43+
design points each.
44+
45+
Returns:
46+
A `(b)`-dim Tensor of the acquisition function at the given
47+
design points `X`.
48+
"""
49+
self.to(device=X.device)
50+
51+
discrete_posterior = self.model.posterior(X)
52+
result = discrete_posterior.integrate(self.ag_integrate)
53+
# remove q dimension
54+
return result.squeeze(-1)
55+
56+
@abstractmethod
57+
def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor:
58+
r"""
59+
This function calculates the integral that computes the acquisition function
60+
without the posterior factor from lower_bound to upper_bound.
61+
That is, our acquisition function is assumed to have the form
62+
\int ag(x) * p(x) dx,
63+
and this function calculates \int_{lower_bound}^{upper_bound} ag(x) dx.
64+
The `integrate` method of the posterior (`BoundedRiemannPosterior`)
65+
then computes the final acquisition value.
66+
67+
Args:
68+
lower_bound: lower bound of integral
69+
upper_bound: upper bound of integral
70+
71+
Returns:
72+
A `(b)`-dim Tensor of acquisition function derivatives at the given
73+
design points `X`.
74+
"""
75+
pass # pragma: no cover
76+
77+
r"""DiscretizedExpectedImprovement is an acquisition function that computes
78+
the expected improvement over the current best observed value for a Riemann
79+
distribution."""
80+
81+
82+
class DiscretizedExpectedImprovement(DiscretizedAcquistionFunction):
83+
r"""DiscretizedExpectedImprovement is an acquisition function that
84+
computes the expected improvement over the current best observed value
85+
for a Riemann distribution.
86+
"""
87+
88+
def __init__(self, model: Model, best_f: Tensor) -> None:
89+
super().__init__(model)
90+
self.register_buffer("best_f", torch.as_tensor(best_f))
91+
92+
def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor:
93+
r"""
94+
As Expected improvement has ag(y) = (y - best_f).clamp(min=0), and
95+
is defined as \int ag(y) * p(y) dy, we can calculate the integral
96+
of ag(y) like so:
97+
We just calculate ag(y) at beginning and end, and since the function has
98+
a gradient of 1 or 0, we can just take the average of the two.
99+
100+
Args:
101+
lower_bound: lower bound of integral
102+
upper_bound: upper bound of integral
103+
104+
Returns:
105+
A `(b)`-dim Tensor of acquisition function derivatives at the given
106+
design points `X`.
107+
"""
108+
max_lower_bound_and_f = torch.max(self.best_f, lower_bound)
109+
bucket_average = (upper_bound + max_lower_bound_and_f) / 2
110+
improvement = bucket_average - self.best_f
111+
112+
return improvement.clamp_min(0)
113+
114+
115+
class DiscretizedProbabilityOfImprovement(DiscretizedAcquistionFunction):
116+
r"""DiscretizedProbabilityOfImprovement is an acquisition function that
117+
computes the probability of improvement over the current best observed value
118+
for a Riemann distribution.
119+
"""
120+
121+
def __init__(self, model: Model, best_f: Tensor) -> None:
122+
super().__init__(model)
123+
self.register_buffer("best_f", torch.as_tensor(best_f))
124+
125+
def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor:
126+
r"""
127+
PI is defined as \int ag(y) * p(y) dy, where ag(y) = I(y - best_f)
128+
and I being the indicator function.
129+
130+
So all we need to do is calculate the portion between the bounds
131+
that is larger than best_f.
132+
We do this by comparing how much higher the upper bound is than best_f,
133+
compared to the size of the bucket.
134+
Then we clamp at one if best_f is below lower_bound and at zero if
135+
best_f is above upper_bound.
136+
137+
Args:
138+
lower_bound: lower bound of integral
139+
upper_bound: upper bound of integral
140+
141+
Returns:
142+
A `(b)`-dim Tensor of acquisition function derivatives at the given
143+
design points `X`.
144+
"""
145+
proportion = (upper_bound - self.best_f) / (upper_bound - lower_bound)
146+
return proportion.clamp(0, 1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
This module defines the botorch model for PFNs (`PFNModel`) and it
9+
provides handy helpers to download pretrained, public PFNs
10+
with `download_model` and model paths with `ModelPaths`.
11+
For the latter to work `pfns4bo` must be installed.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
from typing import Optional, Union
17+
18+
import torch.nn as nn
19+
20+
from botorch.acquisition.objective import PosteriorTransform
21+
from botorch.exceptions.errors import UnsupportedError
22+
from botorch.models.model import Model
23+
24+
from botorch_community.posteriors.riemann import BoundedRiemannPosterior
25+
from torch import Tensor
26+
27+
28+
class PFNModel(Model):
29+
"""Prior-data Fitted Network"""
30+
31+
def __init__(
32+
self,
33+
train_X: Tensor,
34+
train_Y: Tensor,
35+
model: nn.Module,
36+
) -> None:
37+
"""Initialize a PFNModel.
38+
39+
Args:
40+
train_X: A `batch_shape x n x d` tensor of training features.
41+
train_Y: A `batch_shape x n x m` tensor of training observations.
42+
model: A PFN model with the following
43+
forward(train_X, train_Y, X) -> logit predictions of shape
44+
`n x b x c` where c is the number of discrete buckets
45+
borders: A `c+1`-dim tensor of bucket borders
46+
"""
47+
super().__init__()
48+
self.train_X = train_X
49+
self.train_Y = train_Y
50+
self.pfn = model.to(train_X)
51+
52+
def posterior(
53+
self,
54+
X: Tensor,
55+
output_indices: Optional[list[int]] = None,
56+
observation_noise: Union[bool, Tensor] = False,
57+
posterior_transform: Optional[PosteriorTransform] = None,
58+
) -> BoundedRiemannPosterior:
59+
r"""Computes the posterior over model outputs at the provided points.
60+
61+
Note: The input transforms should be applied here using
62+
`self.transform_inputs(X)` after the `self.eval()` call and before
63+
any `model.forward` or `model.likelihood` calls.
64+
65+
Args:
66+
X: A `b x q x d`-dim Tensor, where `d` is the dimension of the
67+
feature space, `q` is the number of points considered jointly,
68+
and `b` is the batch dimension.
69+
**Currently not supported for PFNModel**.
70+
output_indices: **Currenlty not supported for PFNModel.**
71+
observation_noise: **Currently not supported for PFNModel**.
72+
posterior_transform: **Currently not supported for PFNModel**.
73+
74+
Returns:
75+
A `BoundedRiemannPosterior` object, representing a batch of `b` joint distributions
76+
over `q` points and `m` outputs each.
77+
"""
78+
self.pfn.eval()
79+
if output_indices is not None:
80+
raise RuntimeError(
81+
"output_indices is not None. PFNModel should not be a multi-output model."
82+
)
83+
if observation_noise:
84+
raise UnsupportedError("observation_noise is not supported for PFNModel.")
85+
if posterior_transform is not None:
86+
raise UnsupportedError("posterior_transform is not supported for PFNModel.")
87+
88+
if len(X.shape) > 2 and X.shape[-2] > 1:
89+
raise NotImplementedError("q must be 1 for PFNModel.") # add support later
90+
91+
# flatten batch dimensions for PFN input
92+
logits = self.pfn(self.train_X, self.train_Y, X)
93+
94+
probabilities = logits.softmax(dim=-1)
95+
96+
return BoundedRiemannPosterior(self.pfn.criterion.borders, probabilities)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import gzip
8+
import io
9+
import os
10+
from enum import Enum
11+
from typing import Optional
12+
13+
try:
14+
import requests
15+
except ImportError:
16+
raise ImportError(
17+
"The `requests` library is required to run `download_model`. "
18+
"You can install it using pip: `pip install requests`"
19+
)
20+
21+
22+
import torch
23+
import torch.nn as nn
24+
25+
26+
class ModelPaths(Enum):
27+
"""Enum for PFN models"""
28+
29+
pfns4bo_hebo = "https://github.com/automl/PFNs4BO/raw/refs/heads/main/pfns4bo/final_models/hebo_morebudget_9_unused_features_3_userpriorperdim2_8.pt.gz"
30+
pfns4bo_bnn = "https://github.com/automl/PFNs4BO/raw/refs/heads/main/pfns4bo/final_models/model_sampled_warp_simple_mlp_for_hpob_46.pt.gz"
31+
pfns4bo_hebo_userprior = "https://github.com/automl/PFNs4BO/raw/refs/heads/main/pfns4bo/final_models/hebo_morebudget_9_unused_features_3_userpriorperdim2_8.pt.gz"
32+
33+
34+
def download_model(
35+
model_path: str | ModelPaths,
36+
proxies: Optional[dict[str, str]] = None,
37+
cache_dir: Optional[str] = None,
38+
) -> nn.Module:
39+
"""Download and load PFN model weights from a URL.
40+
41+
Args:
42+
model_path: A string representing the URL of the model to load or a ModelPaths enum.
43+
proxies: An optional dictionary mapping from network protocols, e.g. ``http``, to proxy addresses.
44+
cache_dir: The cache dir to use, if not specified we will use ``/tmp/botorch_pfn_models``
45+
46+
Returns:
47+
A PFN model.
48+
"""
49+
if isinstance(model_path, ModelPaths):
50+
model_path = model_path.value
51+
52+
if cache_dir is None:
53+
cache_dir = "/tmp/botorch_pfn_models"
54+
55+
os.makedirs(cache_dir, exist_ok=True)
56+
cache_path = os.path.join(cache_dir, model_path.split("/")[-1])
57+
58+
if not os.path.exists(cache_path):
59+
# Download the model weights
60+
response = requests.get(model_path, proxies=proxies or None)
61+
response.raise_for_status()
62+
63+
# Decompress the gzipped model weights
64+
with gzip.GzipFile(fileobj=io.BytesIO(response.content)) as gz:
65+
model = torch.load(gz, map_location=torch.device("cpu"))
66+
67+
# Save the model to cache
68+
torch.save(model, cache_path)
69+
print("saved at: ", cache_path)
70+
else:
71+
# Load the model from cache
72+
model = torch.load(cache_path, map_location=torch.device("cpu"))
73+
print("loaded from cache: ", cache_path)
74+
75+
return model

0 commit comments

Comments
 (0)