@@ -243,7 +243,9 @@ def score(self, input: Tensor, condition: Tensor, t: Tensor) -> Tensor:
243243 score = (- (1 - t ) * v - input ) / (t + self .noise_scale )
244244 return score
245245
246- def drift_fn (self , input : Tensor , times : Tensor ) -> Tensor :
246+ def drift_fn (
247+ self , input : Tensor , times : Tensor , effective_t_max : float = 0.99
248+ ) -> Tensor :
247249 r"""Drift function for the flow matching estimator.
248250
249251 The drift function is calculated based on [3]_ (see Equation 7):
@@ -263,14 +265,22 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
263265 Args:
264266 input: Parameters :math:`\theta_t`.
265267 times: SDE time variable in [0,1].
268+ effective_t_max: Upper bound on time to avoid numerical issues at t=1.
269+ This effectively prevents an explosion of the SDE in the beginning.
270+ Note that this does not affect the ODE sampling, which always uses
271+ times in [0,1].
266272
267273 Returns:
268274 Drift function at a given time.
269275 """
270276 # analytical f(t) does not depend on noise_scale and is undefined at t = 1.
271- return - input / torch .maximum (1 - times , torch .tensor (1e-6 ).to (input ))
277+ return - input / torch .maximum (
278+ 1 - times , torch .tensor (1 - effective_t_max ).to (input )
279+ )
272280
273- def diffusion_fn (self , input : Tensor , times : Tensor ) -> Tensor :
281+ def diffusion_fn (
282+ self , input : Tensor , times : Tensor , effective_t_max : float = 0.99
283+ ) -> Tensor :
274284 r"""Diffusion function for the flow matching estimator.
275285
276286 The diffusion function is calculated based on [3]_ (see Equation 7):
@@ -288,6 +298,10 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
288298 Args:
289299 input: Parameters :math:`\theta_t`.
290300 times: SDE time variable in [0,1].
301+ effective_t_max: Upper bound on time to avoid numerical issues at t=1.
302+ This effectively prevents an explosion of the SDE in the beginning.
303+ Note that this does not affect the ODE sampling, which always uses
304+ times in [0,1].
291305
292306 Returns:
293307 Diffusion function at a given time.
@@ -296,10 +310,10 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
296310 return torch .sqrt (
297311 2
298312 * (times + self .noise_scale )
299- / torch .maximum (1 - times , torch .tensor (1e-6 ).to (times ))
313+ / torch .maximum (1 - times , torch .tensor (1 - effective_t_max ).to (times ))
300314 )
301315
302- def mean_t_fn (self , times : Tensor ) -> Tensor :
316+ def mean_t_fn (self , times : Tensor , effective_t_max : float = 0.99 ) -> Tensor :
303317 r"""Linear coefficient of the perturbation kernel expectation
304318 :math:`\mu_t(t) = E[\theta_t | \theta_0]` for the flow matching estimator.
305319
@@ -316,10 +330,18 @@ def mean_t_fn(self, times: Tensor) -> Tensor:
316330
317331 Args:
318332 times: SDE time variable in [0,1].
333+ effective_t_max: Upper bound on time to avoid numerical issues at t=1.
334+ This prevents singularity at t=1 in the mean function (mean_t=0.).
335+ NOTE: This did affect the IID sampling as the analytical denoising
336+ moments run into issues (as mean_t=0) effectively makes it pure
337+ noise and equations are not well defined anymore. Alternatively
338+ we could also adapt the analytical denoising equations in
339+ `utils/score_utils.py` to account for this case.
319340
320341 Returns:
321342 Mean function at a given time.
322343 """
344+ times = torch .clamp (times , max = effective_t_max )
323345 mean_t = 1 - times
324346 for _ in range (len (self .input_shape )):
325347 mean_t = mean_t .unsqueeze (- 1 )
0 commit comments