Skip to content

Commit 0f25b70

Browse files
authored
Merge pull request #870 from alan-turing-institute/868-beta-model
Add Zero-One Inflated Beta MLP (#868)
2 parents f230253 + 84fd9e7 commit 0f25b70

File tree

2 files changed

+251
-1
lines changed

2 files changed

+251
-1
lines changed

autoemulate/emulators/nn/mlp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
bias_init: str = "default",
3434
dropout_prob: float | None = None,
3535
lr: float = 1e-2,
36+
params_size: int = 1,
3637
random_seed: int | None = None,
3738
device: DeviceLike | None = None,
3839
scheduler_cls: type[LRScheduler] | None = None,
@@ -75,6 +76,8 @@ def __init__(
7576
Defaults to None.
7677
lr: float
7778
Learning rate for the optimizer. Defaults to 1e-2.
79+
params_size: int
80+
Number of parameters to predict per output dimension. Defaults to 1.
7881
random_seed: int | None
7982
Random seed for reproducibility. If None, no seed is set. Defaults to None.
8083
device: DeviceLike | None
@@ -115,7 +118,9 @@ def __init__(
115118

116119
# Add final layer without activation
117120
num_tasks = y.shape[1]
118-
layers.append(nn.Linear(self.layer_dims[-1], num_tasks, device=self.device))
121+
layers.append(
122+
nn.Linear(self.layer_dims[-1], num_tasks * params_size, device=self.device)
123+
)
119124
self.nn = nn.Sequential(*layers)
120125

121126
# Finalize initialization
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from autoemulate.core.types import DeviceLike, TensorLike
4+
from autoemulate.emulators.nn.mlp import MLP
5+
from torch import nn
6+
from torch.optim.lr_scheduler import LRScheduler
7+
8+
9+
class ZeroOneInflatedBeta(torch.distributions.Distribution):
10+
"""ZeroOneInflatedBeta."""
11+
12+
arg_constraints = { # type: ignore # noqa: PGH003, RUF012
13+
"pi0": torch.distributions.constraints.unit_interval, # type: ignore # noqa: PGH003
14+
"pi1": torch.distributions.constraints.unit_interval, # type: ignore # noqa: PGH003
15+
"concentration1": torch.distributions.constraints.positive, # type: ignore # noqa: PGH003
16+
"concentration0": torch.distributions.constraints.positive, # type: ignore # noqa: PGH003
17+
}
18+
support = torch.distributions.constraints.unit_interval # type: ignore # noqa: PGH003
19+
20+
def __init__(self, pi0, pi1, concentration1, concentration0, validate_args=None):
21+
self.pi0 = pi0
22+
self.pi1 = pi1
23+
self.concentration1 = concentration1
24+
self.concentration0 = concentration0
25+
self.beta = torch.distributions.Beta(concentration1, concentration0)
26+
27+
# Ensure pi0 + pi1 < 1 elementwise
28+
if ((self.pi0 + self.pi1) > torch.ones_like(self.pi0)).any():
29+
msg = "pi0 + pi1 must be <= 1"
30+
raise ValueError(msg)
31+
32+
super().__init__(validate_args=validate_args)
33+
34+
def log_prob(self, value):
35+
"""Log prob."""
36+
EPS = 1e-12
37+
# Ensure value can broadcast with parameters. If value is 1D [N], make it [N, 1]
38+
# to align with [N, num_tasks].
39+
value_in = value
40+
squeeze_back = False
41+
while value_in.dim() < self.pi0.dim():
42+
value_in = value_in.unsqueeze(-1)
43+
squeeze_back = True
44+
45+
# Clamp continuous values away from boundaries for Beta support
46+
v = value_in.clamp(EPS, 1 - EPS)
47+
48+
# Broadcast pi0, pi1 to value shape
49+
pi0_b = self.pi0.expand_as(value_in)
50+
pi1_b = self.pi1.expand_as(value_in)
51+
52+
# Mixture log probs
53+
logp0 = torch.log(pi0_b + EPS)
54+
logp1 = torch.log(pi1_b + EPS)
55+
mix_log = torch.log(1 - pi0_b - pi1_b + EPS)
56+
beta_lp = self.beta.log_prob(v)
57+
58+
cont = mix_log + beta_lp
59+
logp = torch.where(
60+
value_in == 0, logp0, torch.where(value_in == 1, logp1, cont)
61+
)
62+
63+
if squeeze_back and logp.shape[-1] == 1:
64+
logp = logp.squeeze(-1)
65+
return logp
66+
67+
@property
68+
def mean(self):
69+
"""Mixture mean: pi1*1 + (1-pi0-pi1)*beta.mean."""
70+
p0 = self.pi0
71+
p1 = self.pi1
72+
p_cont = 1 - p0 - p1
73+
return p1 + p_cont * self.beta.mean
74+
75+
@property
76+
def variance(self):
77+
"""Mixture variance computed from mixture second moment."""
78+
p0 = self.pi0
79+
p1 = self.pi1
80+
p_cont = 1 - p0 - p1
81+
# E[X] for mixture
82+
mean = p1 + p_cont * self.beta.mean
83+
# E[X^2] for Beta is var + mean^2
84+
beta_second = self.beta.variance + self.beta.mean**2
85+
second_moment = p1 + p_cont * beta_second
86+
return second_moment - mean**2
87+
88+
def sample(self, sample_shape=None):
89+
"""Sample."""
90+
# Sample from categorical: {0, 1, Beta}
91+
if sample_shape is None:
92+
sample_shape = torch.Size()
93+
ones = torch.ones_like(self.pi0)
94+
probs = torch.stack([self.pi0, self.pi1, ones - self.pi0 - self.pi1], dim=-1)
95+
cat = torch.distributions.Categorical(probs=probs)
96+
choice = cat.sample(sample_shape)
97+
98+
beta_samples = self.beta.sample(sample_shape)
99+
100+
# Assign values based on choice
101+
return torch.where(
102+
choice == 0,
103+
torch.zeros_like(beta_samples),
104+
torch.where(choice == 1, torch.ones_like(beta_samples), beta_samples),
105+
)
106+
107+
108+
class ZOIBMLP(MLP):
109+
"""Zero-One Inflated Beta distribution Multi-Layer Perceptron (MLP) emulator."""
110+
111+
supports_uq: bool = True
112+
113+
def __init__(
114+
self,
115+
x: TensorLike,
116+
y: TensorLike,
117+
standardize_x: bool = True,
118+
activation_cls: type[nn.Module] = nn.ReLU,
119+
loss_fn_cls: type[nn.Module] = nn.MSELoss,
120+
epochs: int = 100,
121+
batch_size: int = 16,
122+
layer_dims: list[int] | None = None,
123+
weight_init: str = "default",
124+
scale: float = 1.0,
125+
bias_init: str = "default",
126+
dropout_prob: float | None = None,
127+
lr: float = 1e-2,
128+
random_seed: int | None = None,
129+
device: DeviceLike | None = None,
130+
scheduler_cls: type[LRScheduler] | None = None,
131+
scheduler_params: dict | None = None,
132+
):
133+
"""
134+
Zero-One Inflated Beta Distribution Multi-Layer Perceptron (MLP) emulator.
135+
136+
Parameters
137+
----------
138+
x: TensorLike
139+
Input features.
140+
y: TensorLike
141+
Target values.
142+
activation_cls: type[nn.Module]
143+
Activation function to use in the hidden layers. Defaults to `nn.ReLU`.
144+
layer_dims: list[int] | None
145+
Dimensions of the hidden layers. If None, defaults to [32, 16].
146+
Defaults to None.
147+
weight_init: str
148+
Weight initialization method. Options are "default", "normal", "uniform",
149+
"zeros", "ones", "xavier_uniform", "xavier_normal", "kaiming_uniform",
150+
"kaiming_normal". Defaults to "default".
151+
scale: float
152+
Scale parameter for weight initialization methods. Used as:
153+
- gain for Xavier methods
154+
- std for normal distribution
155+
- bound for uniform distribution (range: [-scale, scale])
156+
- ignored for Kaiming methods (uses optimal scaling)
157+
Defaults to 1.0.
158+
bias_init: str
159+
Bias initialization method. Options: "zeros", "default":
160+
- "zeros" initializes biases to zero
161+
- "default" uses PyTorch's default uniform initialization
162+
dropout_prob: float | None
163+
Dropout probability for regularization. If None, no dropout is applied.
164+
Defaults to None.
165+
lr: float
166+
Learning rate for the optimizer. Defaults to 1e-2.
167+
random_seed: int | None
168+
Random seed for reproducibility. If None, no seed is set. Defaults to None.
169+
device: DeviceLike | None
170+
Device to run the model on (e.g., "cpu", "cuda", "mps"). Defaults to None.
171+
scheduler_cls: type[LRScheduler] | None
172+
Learning rate scheduler class. If None, no scheduler is used. Defaults to
173+
None.
174+
scheduler_params: dict | None
175+
Additional keyword arguments related to the scheduler.
176+
177+
Raises
178+
------
179+
ValueError
180+
If the input dimensions of `x` and `y` are not matrices.
181+
"""
182+
MLP.__init__(
183+
self,
184+
x,
185+
y,
186+
standardize_x,
187+
False, # Don't standardize y for ZOIB
188+
activation_cls,
189+
loss_fn_cls,
190+
epochs,
191+
batch_size,
192+
layer_dims,
193+
weight_init,
194+
scale,
195+
bias_init,
196+
dropout_prob,
197+
lr,
198+
5, # params_size=5 for Zero-Inflated Beta distribution
199+
random_seed,
200+
device,
201+
scheduler_cls,
202+
scheduler_params,
203+
)
204+
205+
def loss_func(self, y_pred, y_true): # noqa: D102
206+
return -y_pred.log_prob(y_true).mean()
207+
208+
def forward(self, x: TensorLike) -> ZeroOneInflatedBeta:
209+
"""Forward pass for the MLP."""
210+
EPS = 1e-6
211+
output = self.nn(x)
212+
probs = F.softmax(output[..., 2:5], dim=-1)
213+
return ZeroOneInflatedBeta(
214+
pi0=probs[..., :1],
215+
pi1=probs[..., 1:2],
216+
concentration0=F.softplus(output[..., :1]) + EPS,
217+
concentration1=F.softplus(output[..., 1:2]) + EPS,
218+
)
219+
220+
def predict_mean_and_variance(
221+
self,
222+
x: TensorLike,
223+
with_grad: bool = False,
224+
n_samples: int = 1000, # noqa: ARG002
225+
) -> tuple[torch.Tensor, torch.Tensor]:
226+
"""
227+
Predict the mean and variance of the output for given input.
228+
229+
Parameters
230+
----------
231+
x: TensorLike
232+
Input features as numpy array or PyTorch tensor.
233+
234+
Returns
235+
-------
236+
mean: torch.Tensor
237+
Predicted mean values.
238+
variance: torch.Tensor
239+
Predicted variance values.
240+
"""
241+
self.eval() # Set model to evaluation mode
242+
with torch.set_grad_enabled(with_grad):
243+
beta_dist = self.predict(x)
244+
assert isinstance(beta_dist, ZeroOneInflatedBeta)
245+
return beta_dist.mean, beta_dist.variance

0 commit comments

Comments
 (0)