Skip to content

Commit 1888205

Browse files
janfbpsteinbCamille Touron
authored
refactoring noise_schedule and time schedule into base class (#1736)
* refactoring noise_schedule and time schedule into base class - created noise_schedule method to be overwritten by derivatives - created times_schedules method to be overwritten by derivatives - created test on times_schedules - improvded docstrings * added noise schedule test * implemented beta schedule for variance-preserving estimators - added tests too - inspired by https://arxiv.org/abs/2206.00364 * code cosmetics triggered by ruff * cloned VPScoreEstimator to yield improved version in addition: - added improved version to benchmarks (for later comparison) - created new class ImprovedVPScoreEstimator * more realistic bounds for unit test * typo and refactoring to understand how the VE estimator is implemented * fixed wrong setup of pmean and pstd * code reformatting * use the time schedule for computing the validation scores * propagate name change * fix unit tests to respect new schedules * comply with formatting * attempted to implement EDM-like diffusion - without touching the forward function of ConditionalScoreEstimator - benchmarks show that this leads to very long training time without any performance improvements * removed "improved" denoising network * consolidated tests * removed occurrances of vp++ * removed all mentions of edm * ruff fixes * WIP : use time schedule in loss function, address device issues * call solve_schedule in validation step * add solve_schedule method, call train_schedule in loss * call the solve schedule during sampling with SDE * add a solve_schedule function in the conditional vf estimator class to unify training in vftrainer class * make the solve schedule deterministic * corrections on solve schedule * WIP : create solve schedule in base class * modify arguments of solve schedule * change train_schedule + docstrings fixes + device handling * include validation times nugget to avoid instabilities during training * change the nb of simulations for ve option * change device in solve schedule * add noise schedule in VE subclass * reshape noise schedule output in VE class * reshape noise schedule output in VE class * add tests on train and solve schedule shapes, devices, bounds * formatting and changing tests --------- Co-authored-by: Peter Steinbach <p.steinbach@hzdr.de> Co-authored-by: Camille Touron <ctouron@ptb-07008323.grenoble.inria.fr>
1 parent 937efc2 commit 1888205

File tree

6 files changed

+208
-56
lines changed

6 files changed

+208
-56
lines changed

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,8 @@ def _sample_via_diffusion(
313313
corrector: The corrector for the diffusion-based sampler. Either of
314314
[None].
315315
steps: Number of steps to take for the Euler-Maruyama method.
316-
ts: Time points at which to evaluate the diffusion process. If None, a
317-
linear grid between t_max and t_min is used.
316+
ts: Time points at which to evaluate the diffusion process. If None, call
317+
the solve schedule specific to the score estimator.
318318
max_sampling_batch_size: Maximum batch size for sampling.
319319
sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to
320320
`.sample()`.
@@ -340,11 +340,8 @@ def _sample_via_diffusion(
340340
# Ensure we don't use larger batches than total samples needed
341341
effective_batch_size = min(effective_batch_size, total_samples_needed)
342342

343-
# TODO: the time schedule should be provided by the estimator, see issue #1437
344343
if ts is None:
345-
t_max = self.vector_field_estimator.t_max
346-
t_min = self.vector_field_estimator.t_min
347-
ts = torch.linspace(t_max, t_min, steps)
344+
ts = self.vector_field_estimator.solve_schedule(steps)
348345
ts = ts.to(self.device)
349346

350347
# Initialize the diffusion sampler

sbi/inference/trainers/vfpe/base_vf_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,10 @@ def default_calibration_kernel(x):
313313
)
314314

315315
if isinstance(validation_times, int):
316-
validation_times = torch.linspace(
317-
self._neural_net.t_min + validation_times_nugget,
318-
self._neural_net.t_max - validation_times_nugget,
316+
validation_times = self._neural_net.solve_schedule(
319317
validation_times,
318+
t_min=self._neural_net.t_min + validation_times_nugget,
319+
t_max=self._neural_net.t_max - validation_times_nugget,
320320
)
321321

322322
loss_args = LossArgsVF(

sbi/neural_nets/estimators/base.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,35 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
479479
"""
480480
raise NotImplementedError("Diffusion is not implemented for this estimator.")
481481

482+
def solve_schedule(
483+
self,
484+
steps: int,
485+
t_min: Optional[float] = None,
486+
t_max: Optional[float] = None,
487+
) -> Tensor:
488+
"""
489+
Time grid used during sampling (solving steps) and loss evaluation steps.
490+
491+
This grid is deterministic and decreasing. Can be overriden by subclasses.
492+
Return by default a uniform time stepping between t_max and t_min.
493+
494+
Args:
495+
steps: number of discretization steps
496+
t_min: The minimum time value. Defaults to self.t_min.
497+
t_max: The maximum time value. Defaults to self.t_max.
498+
499+
Returns:
500+
Tensor: A tensor of time steps within the range [t_max, t_min].
501+
"""
502+
if t_min is None:
503+
t_min = self.t_min
504+
if t_max is None:
505+
t_max = self.t_max
506+
507+
times = torch.linspace(t_max, t_min, steps, device=self._mean_base.device)
508+
509+
return times
510+
482511

483512
class UnconditionalEstimator(nn.Module, ABC):
484513
r"""Base class for unconditional estimators that estimate properties of

sbi/neural_nets/estimators/score_estimator.py

Lines changed: 109 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def __init__(
6666
condition_shape: torch.Size,
6767
embedding_net: Optional[nn.Module] = None,
6868
weight_fn: Union[str, Callable] = "max_likelihood",
69+
beta_min: float = 0.01,
70+
beta_max: float = 10.0,
6971
mean_0: Union[Tensor, float] = 0.0,
7072
std_0: Union[Tensor, float] = 1.0,
7173
t_min: float = 1e-3,
@@ -111,6 +113,10 @@ def __init__(
111113
t_max=t_max,
112114
)
113115

116+
# Min/max values for noise variance beta
117+
self.beta_min = beta_min
118+
self.beta_max = beta_max
119+
114120
# Set lambdas (variance weights) function.
115121
self._set_weight_fn(weight_fn)
116122
self.register_buffer("mean_0", mean_0.clone().detach())
@@ -228,7 +234,8 @@ def loss(
228234
Args:
229235
input: Input variable i.e. theta.
230236
condition: Conditioning variable.
231-
times: SDE time variable in [t_min, t_max]. Uniformly sampled if None.
237+
times: SDE time variable in [t_min, t_max]. If None, call the train_schedule
238+
method specific to the score estimator.
232239
control_variate: Whether to use a control variate to reduce the variance of
233240
the stochastic loss estimator.
234241
control_variate_threshold: Threshold for the control variate. If the std
@@ -240,13 +247,11 @@ def loss(
240247
MSE between target score and network output, scaled by the weight function.
241248
242249
"""
243-
# Sample diffusion times.
250+
# Sample times from the Markov chain, use batch dimension
244251
if times is None:
245-
times = (
246-
torch.rand(input.shape[0], device=input.device)
247-
* (self.t_max - self.t_min)
248-
+ self.t_min
249-
)
252+
times = self.train_schedule(input.shape[0])
253+
times = times.to(input.device)
254+
250255
# Sample noise.
251256
eps = torch.randn_like(input)
252257

@@ -390,6 +395,86 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
390395
"""
391396
raise NotImplementedError
392397

398+
def noise_schedule(self, times: Tensor) -> Tensor:
399+
"""
400+
Create a mapping from time to noise magnitude (beta/sigma) used in the SDE.
401+
402+
This method acts as a fallback in case derivative classes do not
403+
implement it on their own. We implement here a linear beta schedule defined
404+
by the input `times`, which represents the normalized time steps t ∈ [0, 1].
405+
406+
Args:
407+
times: SDE times in [0, 1].
408+
409+
Returns:
410+
Tensor: Generated beta schedule at a given time.
411+
412+
"""
413+
return self.beta_min + (self.beta_max - self.beta_min) * times
414+
415+
def train_schedule(
416+
self,
417+
num_samples: int,
418+
t_min: Optional[float] = None,
419+
t_max: Optional[float] = None,
420+
) -> Tensor:
421+
"""
422+
Sample diffusion times used during training.
423+
424+
Can be overriden by subclasses. We implement here a uniform sampling of
425+
time within the range [t_min, t_max]. The `times` tensor will be put on
426+
the same device as the stored network.
427+
428+
Args:
429+
num_samples: Number of samples to generate.
430+
t_min: The minimum time value. Defaults to self.t_min.
431+
t_max: The maximum time value. Defaults to self.t_max.
432+
433+
Returns:
434+
Tensor: A tensor of time variables sampled in the range [t_min,t_max].
435+
436+
"""
437+
if t_min is None:
438+
t_min = self.t_min
439+
if t_max is None:
440+
t_max = self.t_max
441+
442+
return (
443+
torch.rand(num_samples, device=self._mean_base.device) * (t_max - t_min)
444+
+ t_min
445+
)
446+
447+
def solve_schedule(
448+
self,
449+
num_steps: int,
450+
t_min: Optional[float] = None,
451+
t_max: Optional[float] = None,
452+
) -> Tensor:
453+
"""
454+
Time grid used to solve the reverse SDE and evaluate the loss function.
455+
456+
This grid is deterministic and decreasing. We implement here a uniform time
457+
stepping within the range [t_max, t_min]. The `times` tensor will be put on
458+
the same device as the stored network.
459+
460+
Args:
461+
num_steps: Number of time steps to generate.
462+
t_min: The minimum time value. Defaults to self.t_min.
463+
t_max: The maximum time value. Defaults to self.t_max.
464+
465+
Returns:
466+
Tensor: A tensor of time steps within the range [t_max, t_min].
467+
468+
"""
469+
if t_min is None:
470+
t_min = self.t_min
471+
if t_max is None:
472+
t_max = self.t_max
473+
474+
times = torch.linspace(t_max, t_min, num_steps, device=self._mean_base.device)
475+
476+
return times
477+
393478
def _set_weight_fn(self, weight_fn: Union[str, Callable]):
394479
"""Set the weight function.
395480
@@ -480,8 +565,6 @@ def __init__(
480565
t_min: float = 1e-3,
481566
t_max: float = 1.0,
482567
) -> None:
483-
self.beta_min = beta_min
484-
self.beta_max = beta_max
485568
super().__init__(
486569
net,
487570
input_shape,
@@ -490,6 +573,8 @@ def __init__(
490573
weight_fn=weight_fn,
491574
mean_0=mean_0,
492575
std_0=std_0,
576+
beta_min=beta_min,
577+
beta_max=beta_max,
493578
t_min=t_min,
494579
t_max=t_max,
495580
)
@@ -525,17 +610,6 @@ def std_fn(self, times: Tensor) -> Tensor:
525610
std = std.unsqueeze(-1)
526611
return torch.sqrt(std)
527612

528-
def _beta_schedule(self, times: Tensor) -> Tensor:
529-
"""Linear beta schedule for mean scaling in variance preserving SDEs.
530-
531-
Args:
532-
times: SDE time variable in [0,1].
533-
534-
Returns:
535-
Beta schedule at a given time.
536-
"""
537-
return self.beta_min + (self.beta_max - self.beta_min) * times
538-
539613
def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
540614
"""Drift function for variance preserving SDEs.
541615
@@ -546,7 +620,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
546620
Returns:
547621
Drift function at a given time.
548622
"""
549-
phi = -0.5 * self._beta_schedule(times)
623+
phi = -0.5 * self.noise_schedule(times)
550624
while len(phi.shape) < len(input.shape):
551625
phi = phi.unsqueeze(-1)
552626
return phi * input
@@ -561,7 +635,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
561635
Returns:
562636
Drift function at a given time.
563637
"""
564-
g = torch.sqrt(self._beta_schedule(times))
638+
g = torch.sqrt(self.noise_schedule(times))
565639
while len(g.shape) < len(input.shape):
566640
g = g.unsqueeze(-1)
567641
return g
@@ -604,14 +678,14 @@ def __init__(
604678
t_min: float = 1e-2,
605679
t_max: float = 1.0,
606680
) -> None:
607-
self.beta_min = beta_min
608-
self.beta_max = beta_max
609681
super().__init__(
610682
net,
611683
input_shape,
612684
condition_shape,
613685
embedding_net=embedding_net,
614686
weight_fn=weight_fn,
687+
beta_min=beta_min,
688+
beta_max=beta_max,
615689
mean_0=mean_0,
616690
std_0=std_0,
617691
t_min=t_min,
@@ -649,18 +723,6 @@ def std_fn(self, times: Tensor) -> Tensor:
649723
std = std.unsqueeze(-1)
650724
return std
651725

652-
def _beta_schedule(self, times: Tensor) -> Tensor:
653-
"""Linear beta schedule for mean scaling in sub-variance preserving SDEs.
654-
(Same as for variance preserving SDEs.)
655-
656-
Args:
657-
times: SDE time variable in [0,1].
658-
659-
Returns:
660-
Beta schedule at a given time.
661-
"""
662-
return self.beta_min + (self.beta_max - self.beta_min) * times
663-
664726
def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
665727
"""Drift function for sub-variance preserving SDEs.
666728
@@ -671,7 +733,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
671733
Returns:
672734
Drift function at a given time.
673735
"""
674-
phi = -0.5 * self._beta_schedule(times).to(input.device)
736+
phi = -0.5 * self.noise_schedule(times)
675737

676738
while len(phi.shape) < len(input.shape):
677739
phi = phi.unsqueeze(-1)
@@ -690,7 +752,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
690752
"""
691753
g = torch.sqrt(
692754
torch.abs(
693-
self._beta_schedule(times)
755+
self.noise_schedule(times)
694756
* (
695757
1
696758
- torch.exp(
@@ -788,16 +850,18 @@ def std_fn(self, times: Tensor) -> Tensor:
788850
std = std.unsqueeze(-1)
789851
return std
790852

791-
def _sigma_schedule(self, times: Tensor) -> Tensor:
792-
"""Geometric sigma schedule for variance exploding SDEs.
853+
def noise_schedule(self, times: Tensor) -> Tensor:
854+
"""Noise schedule used in the SDE drift and diffusion coefficients.
855+
Note that for VE, this method returns the same as std_fn().
793856
794857
Args:
795858
times: SDE time variable in [0,1].
796859
797860
Returns:
798-
Sigma schedule at a given time.
861+
Noise magnitude (sigma) at a given time.
799862
"""
800-
return self.sigma_min * (self.sigma_max / self.sigma_min) ** times
863+
std = self.sigma_min * (self.sigma_max / self.sigma_min) ** times
864+
return std
801865

802866
def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
803867
"""Drift function for variance exploding SDEs.
@@ -821,11 +885,10 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
821885
Returns:
822886
Diffusion function at a given time.
823887
"""
824-
g = self._sigma_schedule(times) * math.sqrt(
825-
(2 * math.log(self.sigma_max / self.sigma_min))
826-
)
888+
sigma_scale = self.sigma_max / self.sigma_min
889+
sigmas = self.noise_schedule(times)
890+
g = sigmas * math.sqrt((2 * math.log(sigma_scale)))
827891

828892
while len(g.shape) < len(input.shape):
829893
g = g.unsqueeze(-1)
830-
831894
return g.to(input.device)

tests/linearGaussian_vector_field_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_c2st_vector_field_on_linearGaussian(
6969

7070
x_o = zeros(1, num_dim)
7171
num_samples = 1000
72-
num_simulations = 2500
72+
num_simulations = 2600 if vector_field_type == "ve" else 2500
7373

7474
# likelihood_mean will be likelihood_shift+theta
7575
likelihood_shift = -1.0 * ones(num_dim)

0 commit comments

Comments
 (0)