diff --git a/sbi/inference/posteriors/vector_field_posterior.py b/sbi/inference/posteriors/vector_field_posterior.py index 681cc58cb..01d7d6999 100644 --- a/sbi/inference/posteriors/vector_field_posterior.py +++ b/sbi/inference/posteriors/vector_field_posterior.py @@ -313,8 +313,8 @@ def _sample_via_diffusion( corrector: The corrector for the diffusion-based sampler. Either of [None]. steps: Number of steps to take for the Euler-Maruyama method. - ts: Time points at which to evaluate the diffusion process. If None, call - the solve schedule specific to the score estimator. + ts: Time points at which to evaluate the diffusion process. If None, + uses the solve_schedule() specific to the estimator. max_sampling_batch_size: Maximum batch size for sampling. sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to `.sample()`. diff --git a/sbi/inference/trainers/vfpe/base_vf_inference.py b/sbi/inference/trainers/vfpe/base_vf_inference.py index d7aad91cd..4e2707958 100644 --- a/sbi/inference/trainers/vfpe/base_vf_inference.py +++ b/sbi/inference/trainers/vfpe/base_vf_inference.py @@ -313,6 +313,7 @@ def default_calibration_kernel(x): ) if isinstance(validation_times, int): + # Use nugget to offset from boundaries for numerical stability validation_times = self._neural_net.solve_schedule( validation_times, t_min=self._neural_net.t_min + validation_times_nugget, diff --git a/sbi/neural_nets/estimators/base.py b/sbi/neural_nets/estimators/base.py index f8e116ccb..22e1544d6 100644 --- a/sbi/neural_nets/estimators/base.py +++ b/sbi/neural_nets/estimators/base.py @@ -485,28 +485,21 @@ def solve_schedule( t_min: Optional[float] = None, t_max: Optional[float] = None, ) -> Tensor: - """ - Time grid used during sampling (solving steps) and loss evaluation steps. + """Time schedule used during sampling. Can be overridden by subclasses. - This grid is deterministic and decreasing. Can be overriden by subclasses. - Return by default a uniform time stepping between t_max and t_min. + Returns a uniform time stepping between t_max and t_min by default. Args: - steps: number of discretization steps - t_min: The minimum time value. Defaults to self.t_min. - t_max: The maximum time value. Defaults to self.t_max. + steps: Number of discretization steps. + t_min: Minimum time value. Defaults to self.t_min. + t_max: Maximum time value. Defaults to self.t_max. Returns: - Tensor: A tensor of time steps within the range [t_max, t_min]. + Tensor of time steps from t_max to t_min. """ - if t_min is None: - t_min = self.t_min - if t_max is None: - t_max = self.t_max - - times = torch.linspace(t_max, t_min, steps, device=self._mean_base.device) - - return times + t_min = self.t_min if t_min is None else t_min + t_max = self.t_max if t_max is None else t_max + return torch.linspace(t_max, t_min, steps, device=self._mean_base.device) class UnconditionalEstimator(nn.Module, ABC): diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index f9c24ed44..b986fc40a 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -2,7 +2,8 @@ # under the Apache License Version 2.0, see import math -from typing import Callable, Optional, Union +import warnings +from typing import Callable, Literal, Optional, Union import torch from torch import Tensor, nn @@ -234,8 +235,8 @@ def loss( Args: input: Input variable i.e. theta. condition: Conditioning variable. - times: SDE time variable in [t_min, t_max]. If None, call the train_schedule - method specific to the score estimator. + times: SDE time variable in [t_min, t_max]. If None, sampled via + train_schedule() (which may be overridden by subclasses). control_variate: Whether to use a control variate to reduce the variance of the stochastic loss estimator. control_variate_threshold: Threshold for the control variate. If the std @@ -397,18 +398,17 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: def noise_schedule(self, times: Tensor) -> Tensor: """ - Create a mapping from time to noise magnitude (beta/sigma) used in the SDE. + Map time to noise magnitude (beta for VP/SubVP, sigma for VE). - This method acts as a fallback in case derivative classes do not - implement it on their own. We implement here a linear beta schedule defined - by the input `times`, which represents the normalized time steps t ∈ [0, 1]. + This base implementation returns a linear beta schedule suitable for + VP/SubVP SDEs where times are expected in [0, 1]. Subclasses + (e.g., VEScoreEstimator) override for different schedules. Args: - times: SDE times in [0, 1]. + times: SDE times in [0, 1] for VP/SubVP, or [t_min, t_max] for VE. Returns: - Tensor: Generated beta schedule at a given time. - + Beta (or sigma) schedule values at the given times. """ return self.beta_min + (self.beta_max - self.beta_min) * times @@ -419,25 +419,20 @@ def train_schedule( t_max: Optional[float] = None, ) -> Tensor: """ - Sample diffusion times used during training. + Return diffusion times for training. Samples uniformly in [t_min, t_max]. - Can be overriden by subclasses. We implement here a uniform sampling of - time within the range [t_min, t_max]. The `times` tensor will be put on - the same device as the stored network. + Can be overridden by subclasses (e.g., VEScoreEstimator supports lognormal). Args: - num_samples: Number of samples to generate. - t_min: The minimum time value. Defaults to self.t_min. - t_max: The maximum time value. Defaults to self.t_max. + num_samples: Number of time samples (typically batch size). + t_min: Minimum time value. Defaults to self.t_min. + t_max: Maximum time value. Defaults to self.t_max. Returns: - Tensor: A tensor of time variables sampled in the range [t_min,t_max]. - + Tensor of random times in [t_min, t_max]. """ - if t_min is None: - t_min = self.t_min - if t_max is None: - t_max = self.t_max + t_min = self.t_min if t_min is None else t_min + t_max = self.t_max if t_max is None else t_max return ( torch.rand(num_samples, device=self._mean_base.device) * (t_max - t_min) @@ -451,29 +446,22 @@ def solve_schedule( t_max: Optional[float] = None, ) -> Tensor: """ - Time grid used to solve the reverse SDE and evaluate the loss function. + Return a deterministic monotonic time grid for evaluation/solving. - This grid is deterministic and decreasing. We implement here a uniform time - stepping within the range [t_max, t_min]. The `times` tensor will be put on - the same device as the stored network. + Can be overridden by subclasses (e.g., VEScoreEstimator supports power_law). Args: - num_steps: Number of time steps to generate. - t_min: The minimum time value. Defaults to self.t_min. - t_max: The maximum time value. Defaults to self.t_max. + num_steps: Number of discretization steps. + t_min: Minimum time value. Defaults to self.t_min. + t_max: Maximum time value. Defaults to self.t_max. Returns: - Tensor: A tensor of time steps within the range [t_max, t_min]. - + Tensor of shape (num_steps,) with times from t_max to t_min. """ - if t_min is None: - t_min = self.t_min - if t_max is None: - t_max = self.t_max + t_min = self.t_min if t_min is None else t_min + t_max = self.t_max if t_max is None else t_max - times = torch.linspace(t_max, t_min, num_steps, device=self._mean_base.device) - - return times + return torch.linspace(t_max, t_min, num_steps, device=self._mean_base.device) def _set_weight_fn(self, weight_fn: Union[str, Callable]): """Set the weight function. @@ -775,22 +763,33 @@ class VEScoreEstimator(ConditionalScoreEstimator): The SDE defining the diffusion process is characterized by the following hyper- parameters. - Args: - sigma_min: This defines the smallest "noise" level in the diffusion process. This - ideally would be 0., but denoising score matching losses as employed in most - diffusion models are ill-suited for this case as their variance explodes to - infinity. Hence often a "small" value is chosen (the larger, the easier to - learn but the "noisier" the end result if not addressed post-hoc). - sigma_max: This is the final standard deviation after running the full diffusion - process. Ideally this would approach ∞ such that x0 and xT are truly - independent; it should be at least chosen such that x_T ~ N(0, sigma_max) at - least approximately. If p(x0) for example has itself a very large variance, - then you might have to increase this. - - NOTE: Together with t_min and t_max they ultimatively define the loss function. - Changing these might also require changing t_min and t_max to find a good - tradeoff between bias and variance. + sigma_min: Smallest noise level in the diffusion process. Ideally 0, but + denoising score matching losses have exploding variance at 0, so a small + positive value is used. + sigma_max: Final standard deviation after full diffusion. Should be large + enough that x_T ~ N(0, sigma_max) approximately. + train_schedule: Time sampling strategy for training. "uniform" samples + uniformly in [t_min, t_max]. "lognormal" uses log-normal sigma sampling + per Karras et al. (2022), concentrating on intermediate noise levels. + solve_schedule: Time discretization for ODE/SDE integration. "uniform" uses + uniform linspace. "power_law" uses power-law spacing per Karras et al. + (2022) Eq. 5, concentrating steps near low noise levels. + lognormal_mean: Mean of log-normal distribution for train_schedule="lognormal". + Default -1.2 from Karras et al. (2022). + lognormal_std: Std of log-normal distribution for train_schedule="lognormal". + Default 1.2 from Karras et al. (2022). + power_law_exponent: Exponent (rho) for solve_schedule="power_law". Larger + values concentrate more steps near low noise. Default 7 from Karras et al. + + Note: + Together with t_min and t_max, these parameters define the loss function. + Changing them might require adjusting t_min/t_max for optimal bias-variance + tradeoff. + + References: + Karras et al. (2022) "Elucidating the Design Space of Diffusion-Based + Generative Models" https://arxiv.org/abs/2206.00364 """ @@ -807,9 +806,53 @@ def __init__( std_0: float = 1.0, t_min: float = 1e-3, t_max: float = 1.0, + train_schedule: Literal["uniform", "lognormal"] = "uniform", + solve_schedule: Literal["uniform", "power_law"] = "uniform", + lognormal_mean: float = -1.2, + lognormal_std: float = 1.2, + power_law_exponent: float = 7.0, ) -> None: + # Validate sigma bounds (required for VE SDE math and log computations). + if sigma_min <= 0: + raise ValueError(f"sigma_min must be positive, got {sigma_min}") + if sigma_max <= sigma_min: + raise ValueError( + f"sigma_max ({sigma_max}) must be greater than sigma_min ({sigma_min})" + ) + + # Validate schedule type strings at runtime. + valid_train_schedules = ("uniform", "lognormal") + if train_schedule not in valid_train_schedules: + raise ValueError( + f"train_schedule must be one of {valid_train_schedules}, " + f"got '{train_schedule}'" + ) + valid_solve_schedules = ("uniform", "power_law") + if solve_schedule not in valid_solve_schedules: + raise ValueError( + f"solve_schedule must be one of {valid_solve_schedules}, " + f"got '{solve_schedule}'" + ) + + # Validate lognormal parameters (only when schedule is used). + if train_schedule == "lognormal" and lognormal_std <= 0: + raise ValueError(f"lognormal_std must be positive, got {lognormal_std}") + + # Validate power-law exponent (only when schedule is used). + if solve_schedule == "power_law" and power_law_exponent <= 0: + raise ValueError( + f"power_law_exponent must be positive, got {power_law_exponent}" + ) + self.sigma_min = sigma_min self.sigma_max = sigma_max + self._train_schedule_type = train_schedule + self._solve_schedule_type = solve_schedule + # Log-normal distribution parameters from Karras et al. (2022). + self.lognormal_mean = lognormal_mean + self.lognormal_std = lognormal_std + # Power-law exponent controls step concentration near low noise. + self.power_law_exponent = power_law_exponent super().__init__( net, input_shape, @@ -851,17 +894,18 @@ def std_fn(self, times: Tensor) -> Tensor: return std def noise_schedule(self, times: Tensor) -> Tensor: - """Noise schedule used in the SDE drift and diffusion coefficients. - Note that for VE, this method returns the same as std_fn(). + """Geometric sigma schedule for variance exploding SDEs. + + For VE SDEs, the noise schedule is σ(t) = σ_min * (σ_max / σ_min)^t, + which differs from the linear beta schedule used by VP/SubVP. Args: - times: SDE time variable in [0,1]. + times: SDE times, typically in [t_min, t_max]. Returns: - Noise magnitude (sigma) at a given time. + Sigma schedule at given times. """ - std = self.sigma_min * (self.sigma_max / self.sigma_min) ** times - return std + return self.sigma_min * (self.sigma_max / self.sigma_min) ** times def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: """Drift function for variance exploding SDEs. @@ -885,10 +929,144 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Diffusion function at a given time. """ - sigma_scale = self.sigma_max / self.sigma_min + sigma_ratio = self.sigma_max / self.sigma_min sigmas = self.noise_schedule(times) - g = sigmas * math.sqrt((2 * math.log(sigma_scale))) + g = sigmas * math.sqrt(2 * math.log(sigma_ratio)) while len(g.shape) < len(input.shape): g = g.unsqueeze(-1) return g.to(input.device) + + def train_schedule( + self, + num_samples: int, + t_min: Optional[float] = None, + t_max: Optional[float] = None, + ) -> Tensor: + """ + Return diffusion times for training. + + When train_schedule="uniform" (default): samples uniformly in [t_min, t_max]. + When train_schedule="lognormal": uses log-normal sigma sampling per + Karras et al. (2022) "Elucidating the Design Space of Diffusion-Based + Generative Models", which concentrates training on intermediate noise levels. + + Args: + num_samples: Number of time samples (typically batch size). + t_min: Minimum time value. Defaults to self.t_min. + t_max: Maximum time value. Defaults to self.t_max. + + Returns: + Tensor of times in [t_min, t_max]. + + Raises: + ValueError: If t_min >= t_max. + """ + t_min = self.t_min if t_min is None else t_min + t_max = self.t_max if t_max is None else t_max + + if t_min >= t_max: + raise ValueError(f"t_min ({t_min}) must be less than t_max ({t_max}).") + + if self._train_schedule_type == "uniform": + # Uniform sampling (same as base class) + return ( + torch.rand(num_samples, device=self._mean_base.device) * (t_max - t_min) + + t_min + ) + else: # lognormal + # Sample sigma from log-normal distribution + # sigma = exp(P_mean + P_std * z) where z ~ N(0,1) + log_sigma = self.lognormal_mean + self.lognormal_std * torch.randn( + num_samples, device=self._mean_base.device + ) + + # Clamp in log-space BEFORE exponentiation to prevent NaN from + # log(negative) when converting back to time. This is more numerically + # stable than clamping sigma directly. + log_sigma_min = math.log(self.sigma_min) + log_sigma_max = math.log(self.sigma_max) + + # Check if excessive clamping needed (warn if >5% out of bounds). + out_of_bounds = ( + ((log_sigma < log_sigma_min) | (log_sigma > log_sigma_max)).sum().item() + ) + if out_of_bounds > num_samples * 0.05: + warnings.warn( + f"Lognormal schedule: {out_of_bounds}/{num_samples} samples " + f"({100 * out_of_bounds / num_samples:.1f}%) clamped to " + f"[{self.sigma_min}, {self.sigma_max}]. Consider adjusting " + f"lognormal_mean={self.lognormal_mean} or " + f"lognormal_std={self.lognormal_std}.", + UserWarning, + stacklevel=2, + ) + + log_sigma_clamped = torch.clamp(log_sigma, log_sigma_min, log_sigma_max) + + # Convert log_sigma to time using VE's geometric relationship + log_ratio = log_sigma_max - log_sigma_min # = log(sigma_max / sigma_min) + times = (log_sigma_clamped - log_sigma_min) / log_ratio + + # Final clamp to handle any remaining edge cases from t_min/t_max bounds + return torch.clamp(times, t_min, t_max) + + def solve_schedule( + self, + num_steps: int, + t_min: Optional[float] = None, + t_max: Optional[float] = None, + ) -> Tensor: + """ + Return a deterministic time grid for ODE/SDE integration. + + When solve_schedule="uniform" (default): uniform linspace from t_max to t_min. + When solve_schedule="power_law": power-law discretization per Karras et al. + (2022), Eq. 5, which concentrates steps near low noise levels where fine + details are resolved. + + Args: + num_steps: Number of discretization steps. + t_min: Minimum time value. Defaults to self.t_min. + t_max: Maximum time value. Defaults to self.t_max. + + Returns: + Tensor of shape (num_steps,) with times from t_max to t_min. + + Raises: + ValueError: If t_min >= t_max. + """ + t_min = self.t_min if t_min is None else t_min + t_max = self.t_max if t_max is None else t_max + + if t_min >= t_max: + raise ValueError(f"t_min ({t_min}) must be less than t_max ({t_max}).") + + if self._solve_schedule_type == "uniform": + # Uniform spacing (same as base class) + return torch.linspace( + t_max, t_min, num_steps, device=self._mean_base.device + ) + else: # power_law + # Power-law sigma schedule (Karras et al. 2022, Eq. 5): + # σ_i = (σ_max^(1/ρ) + i/(N-1) * (σ_min^(1/ρ) - σ_max^(1/ρ)))^ρ + rho = self.power_law_exponent + rho_inv = 1.0 / rho + + # Compute sigma values using power-law interpolation + steps = torch.linspace(0, 1, num_steps, device=self._mean_base.device) + sigma_max_inv_rho = self.sigma_max**rho_inv + sigma_min_inv_rho = self.sigma_min**rho_inv + sigmas = ( + sigma_max_inv_rho + steps * (sigma_min_inv_rho - sigma_max_inv_rho) + ) ** rho + + # Convert sigma to time using VE's geometric relationship + log_ratio = math.log(self.sigma_max / self.sigma_min) + times = torch.log(sigmas / self.sigma_min) / log_ratio + + # Ensure exact boundary values (avoid floating-point imprecision) + times[0] = t_max + times[-1] = t_min + + return times diff --git a/sbi/neural_nets/net_builders/vector_field_nets.py b/sbi/neural_nets/net_builders/vector_field_nets.py index d1f135386..0173df1c5 100644 --- a/sbi/neural_nets/net_builders/vector_field_nets.py +++ b/sbi/neural_nets/net_builders/vector_field_nets.py @@ -65,7 +65,10 @@ def build_vector_field_estimator( net: Type of architecture to use, either "mlp", "ada_mlp", "transformer", "transformer_cross_attention" or a custom network following the VectorFieldNet protocol. - **kwargs: Additional arguments for the network. + **kwargs: Additional arguments for the network. For score estimators: + - VP/SubVP: beta_min, beta_max (control noise schedule strength) + - VE: train_schedule, solve_schedule, lognormal_mean, lognormal_std, + power_law_exponent (control training and sampling discretization) Returns: A vector field estimator (either FlowMatchingEstimator or @@ -161,6 +164,25 @@ def build_vector_field_estimator( else: raise ValueError(f"Unknown SDE type: {sde_type}") + # Extract estimator-specific kwargs based on SDE type + estimator_kwargs = {} + if sde_type == "ve": + # VE-specific parameters: sigma bounds and EDM-style schedules + ve_keys = [ + "sigma_min", + "sigma_max", + "train_schedule", + "solve_schedule", + "lognormal_mean", + "lognormal_std", + "power_law_exponent", + ] + estimator_kwargs = {k: kwargs[k] for k in ve_keys if k in kwargs} + elif sde_type in ("vp", "subvp"): + # VP/SubVP-specific beta parameters + vp_keys = ["beta_min", "beta_max"] + estimator_kwargs = {k: kwargs[k] for k in vp_keys if k in kwargs} + return estimator_cls( net=vectorfield_net, input_shape=batch_x[0].shape, @@ -168,6 +190,7 @@ def build_vector_field_estimator( embedding_net=embedding_net_y, mean_0=mean_0, std_0=std_0, + **estimator_kwargs, ) else: raise ValueError(f"Unknown estimator type: {estimator_type}") diff --git a/tests/linearGaussian_vector_field_test.py b/tests/linearGaussian_vector_field_test.py index 40d685499..5422bb3ba 100644 --- a/tests/linearGaussian_vector_field_test.py +++ b/tests/linearGaussian_vector_field_test.py @@ -69,6 +69,7 @@ def test_c2st_vector_field_on_linearGaussian( x_o = zeros(1, num_dim) num_samples = 1000 + # VE with uniform schedule needs slightly more simulations for stability num_simulations = 2600 if vector_field_type == "ve" else 2500 # likelihood_mean will be likelihood_shift+theta diff --git a/tests/vf_estimator_test.py b/tests/vf_estimator_test.py index b9de41dc7..47773c7d8 100644 --- a/tests/vf_estimator_test.py +++ b/tests/vf_estimator_test.py @@ -248,3 +248,68 @@ def test_train_schedule(estimator_type, sde_type): assert torch.allclose(solve_schedule[-1], t_min) assert solve_schedule_default.shape == torch.Size((300,)) assert torch.all(solve_schedule[:-1] - solve_schedule[1:] >= 0) + + +@pytest.mark.parametrize( + "train_schedule,solve_schedule", + [ + ("uniform", "uniform"), + ("lognormal", "uniform"), + ("uniform", "power_law"), + ("lognormal", "power_law"), + ], +) +def test_ve_edm_schedules(train_schedule, solve_schedule): + """Test EDM-style schedules for VE estimator (Karras et al. 2022).""" + estimator = build_score_matching_estimator( + torch.randn(100, 1), + torch.randn(100, 1), + sde_type="ve", + train_schedule=train_schedule, + solve_schedule=solve_schedule, + ) + + # Test train schedule returns valid times without NaN. + times_train = estimator.train_schedule(500) + assert times_train.shape == (500,) + assert torch.all(times_train >= estimator.t_min), "Train times below t_min" + assert torch.all(times_train <= estimator.t_max), "Train times above t_max" + assert not torch.any(torch.isnan(times_train)), "NaN in train schedule" + + # Test solve schedule returns monotonically decreasing times without NaN. + times_solve = estimator.solve_schedule(100) + assert times_solve.shape == (100,) + assert torch.allclose(times_solve[0], torch.tensor(estimator.t_max)), ( + "First solve time != t_max" + ) + assert torch.allclose(times_solve[-1], torch.tensor(estimator.t_min)), ( + "Last solve time != t_min" + ) + assert torch.all(times_solve[:-1] >= times_solve[1:]), ( + "Solve schedule not monotonically decreasing" + ) + assert not torch.any(torch.isnan(times_solve)), "NaN in solve schedule" + + +def test_ve_lognormal_no_nan_with_extreme_params(): + """Test that lognormal schedule doesn't produce NaN even with extreme params.""" + # Use parameters that could cause extreme sigma values. + estimator = build_score_matching_estimator( + torch.randn(100, 1), + torch.randn(100, 1), + sde_type="ve", + train_schedule="lognormal", + lognormal_mean=-3.0, # Very low mean -> small sigmas + lognormal_std=2.0, # High variance -> some extreme samples + ) + + # Generate many samples to test edge cases. + times = estimator.train_schedule(10000) + assert not torch.any(torch.isnan(times)), ( + "NaN produced with extreme lognormal params" + ) + assert not torch.any(torch.isinf(times)), ( + "Inf produced with extreme lognormal params" + ) + assert torch.all(times >= estimator.t_min), "Times below t_min" + assert torch.all(times <= estimator.t_max), "Times above t_max"