Skip to content

Add EDM-style noise schedules for VEScoreEstimator#1754

Open
janfb wants to merge 3 commits intomainfrom
add-edm-noise-schedules
Open

Add EDM-style noise schedules for VEScoreEstimator#1754
janfb wants to merge 3 commits intomainfrom
add-edm-noise-schedules

Conversation

@janfb
Copy link
Contributor

@janfb janfb commented Feb 5, 2026

Extends VEScoreEstimator with configurable noise schedules from Karras et al. (2022) "Elucidating the Design Space of Diffusion-Based Generative Models". This follows up on #1736, which refactored schedule methods into the base class but deferred EDM-specific schedules pending benchmarking.

This applies only to VE score estimators. The EDM papers suggests to use log-normal / geometric time schedules for the training and power-law time schedules for the ODE / SDE solving schedules. So, instead of a uniform schedule, we use:

  • Lognormal: samples σ from exp(P_mean + P_std * z), converts to time via VE's geometric relationship
  • Power-law: σ_i = (σ_max^(1/ρ) + i/(N-1) * (σ_min^(1/ρ) - σ_max^(1/ρ)))^ρ, concentrates steps at low noise

Because these schedules can lead to numerical instabilities in edge cases, this PR also adds:

  • Parameter validation in constructor (sigma bounds, schedule strings, exponent positivity)
  • Log-space clamping in lognormal schedule to prevent NaN from log(negative)
  • Warning when >1% of lognormal samples require clamping
  • Explicit boundary value assignment in power-law to avoid floating-point drift

The hyperparameters are exposed to the user via the build function and posterior_nn:

  • build_vector_field_estimator now passes sigma_min, sigma_max, and all schedule parameters to VE estimators
  • posterior_score_nn accepts these parameters when sde_type="ve"

All new parameters have defaults matching previous uniform behavior. Existing code is unaffected.

Example

from sbi.neural_nets import posterior_score_nn

density_estimator = posterior_score_nn(
    model="mlp",
    sde_type="ve",
    train_schedule="lognormal",
    solve_schedule="power_law",
    # optional, this is the default from paper
    lognormal_mean=-1.2,     
    lognormal_std=1.2,
    power_law_exponent=7.0,
)

trainer = NPSE(prior=prior, density_estimator=density_estimator)
trainer.append_simulations(theta, x).train()
posterior = trainer.build_posterior()

Benchmarking

I am currently running benchmarks to compare against uniform schedules, will post them here once done.

References

@codecov
Copy link

codecov bot commented Feb 5, 2026

❌ 1 Tests Failed:

Tests completed Failed Passed Skipped
5828 1 5827 146
View the full list of 1 ❄️ flaky test(s)
tests/torchutils_test.py::TorchUtilsTest::test_searchsorted

Flake rate in main: 35.06% (Passed 50 times, Failed 27 times)

Stack Traces | 0.008s run time
.venv/lib/python3.10....../site-packages/xdist/remote.py:289: in pytest_runtest_logreport
    self.sendevent("testreport", data=data)
.venv/lib/python3.10....../site-packages/xdist/remote.py:126: in sendevent
    self.channel.send((name, kwargs))
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:912: in send
    self.gateway._send(Message.CHANNEL_DATA, self.id, dumps_internal(item))
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1629: in dumps_internal
    return _Serializer().save(obj)  # type: ignore[return-value]
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1647: in save
    self._save(obj)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1667: in _save
    dispatch(self, obj)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1744: in save_tuple
    self._save(item)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1667: in _save
    dispatch(self, obj)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1740: in save_dict
    self._write_setitem(key, value)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1734: in _write_setitem
    self._save(value)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1667: in _save
    dispatch(self, obj)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1740: in save_dict
    self._write_setitem(key, value)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1734: in _write_setitem
    self._save(value)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1667: in _save
    dispatch(self, obj)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1740: in save_dict
    self._write_setitem(key, value)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1734: in _write_setitem
    self._save(value)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1667: in _save
    dispatch(self, obj)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1740: in save_dict
    self._write_setitem(key, value)
.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1734: in _write_setitem
    self._save(value)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <execnet.gateway_base._Serializer object at 0x7fc464470eb0>
obj = tensor([0.0000, 0.1111, 0.2222, 0.3333, 0.4444, 0.5556, 0.6667, 0.7778, 0.8889])

    def _save(self, obj: object) -> None:
        tp = type(obj)
        try:
            dispatch = self._dispatch[tp]
        except KeyError:
            methodname = "save_" + tp.__name__
            meth: Callable[[_Serializer, object], None] | None = getattr(
                self.__class__, methodname, None
            )
            if meth is None:
>               raise DumpError(f"can't serialize {tp}") from None
E               execnet.gateway_base.DumpError: can't serialize <class 'torch.Tensor'>

.venv/lib/python3.10....................................................../site-packages/execnet/gateway_base.py:1665: DumpError

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

janfb added 3 commits February 5, 2026 20:13
Implements configurable noise schedules from Karras et al. (2022)
"Elucidating the Design Space of Diffusion-Based Generative Models":

- train_schedule: "uniform" (default) | "lognormal"
- solve_schedule: "uniform" (default) | "power_law"
- Parameters: lognormal_mean, lognormal_std, power_law_exponent

Includes input validation and numerical stability:
- Validates sigma bounds (sigma_min > 0, sigma_max > sigma_min)
- Log-space clamping in lognormal schedule to prevent NaN
- Warning when >5% of samples require clamping
- Explicit boundary assignment in power-law schedule

References: https://arxiv.org/abs/2206.00364
- Extract sigma_min, sigma_max, and schedule params for VE estimators
- Update docstrings for solve_schedule usage
- Add comment clarifying VE simulation budget in tests
- test_ve_edm_schedules: all schedule combinations
- test_ve_lognormal_no_nan_with_extreme_params: numerical stability
@janfb janfb force-pushed the add-edm-noise-schedules branch from d1e2f68 to f625a41 Compare February 5, 2026 14:48
std_0: float = 1.0,
t_min: float = 1e-3,
t_max: float = 1.0,
train_schedule: Literal["uniform", "lognormal"] = "uniform",
Copy link
Contributor

@manuelgloeckler manuelgloeckler Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on how the benchmark goes, I think we should even make lognormal and power_law the default.

Copy link
Contributor

@manuelgloeckler manuelgloeckler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants