Skip to content

refactoring noise_schedule and time schedule into base class#1736

Merged
manuelgloeckler merged 38 commits intomainfrom
psteinb-explicit_noise_schedules-1437
Feb 4, 2026
Merged

refactoring noise_schedule and time schedule into base class#1736
manuelgloeckler merged 38 commits intomainfrom
psteinb-explicit_noise_schedules-1437

Conversation

@janfb
Copy link
Contributor

@janfb janfb commented Jan 22, 2026

This is the conflict-free version of #1481 by @psteinb. I merged main in, resolved the conflicts and moved onto this new branch to enable others to continue this work.

@psteinb I hope this is fine for you - all your commits are still here and attributed to you.

  • created noise_schedule method to be overwritten by derivatives
  • created times_schedules method to be overwritten by derivatives
  • created test on times_schedule
  • improvded docstrings
  • created test on noise_schedule

This addresses #1437

psteinb and others added 20 commits March 20, 2025 18:38
- created noise_schedule method to be overwritten by derivatives
- created times_schedules method to be overwritten by derivatives
- created test on times_schedules
- improvded docstrings
in addition:
- added improved version to benchmarks (for later comparison)
- created new class ImprovedVPScoreEstimator
to understand how the VE estimator is implemented
- without touching the forward function of ConditionalScoreEstimator
- benchmarks show that this leads to very long training time
without any performance improvements
…e_schedules-1437

Resolved the merge conflicts by aligning the score estimator with the vector-field API
changes from main and accepting the deletions of legacy score/NPSE paths.

Details:
- Cleaned and reconciled ConditionalScoreEstimator imports/init and typing in sbi/neural_nets/estimators/score_estimator.py, keeping beta_min/beta_max and device tracking consistent with the new base.
- Dropped deleted legacy files to match main: sbi/inference/trainers/npse/npse.py, sbi/neural_nets/net_builders/score_nets.py, and tests/score_estimator_test.py.
@janfb
Copy link
Contributor Author

janfb commented Jan 22, 2026

And here's a quick review by opencode using Codex 5.2 :

Summary

The PR centralizes time/noise schedule logic in ConditionalScoreEstimator and routes
VP/SubVP drift/diffusion through the shared schedule. This aligns with the requested
feature. However, there are two correctness issues and one scope mismatch with the
issue requirements.

Blocking issues

  1. times_schedule returns torch.Tensor(sorted(times)), which recreates the tensor
    on CPU and drops device/dtype. This will silently move schedules off-device.

    • Fix: use torch.sort(times).values and return the tensor as-is.
  2. times_schedule relies on self.device, but self.device is not updated when the
    module moves devices via .to(). Schedules can be created on the wrong device.

    • Fix: use a buffer/device source such as self._mean_base.device or
      next(self.parameters()).device when constructing times.

Non-blocking but important

  • The loss docstring says it uses times_schedule when times is None, but the
    implementation still samples directly with torch.rand. Either update the code to
    call self.times_schedule(...) or update the docstring. The issue explicitly asked
    for using the default training schedule, so the implementation should likely use the
    new method.

Tests

The issue requests tests for both schedules, but this PR only changes
score_estimator.py. Please add tests for:

  • times_schedule: correct shape, monotonicity, and device
  • noise_schedule: correct shape and range [beta_min, beta_max]

@codecov
Copy link

codecov bot commented Jan 22, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 88.39%. Comparing base (649f5d3) to head (ecb295c).
⚠️ Report is 6 commits behind head on main.
✅ All tests successful. No failed tests found.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1736      +/-   ##
==========================================
- Coverage   88.51%   88.39%   -0.13%     
==========================================
  Files         137      137              
  Lines       11527    12188     +661     
==========================================
+ Hits        10203    10773     +570     
- Misses       1324     1415      +91     
Flag Coverage Δ
fast 84.73% <100.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/posteriors/vector_field_posterior.py 77.47% <ø> (+0.29%) ⬆️
sbi/inference/trainers/vfpe/base_vf_inference.py 92.81% <100.00%> (ø)
sbi/neural_nets/estimators/base.py 75.19% <100.00%> (+1.42%) ⬆️
sbi/neural_nets/estimators/score_estimator.py 93.96% <100.00%> (+0.42%) ⬆️

... and 9 files with indirect coverage changes

@janfb
Copy link
Contributor Author

janfb commented Jan 22, 2026

when comparing to the plan in the original issue #1437 , I noticed that there are still a couple of open todos, for the general API ideas and for the schedule from EDM paper in particular. Accordingly, here is a potential plan for finishing this PR:

  1. Add schedule APIs to ConditionalScoreEstimator

    • Add train_schedule(num_samples, t_min=None, t_max=None) that returns diffusion
      times used for training. Default behavior should be a simple uniform schedule so
      current results remain unchanged unless overridden by subclasses.
    • Add solve_schedule(num_steps, t_min=None, t_max=None) that returns a deterministic
      monotonic time grid (e.g., evenly spaced). This is the default time discretization
      for evaluation/sampling steps.
    • Keep noise_schedule(times) as the mapping from time to noise magnitude (beta/sigma)
      so schedules stay decoupled from the SDE implementation.
  2. Use schedule APIs in the loss

    • Update loss(..., times=None) to call self.train_schedule(...) when times is
      not provided, so the new schedule is actually used in training.
    • Ensure schedules are created on the estimator’s current device (e.g., via a buffer
      like self._mean_base.device) to avoid CPU/GPU mismatches.
  3. Wire solve schedule into vector-field trainer validation

    • In sbi/inference/trainers/vfpe/base_vf_inference.py, when validation_times is an
      integer, replace the uniform torch.linspace(...) with
      self._neural_net.solve_schedule(num_steps).
    • Preserve existing behavior when validation_times is already a tensor so users can
      pass custom schedules explicitly.
  4. Add tests

    • train_schedule: correct shape, correct bounds (t_min/t_max), device, and that
      the default is uniform.
    • solve_schedule: monotonic increasing, includes endpoints, correct device, and
      correct length.
    • Regression: verify loss uses train_schedule when times is None.
  5. Optional EDM schedule for VE (paper details)

    • Add optional VE-specific parameters (e.g., edm_sigma_min, edm_sigma_max,
      edm_rho, edm_p_mean, edm_p_std) and a switch such as schedule="edm" to opt in.
    • Training noise distribution: EDM samples noise levels from a log-normal
      distribution (Table 1):
      • ln(σ) ~ N(P_mean, P_std^2).
      • Example defaults used in the paper: P_mean = -1.2, P_std = 1.2 (CIFAR-10).
      • In code: sample σ = exp(P_mean + P_std * N(0,1)), then clamp to
        [σ_min, σ_max] if needed.
    • Solve schedule (deterministic grid): EDM uses a power-law schedule for
      discretizing σ (Eq. 5 in the paper):
      • Define σ_i for i = 0..N-1 as
        σ_i = (σ_max^(1/ρ) + i/(N-1) * (σ_min^(1/ρ) - σ_max^(1/ρ)))^ρ.
      • Set σ_N = 0 for the final step.
      • Larger ρ concentrates steps near low noise; the paper uses ρ = 7.
    • Mapping to time: with the EDM choice σ(t) = t and s(t) = 1, time and
      sigma are interchangeable, so t_i = σ_i.
    • Keep uniform schedules as the default so existing behavior is stable.

@touronc touronc self-assigned this Jan 26, 2026
@janfb
Copy link
Contributor Author

janfb commented Jan 30, 2026

Hi @touronc

Thank you for the updates here - looks very good. I tested this locally and found some issues and suggest fixes below. I just saw that you pushed some fixes already recently, so some of the comments below will be obsolete.


1. solve_schedule must be deterministic

The implementation uses torch.rand() but the docstring says "deterministic monotonic time grid". Even with sorting, it's random on each call, which breaks ODE/SDE integration reproducibility. Use torch.linspace() instead:

def solve_schedule(self, num_steps, t_min=None, t_max=None):
    t_min = self.t_min if t_min is None else t_min
    t_max = self.t_max if t_max is None else t_max
    return torch.linspace(t_max, t_min, num_steps, device=self._mean_base.device)

2. Restore validation_times_nugget

The trainer change lost the nugget offset that avoids boundary instability at t=0 and t=t_max.

base_vf_inference.py: Pass nugget to solve_schedule():

if isinstance(validation_times, int):
    validation_times = self._neural_net.solve_schedule(
        validation_times,
        t_min=self._neural_net.t_min + validation_times_nugget,
        t_max=self._neural_net.t_max - validation_times_nugget,
    )

This in turn requires the base class ConditionalVectorFieldEstimator.solve_schedule() to accept t_min/t_max parameters:

def solve_schedule(
    self,
    steps: int,
    t_min: Optional[float] = None,
    t_max: Optional[float] = None,
) -> Tensor:
    t_min = self.t_min if t_min is None else t_min
    t_max = self.t_max if t_max is None else t_max
    return torch.linspace(t_max, t_min, steps, device=self._mean_base.device)

4. Device handling in loss()

Ensure times tensor is on the correct device after calling train_schedule():

if times is None:
    times = self.train_schedule(input.shape[0])
times = times.to(input.device)

5. Remove self.device tracking

The manual device tracking in __init__ and loss() is fragile (doesn't follow .to() calls) and never actually read. Remove these lines:

# In __init__:
self.device = net.device if hasattr(net, "device") else torch.device("cpu")

# In loss():
self.device = input.device if self.device != input.device else self.device

The schedule methods already use self._mean_base.device which is a proper buffer.


6. train_schedule should use plain uniform (no sorting (?))

The current implementation sorts times and forces endpoints, or is there a specific reason for the sorting?:

times[0, ...] = t_min
times[-1, ...] = t_max
return torch.sort(times).values

This differs from main branch behavior and I believe uniform random would be fine here (or am I missing sth?):

return (
    torch.rand(num_samples, device=self._mean_base.device) * (t_max - t_min)
    + t_min
)

7. VE test needs slightly more simulations

Locally, I noticed that with these changes, the VE test fails with 2500 simulations, probably because the overall random state changed. Increasing the simulation budget slightly fixes it for me:

# In tests/linearGaussian_vector_field_test.py
num_simulations = 2600 if vector_field_type == "ve" else 2500

8. Minor docstring fixes

  • train_schedule: Returns says "range [0,1]" but should be "[t_min, t_max]"
  • noise_schedule: References self.times_schedule which doesn't exist
  • vector_field_posterior.py (~line 309): ts parameter has broken indentation

Otherwise, this looks very good!

Regarding the EDM paper related changes, I suggest we move this into a follow-up PR.

Thanks @touronc

Copy link
Contributor Author

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @touronc for implementing this! 🙏 very well done!

I added a couple of comments, mostly formatting.

Copy link
Contributor Author

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All fixes addressed, thank you @touronc ! 🙏

I approve this PR (cannot officially approve the PR because I created it).

Copy link
Contributor

@manuelgloeckler manuelgloeckler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, love it. Thanks for finishing this up.

I think e.g. for VE variants it would be beneficial to do some EDM like train and solve schedule, but this is another project which would need some benchmarking. So will also approve this to get it merged.

@manuelgloeckler manuelgloeckler merged commit 1888205 into main Feb 4, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants