Skip to content

Commit 6b7ccf3

Browse files
committed
sdes typing, dataset test
1 parent a66d937 commit 6b7ccf3

File tree

5 files changed

+123
-96
lines changed

5 files changed

+123
-96
lines changed

sbgm/sde/_sde.py

Lines changed: 66 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from typing import Sequence, Tuple, Self, Callable, Optional, Union
1+
from typing import Sequence, Tuple, Self, Callable, Optional
22
import jax
33
import jax.numpy as jnp
44
import equinox as eqx
5-
from jaxtyping import Key, Array, Float, jaxtyped
5+
from jaxtyping import Key, Array, Float, Scalar
66

7-
TimeFn = Callable[[float | Float[Array, ""]], Float[Array, ""]]
8-
Time = Float[Array, ""] | float
7+
Time = Scalar | float
8+
TimeFn = Callable[[Time], Scalar]
99

1010

1111
def default_weight_fn(t, *, beta_integral=None, sigma_fn=None):
@@ -15,35 +15,71 @@ def default_weight_fn(t, *, beta_integral=None, sigma_fn=None):
1515

1616
class SDE(eqx.Module):
1717
"""
18-
SDE abstract class.
18+
Abstract base class for Stochastic Differential Equations (SDEs) used in
19+
score-based generative modeling and related diffusion models.
20+
21+
This class defines the required interface and provides base functionality for
22+
forward and reverse-time SDEs, prior sampling, and log-probability computation.
23+
The user should subclass `SDE` and implement the following methods:
24+
- `sde()`: returns drift and diffusion coefficients
25+
- `marginal_prob()`: returns the parameters of the marginal distribution at time t
26+
- `prior_sample()`: returns samples from the terminal distribution p_T(x)
27+
- `prior_log_prob()`: returns log-probabilities under p_T(x)
28+
- `weight()`: returns weighting for loss functions
29+
30+
Attributes:
31+
dt (float): Time discretization step size.
32+
t0 (float): Start time of the diffusion process.
33+
t1 (float): End time of the diffusion process.
1934
"""
2035
dt: float
2136
t0: float
2237
t1: float
2338

2439
def __init__(self, dt: float = 0.01, t0: float = 0., t1: float = 1.):
2540
"""
26-
Construct an SDE.
41+
Initialize the base SDE with time parameters.
42+
43+
Args:
44+
dt (float): Time step for the SDE solver.
45+
t0 (float): Initial time of the process.
46+
t1 (float): Terminal time of the process.
2747
"""
2848
super().__init__()
2949
self.t0 = t0
3050
self.t1 = t1
3151
self.dt = dt
3252

33-
def sde(self, x: Array, t: Union[float, Array]) -> Tuple[Array, Array]:
34-
pass
53+
def sde(self, x: Array, t: Time) -> Tuple[Array, Array]:
54+
"""
55+
Return the drift and diffusion coefficients f(x, t), g(t) of the SDE.
56+
57+
Must be implemented by subclass.
58+
"""
59+
...
3560

36-
def marginal_prob(self, x: Array, t: Union[float, Array]) -> Tuple[Array, Array]:
61+
def marginal_prob(self, x: Array, t: Time) -> Tuple[Array, Array]:
3762
""" Parameters to determine the marginal distribution of the SDE, $p_t(x)$. """
38-
pass
63+
...
3964

4065
def prior_sample(self, key: Key, shape: Sequence[int]) -> Array:
41-
""" Generate one sample from the prior distribution, $p_T(x)$. """
42-
pass
66+
"""
67+
Generate one sample from the prior distribution, $p_T(x)$.
68+
"""
69+
...
70+
71+
def weight(self, t: Time, likelihood_weight: bool = False) -> Array:
72+
"""
73+
Return the training loss weight at time t.
4374
44-
def weight(self, t: Union[float, Array], likelihood_weight: bool = False) -> Array:
45-
""" Weighting for loss """
46-
pass
75+
Args:
76+
t (float or Array): Time value(s).
77+
likelihood_weight (bool): Whether to use likelihood weighting (optional).
78+
79+
Returns:
80+
Array: Scalar or array of weights.
81+
"""
82+
...
4783

4884
def prior_log_prob(self, z: Array) -> Array:
4985
"""
@@ -52,20 +88,25 @@ def prior_log_prob(self, z: Array) -> Array:
5288
Useful for computing the log-likelihood via probability flow ODE.
5389
5490
Args:
55-
z: latent code
91+
z: latent code
92+
5693
Returns:
57-
log probability density
94+
log probability density
5895
"""
59-
pass
96+
...
6097

6198
def reverse(self, score_fn: eqx.Module, probability_flow: bool = False) -> Self:
6299
"""
63100
Create the reverse-time SDE/ODE.
64101
65102
Args:
66-
score_fn: A time-dependent score-based model that takes x and t and returns the score.
67-
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
103+
score_fn: A time-dependent score-based model that takes x and t and returns the score.
104+
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
105+
106+
Returns:
107+
SDE: A subclass implementing the reverse-time SDE.
68108
"""
109+
69110
sde_fn = self.sde
70111

71112
if hasattr(self, "beta_integral_fn"):
@@ -77,7 +118,7 @@ def reverse(self, score_fn: eqx.Module, probability_flow: bool = False) -> Self:
77118
_t0 = self.t0
78119
_t1 = self.t1
79120

80-
# Build the class for reverse-time SDE.
121+
# Build the class for the reverse-time SDE.
81122
class RSDE(self.__class__, SDE):
82123
probability_flow: bool
83124

@@ -91,7 +132,7 @@ def sde(
91132
t: Time,
92133
q: Optional[Float[Array, "..."]] = None,
93134
a: Optional[Float[Array, "..."]] = None
94-
) -> Tuple[Float[Array, "..."], Float[Array, ""]]:
135+
) -> Tuple[Float[Array, "..."], Scalar]:
95136
"""
96137
Create the drift and diffusion functions for the reverse SDE/ODE.
97138
- forward time SDE:
@@ -102,11 +143,11 @@ def sde(
102143
dx = [f(x, t) - 0.5 * g^2(t) * score(x, t)] * dt (ODE => No dw)
103144
"""
104145
t = jnp.asarray(t)
105-
coeff = 0.5 if self.probability_flow else 1.
146+
c = 0.5 if self.probability_flow else 1.
106147
drift, diffusion = sde_fn(x, t)
107148
score = score_fn(t, x, q, a)
108149
# Drift coefficient of reverse SDE and probability flow only different by a factor
109-
drift = drift - jnp.square(diffusion) * score * coeff
150+
drift = drift - jnp.square(diffusion) * score * c
110151
# Set the diffusion function to zero for ODEs (dw=0)
111152
diffusion = 0. if self.probability_flow else diffusion
112153
return drift, diffusion
@@ -117,49 +158,4 @@ def sde(
117158
def _get_log_prob_fn(scale: float = 1.) -> Callable:
118159
def _log_prob_fn(z: Array) -> Array:
119160
return jax.scipy.stats.norm.logpdf(z, loc=0., scale=scale).sum()
120-
return _log_prob_fn
121-
122-
123-
# if __name__ == "__main__":
124-
# import os
125-
# import matplotlib.pyplot as plt
126-
# import numpy as np
127-
128-
# figs_dir = "/project/ls-gruen/users/jed.homer/1pt_pdf/little_studies/sgm_lib/sgm/figs/"
129-
130-
# # Plot SDEs with time
131-
# beta_integral_fn = lambda t: t
132-
# beta_fn = get_beta_fn(beta_integral_fn)
133-
# sigma_fn = lambda t: jnp.exp(t)
134-
135-
# times = dict(t0=0., t1=4., dt=0.1)
136-
137-
# vp_sde = VPSDE(beta_integral_fn, **times)
138-
# ve_sde = VESDE(sigma_fn=sigma_fn)
139-
# subvp_sde = SubVPSDE(beta_integral_fn, **times)
140-
141-
# x = jnp.ones((1,))
142-
# T = jnp.linspace(1e-5, times["t1"], 1000)
143-
144-
# def get_sde_drift_and_diffusion_fn(sde):
145-
# return jax.vmap(sde.sde, in_axes=(None, 0))
146-
147-
# def get_sde_mean_and_std(sde):
148-
# return jax.vmap(sde.marginal_prob, in_axes=(None, 0))
149-
150-
# fig, axs = plt.subplots(1, 4, figsize=(21., 4.), dpi=200)
151-
# ax = axs[0]
152-
# ax.plot(T, jax.vmap(beta_fn)(T), linestyle=":", label=r"$\beta(t)$")
153-
# ax_ = ax.twinx()
154-
# ax.legend(frameon=False, loc="upper left")
155-
# ax_.plot(T, jax.vmap(beta_integral_fn)(T), label=r"$\int_0^t\beta(s)ds$")
156-
# ax_.legend(frameon=False, loc="lower right")
157-
# plt.title("SDEs")
158-
# for ax, _sde in zip(axs[1:], [ve_sde, vp_sde, subvp_sde]):
159-
# mu, std = get_sde_mean_and_std(_sde)(x, T)
160-
# ax.set_title(str(_sde.__class__.__name__))
161-
# ax.plot(T, mu, label=r"$\mu(t)$")
162-
# ax.plot(T, std, label=r"$\sigma(t)$")
163-
# ax.legend(frameon=False)
164-
# plt.savefig(os.path.join(figs_dir, "sdes.png"), bbox_inches="tight")
165-
# plt.close()
161+
return _log_prob_fn

sbgm/sde/_subvp.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
import jax.numpy as jnp
44
import jax.random as jr
55
import equinox as eqx
6-
from jaxtyping import Key, Array, Float, jaxtyped
6+
from jaxtyping import PRNGKeyArray, Array, Float, Scalar, jaxtyped
77
from beartype import beartype as typechecker
88

99
from ._sde import SDE, _get_log_prob_fn, Time, TimeFn
1010

1111

1212
def get_beta_fn(beta_integral_fn: TimeFn | eqx.Module) -> TimeFn:
1313
""" Obtain beta function from a beta integral. """
14-
def _beta_fn(t: Time) -> Float[Array, ""]:
14+
def _beta_fn(t: Time) -> Scalar:
1515
_, beta = jax.jvp(
1616
beta_integral_fn,
1717
primals=(t,),
@@ -47,7 +47,7 @@ def __init__(
4747
self.weight_fn = weight_fn
4848

4949
@jaxtyped(typechecker=typechecker)
50-
def sde(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "..."], Float[Array, ""]]:
50+
def sde(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "..."], Scalar]:
5151
"""
5252
dx = f(x, t) * dt + g(t) * dw
5353
dx = -0.5 * beta(t) * x * dt + sqrt(beta(t) * (1 - exp(-2 * int[beta(s)]))) * dw
@@ -59,7 +59,7 @@ def sde(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "..."], Flo
5959
return drift, diffusion
6060

6161
@jaxtyped(typechecker=typechecker)
62-
def marginal_prob(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "..."], Float[Array, ""]]:
62+
def marginal_prob(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "..."], Scalar]:
6363
"""
6464
Sub-VP SDE p_t(x(t)|x(0)) is
6565
x(t) ~ G[x(t)|mu(x(0), t), sigma^2(t)]
@@ -73,7 +73,7 @@ def marginal_prob(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "
7373
return mean, std
7474

7575
@jaxtyped(typechecker=typechecker)
76-
def weight(self, t: Time, likelihood_weight: bool = False) -> Float[Array, ""]:
76+
def weight(self, t: Time, likelihood_weight: bool = False) -> Scalar:
7777
# Likelihood weighting: above Eq 8 https://arxiv.org/pdf/2101.09258.pdf
7878
if self.weight_fn is not None and not likelihood_weight:
7979
weight = self.weight_fn(t)
@@ -84,8 +84,8 @@ def weight(self, t: Time, likelihood_weight: bool = False) -> Float[Array, ""]:
8484
weight = jnp.square(1. - jnp.exp(-self.beta_integral_fn(t)))
8585
return weight
8686

87-
def prior_sample(self, key: Key[jnp.ndarray, "..."], shape: Sequence[int]) -> Float[Array, "..."]:
87+
def prior_sample(self, key: PRNGKeyArray, shape: Sequence[int]) -> Float[Array, "..."]:
8888
return jr.normal(key, shape)
8989

90-
def prior_log_prob(self, z: Float[Array, "..."]) -> Float[Array, ""]:
90+
def prior_log_prob(self, z: Float[Array, "..."]) -> Scalar:
9191
return _get_log_prob_fn(scale=1.)(z)

sbgm/sde/_ve.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
from typing import Callable, Optional, Sequence, Tuple, Union
1+
from typing import Optional, Sequence, Tuple, Union
22
import jax
33
import jax.numpy as jnp
44
import jax.random as jr
55
import equinox as eqx
6-
from jaxtyping import Key, Array, Float, jaxtyped
6+
from jaxtyping import PRNGKeyArray, Array, Float, Scalar, jaxtyped
77
from beartype import beartype as typechecker
88

99
from ._sde import SDE, _get_log_prob_fn, Time, TimeFn
1010

1111

1212
def get_diffusion_fn(sigma_fn: Union[TimeFn, eqx.Module]) -> TimeFn:
1313
""" Get diffusion coefficient function for VE SDE: dx = sqrt(d[sigma^2(t)]/dt)dw """
14-
def _diffusion_fn(t: Time) -> Float[Array, ""]:
14+
def _diffusion_fn(t: Time) -> Scalar:
1515
_, dsigmadt = jax.jvp(
1616
lambda t: jnp.square(sigma_fn(t)),
1717
primals=(t,),
@@ -41,15 +41,15 @@ def __init__(
4141
dx = sqrt(d[sigma_fn(t) ** 2]/dt)
4242
4343
Args:
44-
sigma: default variance value
45-
dt: timestep width
44+
sigma: default variance value
45+
dt: timestep width
4646
"""
4747
super().__init__(dt=dt, t0=t0, t1=t1)
4848
self.sigma_fn = sigma_fn
4949
self.weight_fn = weight_fn
5050

5151
@jaxtyped(typechecker=typechecker)
52-
def sde(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "..."], Float[Array, ""]]:
52+
def sde(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "..."], Scalar]:
5353
drift = jnp.zeros_like(x)
5454
_, dsigma2dt = jax.jvp(
5555
lambda t: jnp.square(self.sigma_fn(t)),
@@ -61,7 +61,7 @@ def sde(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "..."], Flo
6161
return drift, diffusion
6262

6363
@jaxtyped(typechecker=typechecker)
64-
def marginal_prob(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "..."], Float[Array, ""]]:
64+
def marginal_prob(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "..."], Scalar]:
6565
"""
6666
SDE:
6767
dx = sqrt(d[sigma^2(t)]/dt) * dw
@@ -74,7 +74,7 @@ def marginal_prob(self, x: Float[Array, "..."], t: Time) -> Tuple[Float[Array, "
7474
return x, std
7575

7676
@jaxtyped(typechecker=typechecker)
77-
def weight(self, t: Time, likelihood_weight: bool = False) -> Float[Array, ""]:
77+
def weight(self, t: Time, likelihood_weight: bool = False) -> Scalar:
7878
if self.weight_fn is not None and not likelihood_weight:
7979
weight = self.weight_fn(t)
8080
else:
@@ -84,8 +84,8 @@ def weight(self, t: Time, likelihood_weight: bool = False) -> Float[Array, ""]:
8484
weight = jnp.square(self.sigma_fn(t)) # Same for likelihood weighting
8585
return weight
8686

87-
def prior_sample(self, key: Key[jnp.ndarray, "..."], shape: Sequence[int]) -> Float[Array, "..."]:
87+
def prior_sample(self, key: PRNGKeyArray, shape: Sequence[int]) -> Float[Array, "..."]:
8888
return jr.normal(key, shape) * self.sigma_fn(self.t1)
8989

90-
def prior_log_prob(self, z: Float[Array, "..."]) -> Float[Array, ""]:
90+
def prior_log_prob(self, z: Float[Array, "..."]) -> Scalar:
9191
return _get_log_prob_fn(scale=self.sigma_fn(self.t1))(z)

0 commit comments

Comments
 (0)