|
1 | 1 | # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed |
2 | 2 | # under the Apache License Version 2.0, see <https://www.apache.org/licenses/> |
3 | 3 |
|
4 | | -import inspect |
5 | 4 | from abc import ABCMeta, abstractmethod |
6 | | -from typing import Callable, Optional |
| 5 | +from typing import Optional, Protocol |
7 | 6 |
|
8 | 7 | import torch |
9 | 8 | from torch import Tensor |
@@ -83,42 +82,43 @@ def return_x_o(self) -> Optional[Tensor]: |
83 | 82 | return self._x_o |
84 | 83 |
|
85 | 84 |
|
86 | | -class CallablePotentialWrapper(BasePotential): |
| 85 | +class CustomPotential(Protocol): |
| 86 | + """Protocol for custom potential functions.""" |
| 87 | + |
| 88 | + def __call__(self, theta: Tensor, x_o: Tensor) -> Tensor: |
| 89 | + """Call the potential function on given theta and observed data.""" |
| 90 | + ... |
| 91 | + |
| 92 | + |
| 93 | +class CustomPotentialWrapper(BasePotential): |
87 | 94 | """If `potential_fn` is a callable it gets wrapped as this.""" |
88 | 95 |
|
89 | 96 | def __init__( |
90 | 97 | self, |
91 | | - potential_fn: Callable, |
| 98 | + potential_fn: CustomPotential, |
92 | 99 | prior: Optional[Distribution], |
93 | 100 | x_o: Optional[Tensor] = None, |
94 | 101 | device: str = "cpu", |
95 | 102 | ): |
96 | 103 | """Wraps a callable potential function. |
97 | 104 |
|
98 | 105 | Args: |
99 | | - potential_fn: Callable potential function, must have `theta` and `x_o` as |
100 | | - arguments. |
101 | | - prior: Prior distribution. |
102 | | - x_o: Observed data. |
| 106 | + potential_fn: Custom potential function following the CustomPotential |
| 107 | + protocol, i.e., the function must have exactly two positional arguments |
| 108 | + where the first is theta and the second is the x_o. |
| 109 | + prior: Prior distribution, optional at init, but needed at inference time. |
| 110 | + x_o: Observed data, optional at init, but needed at inference time. |
103 | 111 | device: Device on which to evaluate the potential function. |
104 | 112 |
|
105 | 113 | """ |
106 | 114 | super().__init__(prior, x_o, device) |
107 | 115 |
|
108 | | - kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys()) |
109 | | - required_keys = ["theta", "x_o"] |
110 | | - for key in required_keys: |
111 | | - assert key in kwargs_of_callable, ( |
112 | | - "If you pass a `Callable` as `potential_fn` then it must have " |
113 | | - "`theta` and `x_o` as inputs, even if some of these keyword " |
114 | | - "arguments are unused." |
115 | | - ) |
116 | 116 | self.potential_fn = potential_fn |
117 | 117 |
|
118 | 118 | def __call__(self, theta, track_gradients: bool = True): |
119 | | - """Call the callable potential function on given theta. |
| 119 | + """Calls the custom potential function on given theta. |
120 | 120 |
|
121 | 121 | Note, x_o is re-used from the initialization of the potential function. |
122 | 122 | """ |
123 | 123 | with torch.set_grad_enabled(track_gradients): |
124 | | - return self.potential_fn(theta=theta, x_o=self.x_o) |
| 124 | + return self.potential_fn(theta, self.x_o) |
0 commit comments