Skip to content

Commit 48129e4

Browse files
authored
Merge pull request #259 from alan-turing-institute/attention
Attention
2 parents 57ee46f + c4c1a09 commit 48129e4

10 files changed

+476
-10
lines changed

autoemulate/compare.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def setup(
110110
self.model_names = self.model_registry.get_model_names(models, is_core=True)
111111
self.models = _process_models(
112112
model_registry=self.model_registry,
113-
models=list(self.model_names.keys()),
113+
model_names=list(self.model_names.keys()),
114114
y=self.y,
115115
scale=scale,
116116
scaler=scaler,

autoemulate/emulators/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ..model_registry import ModelRegistry
22
from .conditional_neural_process import ConditionalNeuralProcess
3+
from .conditional_neural_process_attn import AttentiveConditionalNeuralProcess
34
from .gaussian_process import GaussianProcess
45
from .gaussian_process_mogp import GaussianProcessMOGP
56
from .gaussian_process_mt import GaussianProcessMT
@@ -38,6 +39,11 @@
3839

3940

4041
# non-core models
42+
model_registry.register_model(
43+
AttentiveConditionalNeuralProcess().model_name,
44+
AttentiveConditionalNeuralProcess,
45+
is_core=False,
46+
)
4147
model_registry.register_model(
4248
GaussianProcessMT().model_name, GaussianProcessMT, is_core=False
4349
)

autoemulate/emulators/conditional_neural_process.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch import nn
1717

1818
from autoemulate.emulators.neural_networks.cnp_module import CNPModule
19+
from autoemulate.emulators.neural_networks.cnp_module_attn import AttnCNPModule
1920
from autoemulate.emulators.neural_networks.datasets import cnp_collate_fn
2021
from autoemulate.emulators.neural_networks.datasets import CNPDataset
2122
from autoemulate.emulators.neural_networks.losses import CNPLoss
@@ -140,9 +141,6 @@ def __init__(
140141
self.activation = activation
141142
self.optimizer = optimizer
142143
self.normalize_y = normalize_y
143-
if attention:
144-
warnings.warn("Attention is not implemented yet, setting to False.")
145-
attention = False
146144
self.attention = attention
147145
self.device = device
148146
self.random_state = random_state
@@ -181,8 +179,9 @@ def fit(self, X, y):
181179
if self.random_state is not None:
182180
set_random_seed(self.random_state)
183181

182+
module = CNPModule if not self.attention else AttnCNPModule
184183
self.model_ = NeuralNetRegressor(
185-
CNPModule,
184+
module,
186185
module__input_dim=self.input_dim_,
187186
module__output_dim=self.output_dim_,
188187
module__hidden_dim=self.hidden_dim,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
import torch.nn as nn
3+
from sklearn.base import BaseEstimator
4+
from sklearn.base import RegressorMixin
5+
6+
from autoemulate.emulators.conditional_neural_process import ConditionalNeuralProcess
7+
from autoemulate.utils import set_random_seed
8+
9+
10+
class AttentiveConditionalNeuralProcess(ConditionalNeuralProcess):
11+
def __init__(
12+
self,
13+
# architecture
14+
hidden_dim=64,
15+
latent_dim=64,
16+
hidden_layers_enc=3,
17+
hidden_layers_dec=3,
18+
# data per episode
19+
min_context_points=3,
20+
max_context_points=10,
21+
n_episode=32,
22+
# training
23+
max_epochs=100,
24+
lr=5e-3,
25+
batch_size=16,
26+
activation=nn.ReLU,
27+
optimizer=torch.optim.AdamW,
28+
normalize_y=True,
29+
# misc
30+
device="cpu",
31+
random_state=None,
32+
attention=True,
33+
):
34+
super().__init__(
35+
hidden_dim=hidden_dim,
36+
latent_dim=latent_dim,
37+
hidden_layers_enc=hidden_layers_enc,
38+
hidden_layers_dec=hidden_layers_dec,
39+
min_context_points=min_context_points,
40+
max_context_points=max_context_points,
41+
n_episode=n_episode,
42+
max_epochs=max_epochs,
43+
lr=lr,
44+
batch_size=batch_size,
45+
activation=activation,
46+
optimizer=optimizer,
47+
normalize_y=normalize_y,
48+
device=device,
49+
random_state=random_state,
50+
attention=attention,
51+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from scipy.stats import loguniform
6+
from skopt.space import Categorical
7+
from skopt.space import Real
8+
9+
10+
class Encoder(nn.Module):
11+
"""
12+
Deterministic encoder for conditional neural process model.
13+
"""
14+
15+
def __init__(
16+
self,
17+
input_dim,
18+
output_dim,
19+
hidden_dim,
20+
latent_dim,
21+
hidden_layers_enc,
22+
activation,
23+
context_mask=None,
24+
):
25+
super().__init__()
26+
layers = [nn.Linear(input_dim + output_dim, hidden_dim), activation()]
27+
for _ in range(hidden_layers_enc):
28+
layers.extend([nn.Linear(hidden_dim, hidden_dim), activation()])
29+
layers.append(nn.Linear(hidden_dim, latent_dim))
30+
self.net = nn.Sequential(*layers)
31+
32+
self.x_encoder = nn.Linear(input_dim, latent_dim)
33+
34+
self.crossattn = nn.MultiheadAttention(
35+
embed_dim=latent_dim, num_heads=4, batch_first=True
36+
)
37+
38+
def forward(self, x_context, y_context, x_target, context_mask=None):
39+
"""
40+
Encode context
41+
42+
Parameters
43+
----------
44+
x_context: (batch_size, n_context_points, input_dim)
45+
y_context: (batch_size, n_context_points, output_dim)
46+
context_mask: (batch_size, n_context_points)
47+
48+
Returns
49+
-------
50+
r: (batch_size, n_points, latent_dim)
51+
"""
52+
# context self attention
53+
x = torch.cat([x_context, y_context], dim=-1)
54+
r = self.net(x)
55+
# q, k, v
56+
x_target_enc = self.x_encoder(x_target)
57+
x_context_enc = self.x_encoder(x_context)
58+
if context_mask is not None:
59+
r, _ = self.crossattn(
60+
x_target_enc,
61+
x_context_enc,
62+
r,
63+
need_weights=False,
64+
key_padding_mask=context_mask,
65+
)
66+
else:
67+
r, _ = self.crossattn(x_target_enc, x_context_enc, r, need_weights=False)
68+
return r
69+
70+
71+
class Decoder(nn.Module):
72+
def __init__(
73+
self,
74+
input_dim,
75+
latent_dim,
76+
hidden_dim,
77+
output_dim,
78+
hidden_layers_dec,
79+
activation,
80+
):
81+
super().__init__()
82+
layers = [nn.Linear(latent_dim + input_dim, hidden_dim), activation()]
83+
for _ in range(hidden_layers_dec):
84+
layers.extend([nn.Linear(hidden_dim, hidden_dim), activation()])
85+
self.net = nn.Sequential(*layers)
86+
self.mean_head = nn.Linear(hidden_dim, output_dim)
87+
self.logvar_head = nn.Linear(hidden_dim, output_dim)
88+
89+
def forward(self, r, x_target):
90+
"""
91+
Decode using representation r and target points x_target
92+
93+
Parameters
94+
----------
95+
r: (batch_size, n_points, latent_dim)
96+
x_target: (batch_size, n_points, input_dim)
97+
98+
Returns
99+
-------
100+
mean: (batch_size, n_points, output_dim)
101+
logvar: (batch_size, n_points, output_dim)
102+
"""
103+
x = torch.cat([r, x_target], dim=-1)
104+
hidden = self.net(x)
105+
mean = self.mean_head(hidden)
106+
logvar = self.logvar_head(hidden)
107+
108+
return mean, logvar
109+
110+
111+
class AttnCNPModule(nn.Module):
112+
def __init__(
113+
self,
114+
input_dim,
115+
output_dim,
116+
hidden_dim,
117+
latent_dim,
118+
hidden_layers_enc,
119+
hidden_layers_dec,
120+
activation=nn.ReLU,
121+
):
122+
super().__init__()
123+
self.encoder = Encoder(
124+
input_dim, output_dim, hidden_dim, latent_dim, hidden_layers_enc, activation
125+
)
126+
self.decoder = Decoder(
127+
input_dim, latent_dim, hidden_dim, output_dim, hidden_layers_dec, activation
128+
)
129+
130+
def forward(self, X_context, y_context, X_target=None, context_mask=None):
131+
"""
132+
133+
Parameters
134+
----------
135+
X_context: (batch_size, n_context_points, input_dim)
136+
y_context: (batch_size, n_context_points, output_dim)
137+
X_target: (batch_size, n_target_points, input_dim)
138+
context_mask: (batch_size, n_context_points), currently unused,
139+
as we pad with 0's and don't have attention, layernorm yet.
140+
141+
Returns
142+
-------
143+
mean: (batch_size, n_points, output_dim)
144+
logvar: (batch_size, n_points, output_dim)
145+
"""
146+
# inverse context_mask
147+
if context_mask is not None:
148+
context_mask = ~context_mask
149+
r = self.encoder(X_context, y_context, X_target)
150+
mean, logvar = self.decoder(r, X_target)
151+
return mean, logvar

autoemulate/model_processing.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,16 @@ def _wrap_models_in_pipeline(models, scale, scaler, reduce_dim, dim_reducer):
6565
return models_piped
6666

6767

68-
def _process_models(model_registry, models, y, scale, scaler, reduce_dim, dim_reducer):
68+
def _process_models(
69+
model_registry, model_names, y, scale, scaler, reduce_dim, dim_reducer
70+
):
6971
"""Get and process models.
7072
7173
Parameters
7274
----------
7375
model_registry : ModelRegistry
7476
An instance of the ModelRegistry class.
75-
models : list
77+
model_names : list
7678
List of model names.
7779
y : array-like, shape (n_samples, n_outputs)
7880
Simulation output.
@@ -86,7 +88,7 @@ def _process_models(model_registry, models, y, scale, scaler, reduce_dim, dim_re
8688
models : list
8789
List of model instances.
8890
"""
89-
models = model_registry.get_models(models)
91+
models = model_registry.get_models(model_names)
9092
models_multi = _turn_models_into_multioutput(models, y)
9193
models_scaled = _wrap_models_in_pipeline(
9294
models_multi, scale, scaler, reduce_dim, dim_reducer

autoemulate/model_registry.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_model_names(self, models=None, is_core=False):
2323
models : str or list of str
2424
The name(s) of the model(s) to get long and short names for.
2525
is_core : bool
26-
Whether to return only core model names.
26+
Whether to return only core model names in case `models` is None.
2727
2828
Returns
2929
-------
@@ -61,7 +61,7 @@ def get_model_names(self, models=None, is_core=False):
6161
k: v for k, v in model_names.items() if k in models or v in models
6262
}
6363

64-
if is_core:
64+
if models is None and is_core:
6565
model_names = {
6666
k: v for k, v in model_names.items() if k in self.core_model_names
6767
}

0 commit comments

Comments
 (0)