Skip to content

Commit 19364f0

Browse files
fix: fmpe singularity on sde sampling (#1661)
* allow sampling at every timepoint for diagnostic purposes * Fix FMPE SDE sampling * Fix bug * Add as argument with docstring * Fix typo * Mean_t_fn also need effective_t_max * All tests pass now
1 parent bdf7d83 commit 19364f0

File tree

4 files changed

+33
-14
lines changed

4 files changed

+33
-14
lines changed

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def _sample_via_diffusion(
260260
ts: Optional[Tensor] = None,
261261
max_sampling_batch_size: int = 10_000,
262262
show_progress_bars: bool = True,
263+
save_intermediate: bool = False,
263264
) -> Tensor:
264265
r"""Return samples from posterior distribution $p(\theta|x)$.
265266
@@ -281,6 +282,9 @@ def _sample_via_diffusion(
281282
sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to
282283
`.sample()`.
283284
show_progress_bars: Whether to show a progress bar during sampling.
285+
save_intermediate: Whether to save intermediate results of the diffusion
286+
process. If True, the returned tensor has shape
287+
`(*sample_shape, steps, *input_shape)`.
284288
"""
285289

286290
if not self.vector_field_estimator.SCORE_DEFINED:
@@ -332,6 +336,7 @@ def _sample_via_diffusion(
332336
num_samples=current_batch_size,
333337
ts=ts,
334338
show_progress_bars=show_progress_bars,
339+
save_intermediate=save_intermediate,
335340
)
336341

337342
all_samples.append(batch_samples)

sbi/neural_nets/estimators/flowmatching_estimator.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def score(self, input: Tensor, condition: Tensor, t: Tensor) -> Tensor:
243243
score = (-(1 - t) * v - input) / (t + self.noise_scale)
244244
return score
245245

246-
def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
246+
def drift_fn(
247+
self, input: Tensor, times: Tensor, effective_t_max: float = 0.99
248+
) -> Tensor:
247249
r"""Drift function for the flow matching estimator.
248250
249251
The drift function is calculated based on [3]_ (see Equation 7):
@@ -263,14 +265,22 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
263265
Args:
264266
input: Parameters :math:`\theta_t`.
265267
times: SDE time variable in [0,1].
268+
effective_t_max: Upper bound on time to avoid numerical issues at t=1.
269+
This effectively prevents an explosion of the SDE in the beginning.
270+
Note that this does not affect the ODE sampling, which always uses
271+
times in [0,1].
266272
267273
Returns:
268274
Drift function at a given time.
269275
"""
270276
# analytical f(t) does not depend on noise_scale and is undefined at t = 1.
271-
return -input / torch.maximum(1 - times, torch.tensor(1e-6).to(input))
277+
return -input / torch.maximum(
278+
1 - times, torch.tensor(1 - effective_t_max).to(input)
279+
)
272280

273-
def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
281+
def diffusion_fn(
282+
self, input: Tensor, times: Tensor, effective_t_max: float = 0.99
283+
) -> Tensor:
274284
r"""Diffusion function for the flow matching estimator.
275285
276286
The diffusion function is calculated based on [3]_ (see Equation 7):
@@ -288,6 +298,10 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
288298
Args:
289299
input: Parameters :math:`\theta_t`.
290300
times: SDE time variable in [0,1].
301+
effective_t_max: Upper bound on time to avoid numerical issues at t=1.
302+
This effectively prevents an explosion of the SDE in the beginning.
303+
Note that this does not affect the ODE sampling, which always uses
304+
times in [0,1].
291305
292306
Returns:
293307
Diffusion function at a given time.
@@ -296,10 +310,10 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
296310
return torch.sqrt(
297311
2
298312
* (times + self.noise_scale)
299-
/ torch.maximum(1 - times, torch.tensor(1e-6).to(times))
313+
/ torch.maximum(1 - times, torch.tensor(1 - effective_t_max).to(times))
300314
)
301315

302-
def mean_t_fn(self, times: Tensor) -> Tensor:
316+
def mean_t_fn(self, times: Tensor, effective_t_max: float = 0.99) -> Tensor:
303317
r"""Linear coefficient of the perturbation kernel expectation
304318
:math:`\mu_t(t) = E[\theta_t | \theta_0]` for the flow matching estimator.
305319
@@ -316,10 +330,18 @@ def mean_t_fn(self, times: Tensor) -> Tensor:
316330
317331
Args:
318332
times: SDE time variable in [0,1].
333+
effective_t_max: Upper bound on time to avoid numerical issues at t=1.
334+
This prevents singularity at t=1 in the mean function (mean_t=0.).
335+
NOTE: This did affect the IID sampling as the analytical denoising
336+
moments run into issues (as mean_t=0) effectively makes it pure
337+
noise and equations are not well defined anymore. Alternatively
338+
we could also adapt the analytical denoising equations in
339+
`utils/score_utils.py` to account for this case.
319340
320341
Returns:
321342
Mean function at a given time.
322343
"""
344+
times = torch.clamp(times, max=effective_t_max)
323345
mean_t = 1 - times
324346
for _ in range(len(self.input_shape)):
325347
mean_t = mean_t.unsqueeze(-1)

sbi/samplers/score/diffuser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,6 @@ def run(
175175
intermediate_samples.append(samples)
176176

177177
if save_intermediate:
178-
return torch.cat(intermediate_samples, dim=0)
178+
return torch.cat(intermediate_samples, dim=1)
179179
else:
180180
return samples

tests/linearGaussian_vector_field_test.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,6 @@ def test_vector_field_iid_inference(
373373
num_trials: The number of trials to run.
374374
"""
375375

376-
if (
377-
vector_field_type == "fmpe"
378-
and prior_type == "uniform"
379-
and iid_method in ["gauss", "auto_gauss", "jac_gauss"]
380-
):
381-
# TODO: Predictor produces NaNs for these cases, see #1656
382-
pytest.skip("Known issue of IID methods with uniform priors, see #1656.")
383-
384376
vector_field_trained_model = train_vector_field_model(vector_field_type, prior_type)
385377

386378
# Extract data from the trained model

0 commit comments

Comments
 (0)