Skip to content

Commit b69be96

Browse files
authored
Merge pull request #460 from Thomas-Christie/expected-improvement
Add expected improvement utility function
2 parents fc45b9f + 8fb0f9a commit b69be96

25 files changed

+385
-206
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ repos:
4040
types: [python]
4141
- id: ruff
4242
name: ruff
43-
entry: ruff
43+
entry: ruff check
4444
args: ["--exit-non-zero-on-fix"]
4545
require_serial: true
4646
language: system

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ help: ## Display this help
1616

1717
##@ Formatting
1818
black: ## Format code in-place using black.
19-
black ${PKGROOT}/ tests/ -l 79 .
19+
black ${PKGROOT}/ tests/ -l 88 .
2020

2121
isort: ## Format imports in-place using isort.
2222
isort ${PKGROOT}/ tests/
2323

2424
format: ## Code styling - black, isort
25-
black ${PKGROOT}/ tests/ -l 100 .
25+
black ${PKGROOT}/ tests/ -l 88 .
2626
@printf "\033[1;34mBlack passes!\033[0m\n\n"
2727
isort ${PKGROOT}/ tests/
2828
@printf "\033[1;34misort passes!\033[0m\n\n"

docs/examples/bayesian_optimisation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def obtain_log_regret_statistics(
728728
#
729729
# - **Expected Improvement (EI)** ([Močkus, 1974](https://link.springer.com/chapter/10.1007/3-540-07165-2_55)) - EI goes beyond PI by not only considering the
730730
# probability of improving on the current best observed point, but also taking into
731-
# account the \textit{magnitude} of improvement. Mathematically, this is defined as
731+
# account the *magnitude* of improvement. Mathematically, this is defined as
732732
# follows:
733733
# $$
734734
# \begin{aligned}

docs/examples/decision_making.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
240240
# %% [markdown]
241241

242242
# It is worth noting that `ThompsonSampling` is not the only utility function we could use,
243-
# since our module also provides e.g. `ProbabilityOfImprovement`,
244-
# which was briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/).
243+
# since our module also provides e.g. `ProbabilityOfImprovement`, `ExpectedImprovment`,
244+
# which were briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/).
245245

246246

247247
# %% [markdown]

gpjax/decision_making/test_functions/continuous_functions.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
# ==============================================================================
15-
from abc import (
16-
ABC,
17-
abstractmethod,
18-
)
15+
from abc import abstractmethod
1916
from dataclasses import dataclass
2017

2118
import jax.numpy as jnp
2219
from jaxtyping import (
2320
Array,
2421
Float,
22+
Num,
2523
)
24+
import tensorflow_probability.substrates.jax as tfp
2625

2726
from gpjax.dataset import Dataset
2827
from gpjax.decision_making.search_space import ContinuousSearchSpace
28+
from gpjax.gps import AbstractMeanFunction
2929
from gpjax.typing import KeyArray
3030

3131

32-
class AbstractContinuousTestFunction(ABC):
32+
class AbstractContinuousTestFunction(AbstractMeanFunction):
3333
"""
3434
Abstract base class for continuous test functions.
3535
@@ -43,19 +43,28 @@ class AbstractContinuousTestFunction(ABC):
4343
minimizer: Float[Array, "1 D"]
4444
minimum: Float[Array, "1 1"]
4545

46-
def generate_dataset(self, num_points: int, key: KeyArray) -> Dataset:
46+
def generate_dataset(
47+
self, num_points: int, key: KeyArray, obs_stddev: float = 0.0
48+
) -> Dataset:
4749
"""
4850
Generate a toy dataset from the test function.
4951
5052
Args:
5153
num_points (int): Number of points to sample.
5254
key (KeyArray): JAX PRNG key.
55+
obs_stddev (float): (Optional) standard deviation of Gaussian distributed
56+
noise added to observations.
5357
5458
Returns:
5559
Dataset: Dataset of points sampled from the test function.
5660
"""
5761
X = self.search_space.sample(num_points=num_points, key=key)
58-
y = self.evaluate(X)
62+
gaussian_noise = tfp.distributions.Normal(
63+
jnp.zeros(num_points), obs_stddev * jnp.ones(num_points)
64+
)
65+
y = self.evaluate(X) + jnp.transpose(
66+
gaussian_noise.sample(sample_shape=[1], seed=key)
67+
)
5968
return Dataset(X=X, y=y)
6069

6170
def generate_test_points(
@@ -73,6 +82,9 @@ def generate_test_points(
7382
"""
7483
return self.search_space.sample(num_points=num_points, key=key)
7584

85+
def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
86+
return self.evaluate(x)
87+
7688
@abstractmethod
7789
def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
7890
"""

gpjax/decision_making/test_functions/non_conjugate_functions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
import jax.numpy as jnp
1919
import jax.random as jr
20-
from jaxtyping import (
21-
Array,
22-
Float,
23-
Integer,
24-
)
2520

2621
from gpjax.dataset import Dataset
2722
from gpjax.decision_making.search_space import ContinuousSearchSpace
28-
from gpjax.typing import KeyArray
23+
from gpjax.typing import (
24+
Array,
25+
Float,
26+
Int,
27+
KeyArray,
28+
)
2929

3030

3131
@dataclass
@@ -74,7 +74,7 @@ def generate_test_points(
7474
return self.search_space.sample(num_points=num_points, key=key)
7575

7676
@abstractmethod
77-
def evaluate(self, x: Float[Array, "N 1"]) -> Integer[Array, "N 1"]:
77+
def evaluate(self, x: Float[Array, "N 1"]) -> Int[Array, "N 1"]:
7878
"""
7979
Evaluate the test function at a set of points. Function taken from
8080
https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
@@ -83,7 +83,7 @@ def evaluate(self, x: Float[Array, "N 1"]) -> Integer[Array, "N 1"]:
8383
x (Float[Array, 'N D']): Points to evaluate the test function at.
8484
8585
Returns:
86-
Integer[Array, 'N 1']: Values of the test function at the points.
86+
Float[Array, 'N 1']: Values of the test function at the points.
8787
"""
8888
key = jr.key(42)
8989
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x

gpjax/decision_making/utility_functions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
SinglePointUtilityFunction,
1919
UtilityFunction,
2020
)
21+
from gpjax.decision_making.utility_functions.expected_improvement import (
22+
ExpectedImprovement,
23+
)
2124
from gpjax.decision_making.utility_functions.probability_of_improvement import (
2225
ProbabilityOfImprovement,
2326
)
@@ -27,6 +30,7 @@
2730
"UtilityFunction",
2831
"AbstractUtilityFunctionBuilder",
2932
"AbstractSinglePointUtilityFunctionBuilder",
33+
"ExpectedImprovement",
3034
"SinglePointUtilityFunction",
3135
"ThompsonSampling",
3236
"ProbabilityOfImprovement",
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from dataclasses import dataclass
16+
from functools import partial
17+
18+
from beartype.typing import Mapping
19+
import jax.numpy as jnp
20+
import tensorflow_probability.substrates.jax as tfp
21+
22+
from gpjax.dataset import Dataset
23+
from gpjax.decision_making.utility_functions.base import (
24+
AbstractSinglePointUtilityFunctionBuilder,
25+
SinglePointUtilityFunction,
26+
)
27+
from gpjax.decision_making.utils import (
28+
OBJECTIVE,
29+
get_best_latent_observation_val,
30+
)
31+
from gpjax.gps import ConjugatePosterior
32+
from gpjax.typing import (
33+
Array,
34+
Float,
35+
KeyArray,
36+
)
37+
38+
39+
@dataclass
40+
class ExpectedImprovement(AbstractSinglePointUtilityFunctionBuilder):
41+
"""
42+
Expected Improvement acquisition function as introduced by [Močkus,
43+
1974](https://link.springer.com/chapter/10.1007/3-540-07165-2_55). The "best"
44+
incumbent value is defined as the lowest posterior mean value evaluated at the the
45+
previously observed points. This enables the acquisition function to be utilised with noisy observations.
46+
"""
47+
48+
def build_utility_function(
49+
self,
50+
posteriors: Mapping[str, ConjugatePosterior],
51+
datasets: Mapping[str, Dataset],
52+
key: KeyArray,
53+
) -> SinglePointUtilityFunction:
54+
r"""
55+
Build the Expected Improvement acquisition function. This computes the expected
56+
improvement over the "best" of the previously observed points, utilising the
57+
posterior distribution of the surrogate model. For posterior distribution
58+
$`f(\cdot)`$, and best incumbent value $`\eta`$, this is defined
59+
as:
60+
```math
61+
\alpha_{\text{EI}}(\mathbf{x}) = \mathbb{E}\left[\max(0, \eta - f(\mathbf{x}))\right]
62+
```
63+
64+
Args:
65+
posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
66+
used to form the utility function. One posteriors must correspond to the
67+
`OBJECTIVE` key, as we utilise the objective posterior to form the utility
68+
function.
69+
datasets (Mapping[str, Dataset]): Dictionary of datasets used to form the
70+
utility function. Keys in `datasets` should correspond to keys in
71+
`posteriors`. One of the datasets must correspond to the `OBJECTIVE` key.
72+
key (KeyArray): JAX PRNG key used for random number generation.
73+
74+
Returns:
75+
SinglePointUtilityFunction: The Expected Improvement acquisition function to
76+
to be *maximised* in order to decide which point to query next.
77+
"""
78+
self.check_objective_present(posteriors, datasets)
79+
objective_posterior = posteriors[OBJECTIVE]
80+
objective_dataset = datasets[OBJECTIVE]
81+
82+
if not isinstance(objective_posterior, ConjugatePosterior):
83+
raise ValueError(
84+
"Objective posterior must be a ConjugatePosterior to compute the Expected Improvement."
85+
)
86+
87+
if (
88+
objective_dataset.X is None
89+
or objective_dataset.n == 0
90+
or objective_dataset.y is None
91+
):
92+
raise ValueError("Objective dataset must contain at least one item")
93+
94+
eta = get_best_latent_observation_val(objective_posterior, objective_dataset)
95+
return partial(
96+
_expected_improvement, objective_posterior, objective_dataset, eta
97+
)
98+
99+
100+
def _expected_improvement(
101+
objective_posterior: ConjugatePosterior,
102+
objective_dataset: Dataset,
103+
eta: Float[Array, ""],
104+
x: Float[Array, "N D"],
105+
) -> Float[Array, "N 1"]:
106+
latent_dist = objective_posterior(x, objective_dataset)
107+
mean = latent_dist.mean()
108+
var = latent_dist.variance()
109+
normal = tfp.distributions.Normal(mean, jnp.sqrt(var))
110+
return jnp.expand_dims(
111+
((eta - mean) * normal.cdf(eta) + var * normal.prob(eta)), -1
112+
)

gpjax/decision_making/utility_functions/probability_of_improvement.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
AbstractSinglePointUtilityFunctionBuilder,
2424
SinglePointUtilityFunction,
2525
)
26-
from gpjax.decision_making.utils import OBJECTIVE
26+
from gpjax.decision_making.utils import (
27+
OBJECTIVE,
28+
get_best_latent_observation_val,
29+
)
2730
from gpjax.gps import ConjugatePosterior
2831
from gpjax.typing import (
2932
Array,
@@ -107,14 +110,9 @@ def build_utility_function(
107110
)
108111

109112
def probability_of_improvement(x_test: Num[Array, "N D"]):
110-
# Computing the posterior mean for the training dataset
111-
# for computing the best_y value (as the minimum
112-
# posterior mean of the objective function)
113-
predictive_dist_for_training = objective_posterior.predict(
114-
objective_dataset.X, objective_dataset
113+
best_y = get_best_latent_observation_val(
114+
objective_posterior, objective_dataset
115115
)
116-
best_y = predictive_dist_for_training.mean().min()
117-
118116
predictive_dist = objective_posterior.predict(x_test, objective_dataset)
119117

120118
normal_dist = tfp.distributions.Normal(

gpjax/decision_making/utility_functions/thompson_sampling.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
SinglePointUtilityFunction,
2323
)
2424
from gpjax.decision_making.utils import OBJECTIVE
25-
from gpjax.gps import (
26-
ConjugatePosterior,
27-
NonConjugatePosterior,
28-
)
25+
from gpjax.gps import ConjugatePosterior
2926
from gpjax.typing import KeyArray
3027

3128

@@ -59,7 +56,7 @@ def __post_init__(self):
5956

6057
def build_utility_function(
6158
self,
62-
posteriors: Mapping[str, ConjugatePosterior | NonConjugatePosterior],
59+
posteriors: Mapping[str, ConjugatePosterior],
6360
datasets: Mapping[str, Dataset],
6461
key: KeyArray,
6562
) -> SinglePointUtilityFunction:
@@ -69,8 +66,8 @@ def build_utility_function(
6966
are *maximised*.
7067
7168
Args:
72-
posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be
73-
used to form the utility function. One of the posteriors must correspond
69+
posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
70+
be used to form the utility function. One of the posteriors must correspond
7471
to the `OBJECTIVE` key, as we sample from the objective posterior to form
7572
the utility function.
7673
datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used

0 commit comments

Comments
 (0)