Skip to content

Commit f230253

Browse files
authored
Merge pull request #593 from alan-turing-institute/add-gaussian-mlp
Add Gaussian MLP to experimental
2 parents 5ae22c6 + efd3223 commit f230253

File tree

1 file changed

+200
-0
lines changed

1 file changed

+200
-0
lines changed
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import torch
2+
from autoemulate.core.device import TorchDeviceMixin
3+
from autoemulate.core.types import DeviceLike, GaussianLike, TensorLike
4+
from autoemulate.data.utils import set_random_seed
5+
from autoemulate.emulators.base import GaussianEmulator
6+
from autoemulate.emulators.nn.mlp import MLP
7+
from autoemulate.transforms.standardize import StandardizeTransform
8+
from autoemulate.transforms.utils import make_positive_definite
9+
from torch import nn
10+
from torch.optim.lr_scheduler import LRScheduler
11+
12+
13+
class GaussianMLP(GaussianEmulator, MLP):
14+
"""Multi-Layer Perceptron (MLP) emulator with Gaussian outputs."""
15+
16+
def __init__(
17+
self,
18+
x: TensorLike,
19+
y: TensorLike,
20+
standardize_x: bool = True,
21+
standardize_y: bool = True,
22+
activation_cls: type[nn.Module] = nn.ReLU,
23+
epochs: int = 500,
24+
batch_size: int = 16,
25+
layer_dims: list[int] | None = None,
26+
weight_init: str = "default",
27+
scale: float = 1.0,
28+
full_covariance: bool = False,
29+
bias_init: str = "default",
30+
dropout_prob: float | None = None,
31+
lr: float = 1e-2,
32+
random_seed: int | None = None,
33+
device: DeviceLike | None = None,
34+
scheduler_cls: type[LRScheduler] | None = None,
35+
scheduler_params: dict | None = None,
36+
):
37+
"""
38+
Multi-Layer Perceptron (MLP) emulator with Gaussian outputs.
39+
40+
GaussianMLP extends the standard MLP to output Gaussian distributions with
41+
either diagonal or full covariance matrices, allowing for uncertainty
42+
quantification in predictions.
43+
44+
Parameters
45+
----------
46+
x: TensorLike
47+
Input features.
48+
y: TensorLike
49+
Target values.
50+
standardize_x: bool
51+
Whether to standardize the input features. Defaults to True.
52+
standardize_y: bool
53+
Whether to standardize the target values. Defaults to True.
54+
epochs: int
55+
Number of training epochs. Defaults to 500.
56+
batch_size: int
57+
Batch size for training. Defaults to 16.
58+
activation_cls: type[nn.Module]
59+
Activation function to use in the hidden layers. Defaults to `nn.ReLU`.
60+
layer_dims: list[int] | None
61+
Dimensions of the hidden layers. If None, defaults to [32, 16].
62+
Defaults to None.
63+
weight_init: str
64+
Weight initialization method. Options are "default", "normal", "uniform",
65+
"zeros", "ones", "xavier_uniform", "xavier_normal", "kaiming_uniform",
66+
"kaiming_normal". Defaults to "default".
67+
scale: float
68+
Scale parameter for weight initialization methods. Used as:
69+
- gain for Xavier methods
70+
- std for normal distribution
71+
- bound for uniform distribution (range: [-scale, scale])
72+
- ignored for Kaiming methods (uses optimal scaling)
73+
Defaults to 1.0.
74+
full_covariance: bool
75+
If True, the emulator predicts full covariance matrices for the outputs. If
76+
False, only variance is predicted. Defaults to False.
77+
bias_init: str
78+
Bias initialization method. Options: "zeros", "default":
79+
- "zeros" initializes biases to zero
80+
- "default" uses PyTorch's default uniform initialization
81+
dropout_prob: float | None
82+
Dropout probability for regularization. If None, no dropout is applied.
83+
Defaults to None.
84+
lr: float
85+
Learning rate for the optimizer. Defaults to 1e-2.
86+
random_seed: int | None
87+
Random seed for reproducibility. If None, no seed is set. Defaults to None.
88+
device: DeviceLike | None
89+
Device to run the model on (e.g., "cpu", "cuda", "mps"). Defaults to None.
90+
scheduler_cls: type[LRScheduler] | None
91+
Learning rate scheduler class. If None, no scheduler is used. Defaults to
92+
None.
93+
scheduler_params: dict | None
94+
Additional keyword arguments related to the scheduler.
95+
96+
Raises
97+
------
98+
ValueError
99+
If the input dimensions of `x` and `y` are not matrices.
100+
"""
101+
TorchDeviceMixin.__init__(self, device=device)
102+
nn.Module.__init__(self)
103+
104+
if random_seed is not None:
105+
set_random_seed(seed=random_seed)
106+
107+
# Ensure x and y are tensors with correct dimensions
108+
x, y = self._convert_to_tensors(x, y)
109+
110+
# Construct the MLP layers
111+
# Total params required for last layer
112+
num_params = (
113+
y.shape[1] + (y.shape[1] * (y.shape[1] + 1)) // 2 # mean + tril covariance
114+
if full_covariance
115+
else 2 * y.shape[1] # mean + variance (diag covariance)
116+
)
117+
layer_dims = (
118+
[x.shape[1], *layer_dims]
119+
if layer_dims
120+
else [x.shape[1], 32 * num_params, 16 * num_params]
121+
)
122+
layers = []
123+
for idx, dim in enumerate(layer_dims[1:]):
124+
layers.append(nn.Linear(layer_dims[idx], dim, device=self.device))
125+
layers.append(activation_cls())
126+
if dropout_prob is not None:
127+
layers.append(nn.Dropout(p=dropout_prob))
128+
129+
# Add final layer without activation
130+
layers.append(nn.Linear(layer_dims[-1], num_params, device=self.device))
131+
self.nn = nn.Sequential(*layers)
132+
133+
# Finalize initialization
134+
self._initialize_weights(weight_init, scale, bias_init)
135+
self.x_transform = StandardizeTransform() if standardize_x else None
136+
self.y_transform = StandardizeTransform() if standardize_y else None
137+
self.epochs = epochs
138+
self.lr = lr
139+
self.num_tasks = y.shape[1]
140+
self.batch_size = batch_size
141+
self.full_covariance = full_covariance
142+
self.optimizer = self.optimizer_cls(self.nn.parameters(), lr=lr) # type: ignore # noqa: PGH003
143+
self.scheduler_cls = scheduler_cls
144+
self.scheduler_params = scheduler_params or {}
145+
self.scheduler_setup(self.scheduler_params)
146+
self.to(device)
147+
148+
def forward(self, x):
149+
"""Forward pass for the Gaussian MLP."""
150+
y = self.nn(x)
151+
mean = y[..., : self.num_tasks]
152+
153+
if self.full_covariance:
154+
# Use Cholesky decomposition to guarantee PSD covariance matrix
155+
num_chol_params = (self.num_tasks * (self.num_tasks + 1)) // 2
156+
chol_params = y[..., self.num_tasks : self.num_tasks + num_chol_params]
157+
158+
# Assign params to matrix
159+
scale_tril = torch.zeros(
160+
*y.shape[:-1], self.num_tasks, self.num_tasks, device=y.device
161+
)
162+
tril_indices = torch.tril_indices(
163+
self.num_tasks, self.num_tasks, device=y.device
164+
)
165+
scale_tril[..., tril_indices[0], tril_indices[1]] = chol_params
166+
167+
# Ensure positive variance
168+
diag_idxs = torch.arange(self.num_tasks)
169+
diag = (
170+
torch.nn.functional.softplus(scale_tril[..., diag_idxs, diag_idxs])
171+
+ 1e-6
172+
)
173+
scale_tril[..., diag_idxs, diag_idxs] = diag
174+
175+
covariance_matrix = scale_tril @ scale_tril.transpose(-1, -2)
176+
177+
# TODO: for large covariance matrices, numerical instability remains
178+
return GaussianLike(mean, make_positive_definite(covariance_matrix))
179+
180+
# Diagonal covariance case
181+
return GaussianLike(
182+
mean,
183+
torch.diag_embed(
184+
torch.nn.functional.softplus(y[..., self.num_tasks :]) + 1e-6
185+
),
186+
)
187+
188+
def _predict(self, x: TensorLike, with_grad: bool) -> GaussianLike:
189+
"""Predict method that returns GaussianLike distribution.
190+
191+
The method provides the implementation from PyTorchBackend base class but is
192+
required to be implemented here to satisfy the type signature.
193+
"""
194+
self.eval()
195+
with torch.set_grad_enabled(with_grad):
196+
return self(x)
197+
198+
def loss_func(self, y_pred, y_true):
199+
"""Negative log likelihood loss function."""
200+
return -y_pred.log_prob(y_true).mean()

0 commit comments

Comments
 (0)