Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
abc6742
refactoring noise_schedule and time schedule into base class
psteinb Mar 18, 2025
b7d3be3
added noise schedule test
psteinb Mar 18, 2025
dca1939
implemented beta schedule for variance-preserving estimators
psteinb Mar 18, 2025
6ec5793
code cosmetics triggered by ruff
psteinb Mar 18, 2025
3e963a9
cloned VPScoreEstimator to yield improved version
psteinb Mar 19, 2025
4b62510
more realistic bounds for unit test
psteinb Mar 19, 2025
274212e
typo and refactoring
psteinb Mar 19, 2025
74065ae
fixed wrong setup of pmean and pstd
psteinb Mar 19, 2025
d61e198
code reformatting
psteinb Mar 19, 2025
d2c0100
use the time schedule for computing the validation scores
psteinb Mar 20, 2025
5b50673
propagate name change
psteinb Mar 20, 2025
2081ae0
fix unit tests to respect new schedules
psteinb Mar 20, 2025
4f58be6
comply with formatting
psteinb Mar 20, 2025
773a28b
attempted to implement EDM-like diffusion
psteinb Mar 20, 2025
6386fc8
removed "improved" denoising network
psteinb Mar 20, 2025
2bbd92f
consolidated tests
psteinb Mar 20, 2025
8f0c65a
removed occurrances of vp++
psteinb Mar 20, 2025
833c17b
removed all mentions of edm
psteinb Mar 20, 2025
74d60f2
Merge remote-tracking branch 'origin/main' into psteinb-explicit_nois…
janfb Jan 22, 2026
ec81ec0
ruff fixes
janfb Jan 22, 2026
bfb0f3a
WIP : use time schedule in loss function, address device issues
Jan 23, 2026
55ad6b3
call solve_schedule in validation step
Jan 26, 2026
5a60729
add solve_schedule method, call train_schedule in loss
Jan 26, 2026
6beb9ac
call the solve schedule during sampling with SDE
Jan 26, 2026
beb187e
add a solve_schedule function in the conditional vf estimator class t…
Jan 27, 2026
e750c67
make the solve schedule deterministic
Jan 29, 2026
ece8ede
corrections on solve schedule
Jan 29, 2026
a7078c9
WIP : create solve schedule in base class
Jan 29, 2026
aaff762
modify arguments of solve schedule
Jan 30, 2026
89d0c8d
change train_schedule + docstrings fixes + device handling
Jan 30, 2026
70f4b21
include validation times nugget to avoid instabilities during training
Jan 30, 2026
7383d9c
change the nb of simulations for ve option
Jan 30, 2026
f1ef67e
change device in solve schedule
Jan 30, 2026
32eec91
add noise schedule in VE subclass
Jan 30, 2026
1a032b7
reshape noise schedule output in VE class
Jan 30, 2026
b851fd7
reshape noise schedule output in VE class
Jan 30, 2026
f17126d
add tests on train and solve schedule shapes, devices, bounds
Jan 31, 2026
ecb295c
formatting and changing tests
Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions sbi/inference/posteriors/vector_field_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def _sample_via_diffusion(
corrector: The corrector for the diffusion-based sampler. Either of
[None].
steps: Number of steps to take for the Euler-Maruyama method.
ts: Time points at which to evaluate the diffusion process. If None, a
linear grid between t_max and t_min is used.
ts: Time points at which to evaluate the diffusion process. If None, call
the solve schedule specific to the score estimator.
max_sampling_batch_size: Maximum batch size for sampling.
sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to
`.sample()`.
Expand All @@ -333,11 +333,8 @@ def _sample_via_diffusion(
# Ensure we don't use larger batches than total samples needed
effective_batch_size = min(effective_batch_size, total_samples_needed)

# TODO: the time schedule should be provided by the estimator, see issue #1437
if ts is None:
t_max = self.vector_field_estimator.t_max
t_min = self.vector_field_estimator.t_min
ts = torch.linspace(t_max, t_min, steps)
ts = self.vector_field_estimator.solve_schedule(steps)
ts = ts.to(self.device)

# Initialize the diffusion sampler
Expand Down
6 changes: 3 additions & 3 deletions sbi/inference/trainers/vfpe/base_vf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,10 @@ def default_calibration_kernel(x):
)

if isinstance(validation_times, int):
validation_times = torch.linspace(
self._neural_net.t_min + validation_times_nugget,
self._neural_net.t_max - validation_times_nugget,
validation_times = self._neural_net.solve_schedule(
validation_times,
t_min=self._neural_net.t_min + validation_times_nugget,
t_max=self._neural_net.t_max - validation_times_nugget,
)

loss_args = LossArgsVF(
Expand Down
26 changes: 26 additions & 0 deletions sbi/neural_nets/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,32 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""
raise NotImplementedError("Diffusion is not implemented for this estimator.")

def solve_schedule(
self,
steps: int,
t_min: Optional[float] = None,
t_max: Optional[float] = None,
) -> Tensor:
"""
Deterministic time schedule used during sampling.
Can be overriden by subclasses.
Return by default a uniform time stepping between t_max and t_min.

Args:
steps : number of discretization steps
t_min (float, optional): The minimum time value. Defaults to self.t_min.
t_max (float, optional): The maximum time value. Defaults to self.t_max.

Returns:
Tensor: A tensor of time steps within the range [t_max, t_min].
"""
t_min = self.t_min if isinstance(t_min, type(None)) else t_min
t_max = self.t_max if isinstance(t_max, type(None)) else t_max

times = torch.linspace(t_max, t_min, steps, device=self._mean_base.device)

return times


class UnconditionalEstimator(nn.Module, ABC):
r"""Base class for unconditional estimators that estimate properties of
Expand Down
150 changes: 104 additions & 46 deletions sbi/neural_nets/estimators/score_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(
condition_shape: torch.Size,
embedding_net: Optional[nn.Module] = None,
weight_fn: Union[str, Callable] = "max_likelihood",
beta_min: float = 0.01,
beta_max: float = 10.0,
mean_0: Union[Tensor, float] = 0.0,
std_0: Union[Tensor, float] = 1.0,
t_min: float = 1e-3,
Expand Down Expand Up @@ -111,6 +113,10 @@ def __init__(
t_max=t_max,
)

# Min/max values for noise variance beta
self.beta_min = beta_min
self.beta_max = beta_max

# Set lambdas (variance weights) function.
self._set_weight_fn(weight_fn)
self.register_buffer("mean_0", mean_0.clone().detach())
Expand Down Expand Up @@ -228,7 +234,8 @@ def loss(
Args:
input: Input variable i.e. theta.
condition: Conditioning variable.
times: SDE time variable in [t_min, t_max]. Uniformly sampled if None.
times: SDE time variable in [t_min, t_max].
if None, will be filled by calling the time_schedule method
control_variate: Whether to use a control variate to reduce the variance of
the stochastic loss estimator.
control_variate_threshold: Threshold for the control variate. If the std
Expand All @@ -240,13 +247,11 @@ def loss(
MSE between target score and network output, scaled by the weight function.

"""
# Sample diffusion times.
# Sample times from the Markov chain, use batch dimension
if times is None:
times = (
torch.rand(input.shape[0], device=input.device)
* (self.t_max - self.t_min)
+ self.t_min
)
times = self.train_schedule(input.shape[0])
times.to(input.device)

# Sample noise.
eps = torch.randn_like(input)

Expand Down Expand Up @@ -390,6 +395,80 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""
raise NotImplementedError

def noise_schedule(self, times: Tensor) -> Tensor:
"""
Create a mapping from time to noise magnitude (beta/sigma) used in the SDE.

This method acts as a fallback in case derivative classes do not
implement it on their own. We implement here a linear beta schedule defined
by the input `times`, which represents the normalized time steps t ∈ [0, 1].

Args:
times (Tensor):
SDE times in [0, 1].

Returns:
Tensor: Generated beta schedule at a given time.

"""
return self.beta_min + (self.beta_max - self.beta_min) * times

def train_schedule(
self,
num_samples: int,
t_min: Optional[float] = None,
t_max: Optional[float] = None,
) -> Tensor:
"""
Return diffusion times used for training. Can be overriden by subclasses.
We implement here a uniform sampling of time within the range [t_min, t_max].
The `times` tensor will be put on the same device as the stored network.

Args:
num_samples (int): Number of samples to generate.
t_min (float, optional): The minimum time value. Defaults to self.t_min.
t_max (float, optional): The maximum time value. Defaults to self.t_max.

Returns:
Tensor: A tensor of sampled time variables scaled and shifted to the
range [t_min,t_max].

"""
t_min = self.t_min if isinstance(t_min, type(None)) else t_min
t_max = self.t_max if isinstance(t_max, type(None)) else t_max

return (
torch.rand(num_samples, device=self._mean_base.device) * (t_max - t_min)
+ t_min
)

def solve_schedule(
self,
num_steps: int,
t_min: Optional[float] = None,
t_max: Optional[float] = None,
) -> Tensor:
"""
Return a deterministic monotonic time grid used for evaluation/solving steps.
We implement here a uniform time stepping within the range [t_max, t_min].
The `times` tensor will be put on the same device as the stored network.

Args:
num_stesp (int): Number of time steps to generate.
t_min (float, optional): The minimum time value. Defaults to self.t_min.
t_max (float, optional): The maximum time value. Defaults to self.t_max.

Returns:
Tensor: A tensor of time steps within the range [t_max, t_min].

"""
t_min = self.t_min if isinstance(t_min, type(None)) else t_min
t_max = self.t_max if isinstance(t_max, type(None)) else t_max

times = torch.linspace(t_max, t_min, num_steps, device=self._mean_base.device)

return times

def _set_weight_fn(self, weight_fn: Union[str, Callable]):
"""Set the weight function.

Expand Down Expand Up @@ -480,8 +559,6 @@ def __init__(
t_min: float = 1e-3,
t_max: float = 1.0,
) -> None:
self.beta_min = beta_min
self.beta_max = beta_max
super().__init__(
net,
input_shape,
Expand All @@ -490,6 +567,8 @@ def __init__(
weight_fn=weight_fn,
mean_0=mean_0,
std_0=std_0,
beta_min=beta_min,
beta_max=beta_max,
t_min=t_min,
t_max=t_max,
)
Expand Down Expand Up @@ -525,17 +604,6 @@ def std_fn(self, times: Tensor) -> Tensor:
std = std.unsqueeze(-1)
return torch.sqrt(std)

def _beta_schedule(self, times: Tensor) -> Tensor:
"""Linear beta schedule for mean scaling in variance preserving SDEs.

Args:
times: SDE time variable in [0,1].

Returns:
Beta schedule at a given time.
"""
return self.beta_min + (self.beta_max - self.beta_min) * times

def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""Drift function for variance preserving SDEs.

Expand All @@ -546,7 +614,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Drift function at a given time.
"""
phi = -0.5 * self._beta_schedule(times)
phi = -0.5 * self.noise_schedule(times)
while len(phi.shape) < len(input.shape):
phi = phi.unsqueeze(-1)
return phi * input
Expand All @@ -561,7 +629,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Drift function at a given time.
"""
g = torch.sqrt(self._beta_schedule(times))
g = torch.sqrt(self.noise_schedule(times))
while len(g.shape) < len(input.shape):
g = g.unsqueeze(-1)
return g
Expand Down Expand Up @@ -604,14 +672,14 @@ def __init__(
t_min: float = 1e-2,
t_max: float = 1.0,
) -> None:
self.beta_min = beta_min
self.beta_max = beta_max
super().__init__(
net,
input_shape,
condition_shape,
embedding_net=embedding_net,
weight_fn=weight_fn,
beta_min=beta_min,
beta_max=beta_max,
mean_0=mean_0,
std_0=std_0,
t_min=t_min,
Expand Down Expand Up @@ -649,18 +717,6 @@ def std_fn(self, times: Tensor) -> Tensor:
std = std.unsqueeze(-1)
return std

def _beta_schedule(self, times: Tensor) -> Tensor:
"""Linear beta schedule for mean scaling in sub-variance preserving SDEs.
(Same as for variance preserving SDEs.)

Args:
times: SDE time variable in [0,1].

Returns:
Beta schedule at a given time.
"""
return self.beta_min + (self.beta_max - self.beta_min) * times

def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""Drift function for sub-variance preserving SDEs.

Expand All @@ -671,7 +727,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Drift function at a given time.
"""
phi = -0.5 * self._beta_schedule(times).to(input.device)
phi = -0.5 * self.noise_schedule(times)

while len(phi.shape) < len(input.shape):
phi = phi.unsqueeze(-1)
Expand All @@ -690,7 +746,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""
g = torch.sqrt(
torch.abs(
self._beta_schedule(times)
self.noise_schedule(times)
* (
1
- torch.exp(
Expand Down Expand Up @@ -788,16 +844,19 @@ def std_fn(self, times: Tensor) -> Tensor:
std = std.unsqueeze(-1)
return std

def _sigma_schedule(self, times: Tensor) -> Tensor:
"""Geometric sigma schedule for variance exploding SDEs.
def noise_schedule(self, times: Tensor) -> Tensor:
"""Define the noise schedule used in the SDE drift and
diffusion coefficients.
Note that for VE, this method returns the same as std_fn()

Args:
times: SDE time variable in [0,1].

Returns:
Sigma schedule at a given time.
Noise magnitude (sigma) at a given time.
"""
return self.sigma_min * (self.sigma_max / self.sigma_min) ** times
std = self.sigma_min * (self.sigma_max / self.sigma_min) ** times
return std

def drift_fn(self, input: Tensor, times: Tensor) -> Tensor:
"""Drift function for variance exploding SDEs.
Expand All @@ -821,11 +880,10 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor:
Returns:
Diffusion function at a given time.
"""
g = self._sigma_schedule(times) * math.sqrt(
(2 * math.log(self.sigma_max / self.sigma_min))
)
sigma_scale = self.sigma_max / self.sigma_min
sigmas = self.noise_schedule(times)
g = sigmas * math.sqrt((2 * math.log(sigma_scale)))

while len(g.shape) < len(input.shape):
g = g.unsqueeze(-1)

return g.to(input.device)
2 changes: 1 addition & 1 deletion tests/linearGaussian_vector_field_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_c2st_vector_field_on_linearGaussian(

x_o = zeros(1, num_dim)
num_samples = 1000
num_simulations = 2500
num_simulations = 2600 if vector_field_type == "ve" else 2500

# likelihood_mean will be likelihood_shift+theta
likelihood_shift = -1.0 * ones(num_dim)
Expand Down
Loading
Loading