Skip to content

Commit c075a99

Browse files
authored
feat: interface for unconditional flow training (#1470)
* first commit * rearranging I * rearranging II * adding unitary test and a few changes * ...
1 parent 9c59a89 commit c075a99

File tree

10 files changed

+862
-15
lines changed

10 files changed

+862
-15
lines changed

sbi/inference/trainers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def get_dataloaders(
341341
num_validation_examples = num_examples - num_training_examples
342342

343343
if not resume_training:
344-
# Seperate indicies for training and validation
344+
# Separate indices for training and validation
345345
permuted_indices = torch.randperm(num_examples)
346346
self.train_indices, self.val_indices = (
347347
permuted_indices[:num_training_examples],
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from sbi.inference.trainers.marginal.marginal_base import MarginalTrainer
Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
2+
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
3+
4+
import time
5+
from copy import deepcopy
6+
from datetime import datetime
7+
from pathlib import Path
8+
from typing import Callable, Optional, Tuple, Union
9+
10+
import torch
11+
from torch import Tensor
12+
from torch.nn.utils.clip_grad import clip_grad_norm_
13+
from torch.optim.adam import Adam
14+
from torch.utils import data
15+
from torch.utils.data.sampler import SubsetRandomSampler
16+
from torch.utils.tensorboard.writer import SummaryWriter
17+
18+
from sbi.neural_nets.estimators import UnconditionalDensityEstimator
19+
from sbi.neural_nets.estimators.shape_handling import (
20+
reshape_to_batch_event,
21+
)
22+
from sbi.neural_nets.factory import marginal_nn
23+
from sbi.utils import check_estimator_arg, get_log_root
24+
from sbi.utils.torchutils import assert_all_finite, process_device
25+
26+
27+
class MarginalTrainer:
28+
def __init__(
29+
self,
30+
density_estimator: Union[str, Callable] = "MAF",
31+
device: str = "cpu",
32+
summary_writer: Optional[SummaryWriter] = None,
33+
show_progress_bars: bool = True,
34+
):
35+
"""Utility class for training a marginal estimator method."""
36+
37+
self._device = process_device(device)
38+
self._neural_net = None
39+
40+
self._show_progress_bars = show_progress_bars
41+
self._val_loss = float("Inf")
42+
43+
self._summary_writer = (
44+
self._default_summary_writer() if summary_writer is None else summary_writer
45+
)
46+
47+
# Logging during training (by SummaryWriter).
48+
self._summary = dict(
49+
epochs_trained=[],
50+
best_validation_loss=[],
51+
validation_loss=[],
52+
training_loss=[],
53+
epoch_durations_sec=[],
54+
)
55+
56+
check_estimator_arg(density_estimator)
57+
if isinstance(density_estimator, str):
58+
self._build_neural_net = marginal_nn(model=density_estimator)
59+
else:
60+
self._build_neural_net = density_estimator
61+
62+
def get_dataloaders(
63+
self,
64+
training_batch_size: int = 200,
65+
validation_fraction: float = 0.1,
66+
dataloader_kwargs: Optional[dict] = None,
67+
) -> Tuple[data.DataLoader, data.DataLoader]:
68+
x = self.get_samples()
69+
dataset = data.TensorDataset(x)
70+
71+
# Get total number of training examples.
72+
num_examples = x.size(0)
73+
# Select random train and validation splits from (theta, x) pairs.
74+
num_training_examples = int((1 - validation_fraction) * num_examples)
75+
num_validation_examples = num_examples - num_training_examples
76+
77+
# Separate indices for training and validation
78+
permuted_indices = torch.randperm(num_examples)
79+
self.train_indices, self.val_indices = (
80+
permuted_indices[:num_training_examples],
81+
permuted_indices[num_training_examples:],
82+
)
83+
84+
train_loader_kwargs = {
85+
"batch_size": min(training_batch_size, num_training_examples),
86+
"drop_last": True,
87+
"sampler": SubsetRandomSampler(self.train_indices.tolist()),
88+
}
89+
val_loader_kwargs = {
90+
"batch_size": min(training_batch_size, num_validation_examples),
91+
"shuffle": False,
92+
"drop_last": True,
93+
"sampler": SubsetRandomSampler(self.val_indices.tolist()),
94+
}
95+
if dataloader_kwargs is not None:
96+
train_loader_kwargs = dict(train_loader_kwargs, **dataloader_kwargs)
97+
val_loader_kwargs = dict(val_loader_kwargs, **dataloader_kwargs)
98+
99+
train_loader = data.DataLoader(dataset, **train_loader_kwargs)
100+
val_loader = data.DataLoader(dataset, **val_loader_kwargs)
101+
102+
return train_loader, val_loader
103+
104+
def append_samples(self, x) -> "MarginalTrainer":
105+
self._x = x
106+
return self
107+
108+
def get_samples(self) -> Tensor:
109+
return self._x
110+
111+
def loss(self, x: Tensor) -> Tensor:
112+
"""Return loss.
113+
114+
The loss is the negative log prob
115+
116+
Returns:
117+
Negative log prob.
118+
"""
119+
if self._neural_net is None:
120+
raise ValueError(
121+
"Neural network has not been initialized. Please call `train` first."
122+
)
123+
else:
124+
x = reshape_to_batch_event(x, event_shape=self._neural_net.input_shape)
125+
loss = self._neural_net.loss(x)
126+
assert_all_finite(loss, "loss")
127+
return loss
128+
129+
def train(
130+
self,
131+
training_batch_size: int = 200,
132+
learning_rate: float = 5e-4,
133+
validation_fraction: float = 0.1,
134+
stop_after_epochs: int = 20,
135+
max_num_epochs: int = 2**31 - 1,
136+
clip_max_norm: Optional[float] = 5.0,
137+
dataloader_kwargs: Optional[dict] = None,
138+
) -> UnconditionalDensityEstimator:
139+
r"""Return density estimator that approximates the distribution $p(x)$.
140+
141+
Args:
142+
training_batch_size: Training batch size.
143+
learning_rate: Learning rate for Adam optimizer.
144+
validation_fraction: The fraction of data to use for validation.
145+
stop_after_epochs: The number of epochs to wait for improvement on the
146+
validation set before terminating training.
147+
max_num_epochs: Maximum number of epochs to run. If reached, we stop
148+
training even when the validation loss is still decreasing. Otherwise,
149+
we train until validation loss increases (see also `stop_after_epochs`).
150+
clip_max_norm: Value at which to clip the total gradient norm in order to
151+
prevent exploding gradients. Use None for no clipping.
152+
show_train_summary: Whether to print the number of epochs and validation
153+
loss after the training.
154+
dataloader_kwargs: Additional or updated kwargs to be passed to the training
155+
and validation dataloaders (like, e.g., a collate_fn)
156+
157+
Returns:
158+
Density estimator that approximates the distribution $p(\theta|x)$.
159+
"""
160+
161+
# fake round setting just for compatibility with NeuralInference
162+
self._round = 0
163+
164+
train_loader, val_loader = self.get_dataloaders(
165+
training_batch_size,
166+
validation_fraction,
167+
dataloader_kwargs=dataloader_kwargs,
168+
)
169+
170+
if self._neural_net is None:
171+
# Get x to initialize NN
172+
x = self.get_samples()
173+
# Use only training data for building the neural net (z-scoring transforms)
174+
175+
self._neural_net = self._build_neural_net(
176+
x[self.train_indices].to("cpu"),
177+
)
178+
179+
self.optimizer = Adam(list(self._neural_net.parameters()), lr=learning_rate)
180+
self.epoch, self._val_loss = 0, float("Inf")
181+
182+
while self.epoch <= max_num_epochs and not self._converged(
183+
self.epoch, stop_after_epochs
184+
):
185+
# Train for a single epoch.
186+
self._neural_net.train()
187+
train_loss_sum = 0
188+
epoch_start_time = time.time()
189+
for batch in train_loader:
190+
self.optimizer.zero_grad()
191+
# Get batches on current device.
192+
x_batch = batch[0].to(self._device)
193+
194+
train_losses = self.loss(x_batch)
195+
train_loss = torch.mean(train_losses)
196+
train_loss_sum += train_losses.sum().item()
197+
198+
train_loss.backward()
199+
if clip_max_norm is not None:
200+
clip_grad_norm_(
201+
self._neural_net.parameters(), max_norm=clip_max_norm
202+
)
203+
self.optimizer.step()
204+
205+
self.epoch += 1
206+
207+
train_loss_average = train_loss_sum / (
208+
len(train_loader) * train_loader.batch_size # type: ignore
209+
)
210+
self._summary["training_loss"].append(train_loss_average)
211+
212+
# Calculate validation performance.
213+
self._neural_net.eval()
214+
val_loss_sum = 0
215+
216+
with torch.no_grad():
217+
for batch in val_loader:
218+
x_batch = batch[0].to(self._device)
219+
# Take negative loss here to get validation log_prob.
220+
val_losses = self.loss(x_batch)
221+
val_loss_sum += val_losses.sum().item()
222+
223+
# Take mean over all validation samples.
224+
self._val_loss = val_loss_sum / (
225+
len(val_loader) * val_loader.batch_size # type: ignore
226+
)
227+
# Log validation loss for every epoch.
228+
self._summary["validation_loss"].append(self._val_loss)
229+
self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time)
230+
231+
self._maybe_show_progress(self._show_progress_bars, self.epoch)
232+
233+
# Update summary.
234+
self._summary["epochs_trained"].append(self.epoch)
235+
self._summary["best_validation_loss"].append(self._best_val_loss)
236+
237+
# Update tensorboard and summary dict.
238+
self._summarize(round_=self._round)
239+
240+
# Avoid keeping the gradients in the resulting network, which can
241+
# cause memory leakage when benchmarking.
242+
self._neural_net.zero_grad(set_to_none=True)
243+
244+
return deepcopy(self._neural_net)
245+
246+
def _default_summary_writer(self) -> SummaryWriter:
247+
"""Return summary writer logging to method- and simulator-specific directory."""
248+
249+
method = self.__class__.__name__
250+
logdir = Path(
251+
get_log_root(), method, datetime.now().isoformat().replace(":", "_")
252+
)
253+
return SummaryWriter(logdir)
254+
255+
def _converged(self, epoch: int, stop_after_epochs: int) -> bool:
256+
"""Return whether the training converged yet and save best model state so far.
257+
258+
Checks for improvement in validation performance over previous epochs.
259+
260+
Args:
261+
epoch: Current epoch in training.
262+
stop_after_epochs: How many fruitless epochs to let pass before stopping.
263+
264+
Returns:
265+
Whether the training has stopped improving, i.e. has converged.
266+
"""
267+
converged = False
268+
269+
assert self._neural_net is not None
270+
neural_net = self._neural_net
271+
272+
# (Re)-start the epoch count with the first epoch or any improvement.
273+
if epoch == 0 or self._val_loss < self._best_val_loss:
274+
self._best_val_loss = self._val_loss
275+
self._epochs_since_last_improvement = 0
276+
self._best_model_state_dict = deepcopy(neural_net.state_dict())
277+
else:
278+
self._epochs_since_last_improvement += 1
279+
280+
# If no validation improvement over many epochs, stop training.
281+
if self._epochs_since_last_improvement > stop_after_epochs - 1:
282+
neural_net.load_state_dict(self._best_model_state_dict)
283+
converged = True
284+
285+
return converged
286+
287+
@staticmethod
288+
def _maybe_show_progress(show: bool, epoch: int) -> None:
289+
if show:
290+
# end="\r" deletes the print statement when a new one appears.
291+
# https://stackoverflow.com/questions/3419984/. `\r` in the beginning due
292+
# to #330.
293+
print("\r", f"Training neural network. Epochs trained: {epoch}", end="")
294+
295+
def _summarize(
296+
self,
297+
round_: int,
298+
) -> None:
299+
"""Update the summary_writer with statistics for a given round.
300+
301+
During training several performance statistics are added to the summary, e.g.,
302+
using `self._summary['key'].append(value)`. This function writes these values
303+
into summary writer object.
304+
305+
Args:
306+
round: index of round
307+
308+
Scalar tags:
309+
- epochs_trained:
310+
number of epochs trained
311+
- best_validation_loss:
312+
best validation loss (for each round).
313+
- validation_loss:
314+
validation loss for every epoch (for each round).
315+
- training_loss
316+
training loss for every epoch (for each round).
317+
- epoch_durations_sec
318+
epoch duration for every epoch (for each round)
319+
320+
"""
321+
322+
# Add most recent training stats to summary writer.
323+
self._summary_writer.add_scalar(
324+
tag="epochs_trained",
325+
scalar_value=self._summary["epochs_trained"][-1],
326+
global_step=round_ + 1,
327+
)
328+
329+
self._summary_writer.add_scalar(
330+
tag="best_validation_loss",
331+
scalar_value=self._summary["best_validation_loss"][-1],
332+
global_step=round_ + 1,
333+
)
334+
335+
# Add validation loss for every epoch.
336+
# Offset with all previous epochs.
337+
offset = (
338+
torch.tensor(self._summary["epochs_trained"][:-1], dtype=torch.int)
339+
.sum()
340+
.item()
341+
)
342+
for i, vlp in enumerate(self._summary["validation_loss"][offset:]):
343+
self._summary_writer.add_scalar(
344+
tag="validation_loss",
345+
scalar_value=vlp,
346+
global_step=offset + i,
347+
)
348+
349+
for i, tlp in enumerate(self._summary["training_loss"][offset:]):
350+
self._summary_writer.add_scalar(
351+
tag="training_loss",
352+
scalar_value=tlp,
353+
global_step=offset + i,
354+
)
355+
356+
for i, eds in enumerate(self._summary["epoch_durations_sec"][offset:]):
357+
self._summary_writer.add_scalar(
358+
tag="epoch_durations_sec",
359+
scalar_value=eds,
360+
global_step=offset + i,
361+
)
362+
363+
self._summary_writer.flush()

sbi/neural_nets/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
classifier_nn,
33
flowmatching_nn,
44
likelihood_nn,
5+
marginal_nn,
56
posterior_nn,
67
posterior_score_nn,
78
)
@@ -24,4 +25,6 @@ def __getattr__(name):
2425
return posterior_nn
2526
elif name == "posterior_score_nn":
2627
return posterior_score_nn
28+
elif name == "marginal_nn":
29+
return marginal_nn
2730
raise AttributeError(f"Module '{__name__}' has no attribute '{name}'")

0 commit comments

Comments
 (0)