Skip to content

Commit 1b1edfd

Browse files
Fix: Make NPSE picklable (#1679)
1 parent 89dc308 commit 1b1edfd

File tree

3 files changed

+30
-10
lines changed

3 files changed

+30
-10
lines changed

sbi/neural_nets/estimators/score_estimator.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -399,21 +399,28 @@ def _set_weight_fn(self, weight_fn: Union[str, Callable]):
399399
- a custom function that returns a Callable.
400400
"""
401401
if weight_fn == "identity":
402-
self.weight_fn = lambda times: 1
402+
self.weight_fn = self._identity_weight_fn
403403
elif weight_fn == "max_likelihood":
404-
self.weight_fn = (
405-
lambda times: self.diffusion_fn(
406-
torch.ones((1,), device=times.device), times
407-
)
408-
** 2
409-
)
404+
self.weight_fn = self._max_likelihood_weight_fn
410405
elif weight_fn == "variance":
411-
self.weight_fn = lambda times: self.std_fn(times) ** 2
406+
self.weight_fn = self._variance_weight_fn
412407
elif callable(weight_fn):
413408
self.weight_fn = weight_fn
414409
else:
415410
raise ValueError(f"Weight function {weight_fn} not recognized.")
416411

412+
def _identity_weight_fn(self, times):
413+
"""Return ones for any time t."""
414+
return 1
415+
416+
def _max_likelihood_weight_fn(self, times):
417+
"""Return weights proportional to the diffusion function."""
418+
return self.diffusion_fn(torch.ones((1,), device=times.device), times) ** 2
419+
420+
def _variance_weight_fn(self, times):
421+
"""Return weights as the variance."""
422+
return self.std_fn(times) ** 2
423+
417424
def ode_fn(self, input: Tensor, condition: Tensor, times: Tensor) -> Tensor:
418425
"""ODE flow function of the score estimator.
419426

sbi/samplers/ode_solvers/zuko_ode.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
Zuko ODE solver.
66
"""
77

8+
import functools
9+
810
import torch.nn as nn
911
from torch import Tensor
1012
from torch.distributions import Distribution
@@ -81,8 +83,9 @@ def get_distribution(self, condition: Tensor, **kwargs) -> Distribution:
8183
The distribution object with `log_prob` and
8284
`sample` methods that wraps the ODE solver.
8385
"""
86+
partial_f = functools.partial(self._f_condition_last, condition=condition)
8487
transform = FreeFormJacobianTransform(
85-
f=lambda t, input: self.f(input, condition, t),
88+
f=partial_f,
8689
t0=condition.new_tensor(self.t_min),
8790
t1=condition.new_tensor(self.t_max),
8891
phi=(condition, *self.net.parameters()),
@@ -93,3 +96,10 @@ def get_distribution(self, condition: Tensor, **kwargs) -> Distribution:
9396
transform=transform,
9497
base=DiagNormal(self.mean_base, self.std_base).expand(condition.shape[:-1]),
9598
)
99+
100+
def _f_condition_last(self, t, input, condition):
101+
"""Return the ODE function with argument-order changed: condition is last.
102+
103+
This is a helper to build a partial function in `get_distribution`, which
104+
requires that `condition` is passed as last element."""
105+
return self.f(input, condition, t)

tests/save_and_load_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import torch
88

99
from sbi import utils as utils
10-
from sbi.inference import NLE, NPE, NRE
10+
from sbi.inference import FMPE, NLE, NPE, NPSE, NRE
1111
from sbi.inference.posteriors.posterior_parameters import (
1212
DirectPosteriorParameters,
1313
MCMCPosteriorParameters,
1414
RejectionPosteriorParameters,
1515
VIPosteriorParameters,
16+
VectorFieldPosteriorParameters,
1617
)
1718
from sbi.inference.posteriors.vi_posterior import VIPosterior
1819

@@ -21,6 +22,8 @@
2122
"inference_method, posterior_parameters",
2223
(
2324
(NPE, DirectPosteriorParameters),
25+
(NPSE, VectorFieldPosteriorParameters),
26+
(FMPE, VectorFieldPosteriorParameters),
2427
pytest.param(NLE, MCMCPosteriorParameters, marks=pytest.mark.mcmc),
2528
pytest.param(NRE, MCMCPosteriorParameters, marks=pytest.mark.mcmc),
2629
pytest.param(NRE, VIPosteriorParameters, marks=pytest.mark.mcmc),

0 commit comments

Comments
 (0)