1- from typing import Sequence , Tuple , Self , Callable , Optional , Union
1+ from typing import Sequence , Tuple , Self , Callable , Optional
22import jax
33import jax .numpy as jnp
44import equinox as eqx
5- from jaxtyping import Key , Array , Float , jaxtyped
5+ from jaxtyping import Key , Array , Float , Scalar
66
7- TimeFn = Callable [[ float | Float [ Array , "" ]], Float [ Array , "" ]]
8- Time = Float [ Array , "" ] | float
7+ Time = Scalar | float
8+ TimeFn = Callable [[ Time ], Scalar ]
99
1010
1111def default_weight_fn (t , * , beta_integral = None , sigma_fn = None ):
@@ -15,35 +15,71 @@ def default_weight_fn(t, *, beta_integral=None, sigma_fn=None):
1515
1616class SDE (eqx .Module ):
1717 """
18- SDE abstract class.
18+ Abstract base class for Stochastic Differential Equations (SDEs) used in
19+ score-based generative modeling and related diffusion models.
20+
21+ This class defines the required interface and provides base functionality for
22+ forward and reverse-time SDEs, prior sampling, and log-probability computation.
23+ The user should subclass `SDE` and implement the following methods:
24+ - `sde()`: returns drift and diffusion coefficients
25+ - `marginal_prob()`: returns the parameters of the marginal distribution at time t
26+ - `prior_sample()`: returns samples from the terminal distribution p_T(x)
27+ - `prior_log_prob()`: returns log-probabilities under p_T(x)
28+ - `weight()`: returns weighting for loss functions
29+
30+ Attributes:
31+ dt (float): Time discretization step size.
32+ t0 (float): Start time of the diffusion process.
33+ t1 (float): End time of the diffusion process.
1934 """
2035 dt : float
2136 t0 : float
2237 t1 : float
2338
2439 def __init__ (self , dt : float = 0.01 , t0 : float = 0. , t1 : float = 1. ):
2540 """
26- Construct an SDE.
41+ Initialize the base SDE with time parameters.
42+
43+ Args:
44+ dt (float): Time step for the SDE solver.
45+ t0 (float): Initial time of the process.
46+ t1 (float): Terminal time of the process.
2747 """
2848 super ().__init__ ()
2949 self .t0 = t0
3050 self .t1 = t1
3151 self .dt = dt
3252
33- def sde (self , x : Array , t : Union [float , Array ]) -> Tuple [Array , Array ]:
34- pass
53+ def sde (self , x : Array , t : Time ) -> Tuple [Array , Array ]:
54+ """
55+ Return the drift and diffusion coefficients f(x, t), g(t) of the SDE.
56+
57+ Must be implemented by subclass.
58+ """
59+ ...
3560
36- def marginal_prob (self , x : Array , t : Union [ float , Array ] ) -> Tuple [Array , Array ]:
61+ def marginal_prob (self , x : Array , t : Time ) -> Tuple [Array , Array ]:
3762 """ Parameters to determine the marginal distribution of the SDE, $p_t(x)$. """
38- pass
63+ ...
3964
4065 def prior_sample (self , key : Key , shape : Sequence [int ]) -> Array :
41- """ Generate one sample from the prior distribution, $p_T(x)$. """
42- pass
66+ """
67+ Generate one sample from the prior distribution, $p_T(x)$.
68+ """
69+ ...
70+
71+ def weight (self , t : Time , likelihood_weight : bool = False ) -> Array :
72+ """
73+ Return the training loss weight at time t.
4374
44- def weight (self , t : Union [float , Array ], likelihood_weight : bool = False ) -> Array :
45- """ Weighting for loss """
46- pass
75+ Args:
76+ t (float or Array): Time value(s).
77+ likelihood_weight (bool): Whether to use likelihood weighting (optional).
78+
79+ Returns:
80+ Array: Scalar or array of weights.
81+ """
82+ ...
4783
4884 def prior_log_prob (self , z : Array ) -> Array :
4985 """
@@ -52,20 +88,25 @@ def prior_log_prob(self, z: Array) -> Array:
5288 Useful for computing the log-likelihood via probability flow ODE.
5389
5490 Args:
55- z: latent code
91+ z: latent code
92+
5693 Returns:
57- log probability density
94+ log probability density
5895 """
59- pass
96+ ...
6097
6198 def reverse (self , score_fn : eqx .Module , probability_flow : bool = False ) -> Self :
6299 """
63100 Create the reverse-time SDE/ODE.
64101
65102 Args:
66- score_fn: A time-dependent score-based model that takes x and t and returns the score.
67- probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
103+ score_fn: A time-dependent score-based model that takes x and t and returns the score.
104+ probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
105+
106+ Returns:
107+ SDE: A subclass implementing the reverse-time SDE.
68108 """
109+
69110 sde_fn = self .sde
70111
71112 if hasattr (self , "beta_integral_fn" ):
@@ -77,7 +118,7 @@ def reverse(self, score_fn: eqx.Module, probability_flow: bool = False) -> Self:
77118 _t0 = self .t0
78119 _t1 = self .t1
79120
80- # Build the class for reverse-time SDE.
121+ # Build the class for the reverse-time SDE.
81122 class RSDE (self .__class__ , SDE ):
82123 probability_flow : bool
83124
@@ -91,7 +132,7 @@ def sde(
91132 t : Time ,
92133 q : Optional [Float [Array , "..." ]] = None ,
93134 a : Optional [Float [Array , "..." ]] = None
94- ) -> Tuple [Float [Array , "..." ], Float [ Array , "" ] ]:
135+ ) -> Tuple [Float [Array , "..." ], Scalar ]:
95136 """
96137 Create the drift and diffusion functions for the reverse SDE/ODE.
97138 - forward time SDE:
@@ -102,11 +143,11 @@ def sde(
102143 dx = [f(x, t) - 0.5 * g^2(t) * score(x, t)] * dt (ODE => No dw)
103144 """
104145 t = jnp .asarray (t )
105- coeff = 0.5 if self .probability_flow else 1.
146+ c = 0.5 if self .probability_flow else 1.
106147 drift , diffusion = sde_fn (x , t )
107148 score = score_fn (t , x , q , a )
108149 # Drift coefficient of reverse SDE and probability flow only different by a factor
109- drift = drift - jnp .square (diffusion ) * score * coeff
150+ drift = drift - jnp .square (diffusion ) * score * c
110151 # Set the diffusion function to zero for ODEs (dw=0)
111152 diffusion = 0. if self .probability_flow else diffusion
112153 return drift , diffusion
@@ -117,49 +158,4 @@ def sde(
117158def _get_log_prob_fn (scale : float = 1. ) -> Callable :
118159 def _log_prob_fn (z : Array ) -> Array :
119160 return jax .scipy .stats .norm .logpdf (z , loc = 0. , scale = scale ).sum ()
120- return _log_prob_fn
121-
122-
123- # if __name__ == "__main__":
124- # import os
125- # import matplotlib.pyplot as plt
126- # import numpy as np
127-
128- # figs_dir = "/project/ls-gruen/users/jed.homer/1pt_pdf/little_studies/sgm_lib/sgm/figs/"
129-
130- # # Plot SDEs with time
131- # beta_integral_fn = lambda t: t
132- # beta_fn = get_beta_fn(beta_integral_fn)
133- # sigma_fn = lambda t: jnp.exp(t)
134-
135- # times = dict(t0=0., t1=4., dt=0.1)
136-
137- # vp_sde = VPSDE(beta_integral_fn, **times)
138- # ve_sde = VESDE(sigma_fn=sigma_fn)
139- # subvp_sde = SubVPSDE(beta_integral_fn, **times)
140-
141- # x = jnp.ones((1,))
142- # T = jnp.linspace(1e-5, times["t1"], 1000)
143-
144- # def get_sde_drift_and_diffusion_fn(sde):
145- # return jax.vmap(sde.sde, in_axes=(None, 0))
146-
147- # def get_sde_mean_and_std(sde):
148- # return jax.vmap(sde.marginal_prob, in_axes=(None, 0))
149-
150- # fig, axs = plt.subplots(1, 4, figsize=(21., 4.), dpi=200)
151- # ax = axs[0]
152- # ax.plot(T, jax.vmap(beta_fn)(T), linestyle=":", label=r"$\beta(t)$")
153- # ax_ = ax.twinx()
154- # ax.legend(frameon=False, loc="upper left")
155- # ax_.plot(T, jax.vmap(beta_integral_fn)(T), label=r"$\int_0^t\beta(s)ds$")
156- # ax_.legend(frameon=False, loc="lower right")
157- # plt.title("SDEs")
158- # for ax, _sde in zip(axs[1:], [ve_sde, vp_sde, subvp_sde]):
159- # mu, std = get_sde_mean_and_std(_sde)(x, T)
160- # ax.set_title(str(_sde.__class__.__name__))
161- # ax.plot(T, mu, label=r"$\mu(t)$")
162- # ax.plot(T, std, label=r"$\sigma(t)$")
163- # ax.legend(frameon=False)
164- # plt.savefig(os.path.join(figs_dir, "sdes.png"), bbox_inches="tight")
165- # plt.close()
161+ return _log_prob_fn
0 commit comments