Skip to content

Commit 44bf697

Browse files
SamuelGabrielfacebook-github-bot
authored andcommitted
PFN Integration (pytorch#2784)
Summary: Pull Request resolved: pytorch#2784 An initial working and tests passing implementation of PFNs that is compatible with botorch's `optimize_acqf`. Reviewed By: Balandat Differential Revision: D61689567
1 parent da28b43 commit 44bf697

File tree

8 files changed

+974
-1
lines changed

8 files changed

+974
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""Abstract base module for all botorch acquisition functions."""
8+
9+
from __future__ import annotations
10+
11+
from abc import ABC, abstractmethod
12+
13+
import torch
14+
15+
from botorch.acquisition import AcquisitionFunction
16+
from botorch.models.model import Model
17+
from botorch.utils.transforms import t_batch_mode_transform
18+
from torch import Tensor
19+
20+
21+
class DiscretizedAcquistionFunction(AcquisitionFunction, ABC):
22+
r"""DiscretizedAcquistionFunction is an abstract base class for acquisition
23+
functions that are defined on discrete distributions. It wraps a model and
24+
implements a forward method that computes the acquisition function value at
25+
a given set of points.
26+
This class can be subclassed to define acquisiton functions for Riemann-
27+
distributed posteriors.
28+
The acquisition function must have the form $$acq(x) = \int p(y|x) ag(x)$$,
29+
where $$ag$$ is defined differently for each acquisition function.
30+
The ag_integrate method, which computes the integral of ag between two points, must
31+
be implemented by subclasses to define the specific acquisition functions.
32+
"""
33+
34+
def __init__(self, model: Model) -> None:
35+
r"""
36+
Initialize the DiscretizedAcquistionFunction
37+
38+
Args:
39+
model: A fitted model that is used to compute the posterior
40+
distribution over the outcomes of interest.
41+
The model should be a `PFNModel`.
42+
"""
43+
44+
super().__init__(model=model)
45+
46+
@t_batch_mode_transform(expected_q=1)
47+
def forward(self, X: Tensor) -> Tensor:
48+
r"""Evaluate the acquisition function on the candidate set X.
49+
50+
Args:
51+
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim
52+
design points each.
53+
54+
Returns:
55+
A `(b)`-dim Tensor of the acquisition function at the given
56+
design points `X`.
57+
"""
58+
self.to(device=X.device)
59+
60+
discrete_posterior = self.model.posterior(X)
61+
result = discrete_posterior.integrate(self.ag_integrate)
62+
# remove q dimension
63+
return result.squeeze(-1)
64+
65+
@abstractmethod
66+
def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor:
67+
r"""
68+
This function calculates the integral that computes the acquisition function
69+
without the posterior factor from lower_bound to upper_bound.
70+
That is, our acquisition function is assumed to have the form
71+
\int ag(x) * p(x) dx,
72+
and this function calculates \int_{lower_bound}^{upper_bound} ag(x) dx.
73+
The `integrate` method of the posterior (`BoundedRiemannPosterior`)
74+
then computes the final acquisition value.
75+
76+
Args:
77+
lower_bound: lower bound of integral
78+
upper_bound: upper bound of integral
79+
80+
Returns:
81+
A `(b)`-dim Tensor of acquisition function derivatives at the given
82+
design points `X`.
83+
"""
84+
pass # pragma: no cover
85+
86+
r"""DiscretizedExpectedImprovement is an acquisition function that computes
87+
the expected improvement over the current best observed value for a Riemann
88+
distribution."""
89+
90+
91+
class DiscretizedExpectedImprovement(DiscretizedAcquistionFunction):
92+
r"""DiscretizedExpectedImprovement is an acquisition function that
93+
computes the expected improvement over the current best observed value
94+
for a Riemann distribution.
95+
"""
96+
97+
def __init__(self, model: Model, best_f: Tensor) -> None:
98+
r"""
99+
Initialize the DiscretizedExpectedImprovement
100+
101+
Args:
102+
model: A fitted model that is used to compute the posterior
103+
distribution over the outcomes of interest.
104+
The model should be a `PFNModel`.
105+
best_f: A tensor representing the current best observed value.
106+
"""
107+
super().__init__(model)
108+
self.register_buffer("best_f", torch.as_tensor(best_f))
109+
110+
def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor:
111+
r"""
112+
As Expected improvement has ag(y) = (y - best_f).clamp(min=0), and
113+
is defined as \int ag(y) * p(y) dy, we can calculate the integral
114+
of ag(y) like so:
115+
We just calculate ag(y) at beginning and end, and since the function has
116+
a gradient of 1 or 0, we can just take the average of the two.
117+
118+
Args:
119+
lower_bound: lower bound of integral
120+
upper_bound: upper bound of integral
121+
122+
Returns:
123+
A `(b)`-dim Tensor of acquisition function derivatives at the given
124+
design points `X`.
125+
"""
126+
max_lower_bound_and_f = torch.max(self.best_f, lower_bound)
127+
bucket_average = (upper_bound + max_lower_bound_and_f) / 2
128+
improvement = bucket_average - self.best_f
129+
130+
return improvement.clamp_min(0)
131+
132+
133+
class DiscretizedProbabilityOfImprovement(DiscretizedAcquistionFunction):
134+
r"""DiscretizedProbabilityOfImprovement is an acquisition function that
135+
computes the probability of improvement over the current best observed value
136+
for a Riemann distribution.
137+
"""
138+
139+
def __init__(self, model: Model, best_f: Tensor) -> None:
140+
r"""
141+
Initialize the DiscretizedProbabilityOfImprovement
142+
143+
Args:
144+
model: A fitted model that is used to compute the posterior
145+
distribution over the outcomes of interest.
146+
The model should be a `PFNModel`.
147+
best_f: A tensor representing the current best observed value.
148+
"""
149+
150+
super().__init__(model)
151+
self.register_buffer("best_f", torch.as_tensor(best_f))
152+
153+
def ag_integrate(self, lower_bound: Tensor, upper_bound: Tensor) -> Tensor:
154+
r"""
155+
PI is defined as \int ag(y) * p(y) dy, where ag(y) = I(y - best_f)
156+
and I being the indicator function.
157+
158+
So all we need to do is calculate the portion between the bounds
159+
that is larger than best_f.
160+
We do this by comparing how much higher the upper bound is than best_f,
161+
compared to the size of the bucket.
162+
Then we clamp at one if best_f is below lower_bound and at zero if
163+
best_f is above upper_bound.
164+
165+
Args:
166+
lower_bound: lower bound of integral
167+
upper_bound: upper bound of integral
168+
169+
Returns:
170+
A `(b)`-dim Tensor of acquisition function derivatives at the given
171+
design points `X`.
172+
"""
173+
proportion = (upper_bound - self.best_f) / (upper_bound - lower_bound)
174+
return proportion.clamp(0, 1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
r"""
8+
This module defines the botorch model for PFNs (`PFNModel`) and it
9+
provides handy helpers to download pretrained, public PFNs
10+
with `download_model` and model paths with `ModelPaths`.
11+
For the latter to work `pfns4bo` must be installed.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
from typing import Optional, Union
17+
18+
import torch.nn as nn
19+
20+
from botorch.acquisition.objective import PosteriorTransform
21+
from botorch.exceptions.errors import UnsupportedError
22+
from botorch.models.model import Model
23+
24+
from botorch_community.posteriors.riemann import BoundedRiemannPosterior
25+
from torch import Tensor
26+
27+
28+
class PFNModel(Model):
29+
"""Prior-data Fitted Network"""
30+
31+
def __init__(
32+
self,
33+
train_X: Tensor,
34+
train_Y: Tensor,
35+
model: nn.Module,
36+
) -> None:
37+
"""Initialize a PFNModel.
38+
39+
Args:
40+
train_X: A `batch_shape x n x d` tensor of training features.
41+
train_Y: A `batch_shape x n x m` tensor of training observations.
42+
model: A PFN model with the following
43+
forward(train_X, train_Y, X) -> logit predictions of shape
44+
`n x b x c` where c is the number of discrete buckets
45+
borders: A `c+1`-dim tensor of bucket borders
46+
"""
47+
super().__init__()
48+
self.train_X = train_X
49+
self.train_Y = train_Y
50+
self.pfn = model.to(train_X)
51+
52+
def posterior(
53+
self,
54+
X: Tensor,
55+
output_indices: Optional[list[int]] = None,
56+
observation_noise: Union[bool, Tensor] = False,
57+
posterior_transform: Optional[PosteriorTransform] = None,
58+
) -> BoundedRiemannPosterior:
59+
r"""Computes the posterior over model outputs at the provided points.
60+
61+
Note: The input transforms should be applied here using
62+
`self.transform_inputs(X)` after the `self.eval()` call and before
63+
any `model.forward` or `model.likelihood` calls.
64+
65+
Args:
66+
X: A `b x q x d`-dim Tensor, where `d` is the dimension of the
67+
feature space, `q` is the number of points considered jointly,
68+
and `b` is the batch dimension.
69+
We only allow `q=1` for PFNModel, so q can also be omitted, i.e.
70+
`b x d`-dim Tensor.
71+
**Currently not supported for PFNModel**.
72+
output_indices: **Currenlty not supported for PFNModel.**
73+
observation_noise: **Currently not supported for PFNModel**.
74+
posterior_transform: **Currently not supported for PFNModel**.
75+
76+
Returns:
77+
A `BoundedRiemannPosterior` object, representing a batch of `b` joint
78+
distributions over `q` points and `m` outputs each.
79+
"""
80+
self.pfn.eval()
81+
if output_indices is not None:
82+
raise RuntimeError(
83+
"output_indices is not None. PFNModel should not "
84+
"be a multi-output model."
85+
)
86+
if observation_noise:
87+
raise UnsupportedError("observation_noise is not supported for PFNModel.")
88+
if posterior_transform is not None:
89+
raise UnsupportedError("posterior_transform is not supported for PFNModel.")
90+
91+
if len(X.shape) > 2 and X.shape[-2] > 1:
92+
raise NotImplementedError("q must be 1 for PFNModel.") # add support later
93+
94+
# flatten batch dimensions for PFN input
95+
logits = self.pfn(self.train_X, self.train_Y, X)
96+
97+
probabilities = logits.softmax(dim=-1)
98+
99+
return BoundedRiemannPosterior(self.pfn.criterion.borders, probabilities)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import gzip
8+
import io
9+
import os
10+
from enum import Enum
11+
from typing import Optional
12+
13+
try:
14+
import requests
15+
except ImportError: # pragma: no cover
16+
raise ImportError(
17+
"The `requests` library is required to run `download_model`. "
18+
"You can install it using pip: `pip install requests`"
19+
)
20+
21+
22+
import torch
23+
import torch.nn as nn
24+
25+
26+
class ModelPaths(Enum):
27+
"""Enum for PFN models"""
28+
29+
pfns4bo_hebo = (
30+
"https://github.com/automl/PFNs4BO/raw/refs/heads/main/pfns4bo"
31+
"/final_models/hebo_morebudget_9_unused_features_3_userpriorperdim2_8.pt.gz"
32+
)
33+
pfns4bo_bnn = (
34+
"https://github.com/automl/PFNs4BO/raw/refs/heads/main/pfns4bo"
35+
"/final_models/model_sampled_warp_simple_mlp_for_hpob_46.pt.gz"
36+
)
37+
pfns4bo_hebo_userprior = (
38+
"https://github.com/automl/PFNs4BO/raw/refs/heads/main/pfns4bo"
39+
"/final_models/hebo_morebudget_9_unused_features_3_userpriorperdim2_8.pt.gz"
40+
)
41+
42+
43+
def download_model(
44+
model_path: str | ModelPaths,
45+
proxies: Optional[dict[str, str]] = None,
46+
cache_dir: Optional[str] = None,
47+
) -> nn.Module:
48+
"""Download and load PFN model weights from a URL.
49+
50+
Args:
51+
model_path: A string representing the URL of the model to load or a ModelPaths
52+
enum.
53+
proxies: An optional dictionary mapping from network protocols, e.g. ``http``,
54+
to proxy addresses.
55+
cache_dir: The cache dir to use, if not specified we will use
56+
``/tmp/botorch_pfn_models``
57+
58+
Returns:
59+
A PFN model.
60+
"""
61+
if isinstance(model_path, ModelPaths):
62+
model_path = model_path.value
63+
64+
if cache_dir is None:
65+
cache_dir = "/tmp/botorch_pfn_models"
66+
67+
os.makedirs(cache_dir, exist_ok=True)
68+
cache_path = os.path.join(cache_dir, model_path.split("/")[-1])
69+
70+
if not os.path.exists(cache_path):
71+
# Download the model weights
72+
response = requests.get(model_path, proxies=proxies or None)
73+
response.raise_for_status()
74+
75+
# Decompress the gzipped model weights
76+
with gzip.GzipFile(fileobj=io.BytesIO(response.content)) as gz:
77+
model = torch.load(gz, map_location=torch.device("cpu"))
78+
79+
# Save the model to cache
80+
torch.save(model, cache_path)
81+
print("saved at: ", cache_path)
82+
else:
83+
# Load the model from cache
84+
model = torch.load(cache_path, map_location=torch.device("cpu"))
85+
print("loaded from cache: ", cache_path)
86+
87+
return model

0 commit comments

Comments
 (0)