Skip to content

Commit 69b7f38

Browse files
authored
fix: add protocol for custom potential (#1409)
* fix: protocol and refactor for custom potential * fix docstrings * fix docstring
1 parent c2ff942 commit 69b7f38

File tree

3 files changed

+26
-28
lines changed

3 files changed

+26
-28
lines changed

sbi/inference/posteriors/base_posterior.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

44
from abc import abstractmethod
5-
from typing import Any, Callable, Dict, Optional, Union
5+
from typing import Any, Dict, Optional, Union
66
from warnings import warn
77

88
import torch
@@ -11,7 +11,8 @@
1111

1212
from sbi.inference.potentials.base_potential import (
1313
BasePotential,
14-
CallablePotentialWrapper,
14+
CustomPotential,
15+
CustomPotentialWrapper,
1516
)
1617
from sbi.sbi_types import Array, Shape, TorchTransform
1718
from sbi.utils.sbiutils import gradient_ascent
@@ -29,7 +30,7 @@ class NeuralPosterior:
2930

3031
def __init__(
3132
self,
32-
potential_fn: Union[Callable, BasePotential],
33+
potential_fn: Union[BasePotential, CustomPotential],
3334
theta_transform: Optional[TorchTransform] = None,
3435
device: Optional[str] = None,
3536
x_shape: Optional[torch.Size] = None,
@@ -51,16 +52,13 @@ def __init__(
5152
stacklevel=2,
5253
)
5354

54-
# Wrap as `CallablePotentialWrapper` if `potential_fn` is a Callable.
55+
# Wrap custom potential functions to adhere to the `BasePotential` interface.
5556
if not isinstance(potential_fn, BasePotential):
56-
# If the `potential_fn` is a Callable then we wrap it as a
57-
# `CallablePotentialWrapper` which inherits from `BasePotential`.
5857
potential_device = "cpu" if device is None else device
59-
potential_fn = CallablePotentialWrapper(
58+
potential_fn = CustomPotentialWrapper(
6059
potential_fn, prior=None, x_o=None, device=potential_device
6160
)
6261

63-
# Ensure device string.
6462
self._device = process_device(potential_fn.device if device is None else device)
6563

6664
self.potential_fn = potential_fn

sbi/inference/potentials/base_potential.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
import inspect
54
from abc import ABCMeta, abstractmethod
6-
from typing import Callable, Optional
5+
from typing import Optional, Protocol
76

87
import torch
98
from torch import Tensor
@@ -83,42 +82,43 @@ def return_x_o(self) -> Optional[Tensor]:
8382
return self._x_o
8483

8584

86-
class CallablePotentialWrapper(BasePotential):
85+
class CustomPotential(Protocol):
86+
"""Protocol for custom potential functions."""
87+
88+
def __call__(self, theta: Tensor, x_o: Tensor) -> Tensor:
89+
"""Call the potential function on given theta and observed data."""
90+
...
91+
92+
93+
class CustomPotentialWrapper(BasePotential):
8794
"""If `potential_fn` is a callable it gets wrapped as this."""
8895

8996
def __init__(
9097
self,
91-
potential_fn: Callable,
98+
potential_fn: CustomPotential,
9299
prior: Optional[Distribution],
93100
x_o: Optional[Tensor] = None,
94101
device: str = "cpu",
95102
):
96103
"""Wraps a callable potential function.
97104
98105
Args:
99-
potential_fn: Callable potential function, must have `theta` and `x_o` as
100-
arguments.
101-
prior: Prior distribution.
102-
x_o: Observed data.
106+
potential_fn: Custom potential function following the CustomPotential
107+
protocol, i.e., the function must have exactly two positional arguments
108+
where the first is theta and the second is the x_o.
109+
prior: Prior distribution, optional at init, but needed at inference time.
110+
x_o: Observed data, optional at init, but needed at inference time.
103111
device: Device on which to evaluate the potential function.
104112
105113
"""
106114
super().__init__(prior, x_o, device)
107115

108-
kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
109-
required_keys = ["theta", "x_o"]
110-
for key in required_keys:
111-
assert key in kwargs_of_callable, (
112-
"If you pass a `Callable` as `potential_fn` then it must have "
113-
"`theta` and `x_o` as inputs, even if some of these keyword "
114-
"arguments are unused."
115-
)
116116
self.potential_fn = potential_fn
117117

118118
def __call__(self, theta, track_gradients: bool = True):
119-
"""Call the callable potential function on given theta.
119+
"""Calls the custom potential function on given theta.
120120
121121
Note, x_o is re-used from the initialization of the potential function.
122122
"""
123123
with torch.set_grad_enabled(track_gradients):
124-
return self.potential_fn(theta=theta, x_o=self.x_o)
124+
return self.potential_fn(theta, self.x_o)

tests/potential_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
RejectionPosterior,
1515
VIPosterior,
1616
)
17-
from sbi.inference.potentials.base_potential import CallablePotentialWrapper
17+
from sbi.inference.potentials.base_potential import CustomPotentialWrapper
1818
from sbi.utils import BoxUniform
1919
from sbi.utils.conditional_density_utils import ConditionedPotential
2020

@@ -83,7 +83,7 @@ def potential(theta, x_o):
8383
],
8484
)
8585
def test_conditioned_potential(condition: Tensor):
86-
potential_fn = CallablePotentialWrapper(
86+
potential_fn = CustomPotentialWrapper(
8787
potential_fn=lambda theta, x_o: theta,
8888
prior=BoxUniform(low=zeros(2), high=ones(2)),
8989
)

0 commit comments

Comments
 (0)