Skip to content

Commit eb5ca16

Browse files
authored
refactor: improve abc implementation and tests. (#1615)
* refactor: improve implementation and consistency of inference/abc module * refactor: add type hints to abc module * refactor: add type hints to smcabc * refactor method visibility and types * docs(abc): Add docstrings to distances.py * feat(abc): Refactor ABC module for consistent privatization and type hints * fix abc tests, fix mmd metric * consolidate distances and metrics * fix string output * refactor(abc): lra and sass speed and side effects - use super() in MCABC * ignore all venvs * add lra and sass tests * typing and reduced comments * refactor smcabc test for speed * small fixes * refactor abc tests for speed and coverage * fix gitignore * fix abc tests typing * remove empty file * small fixes
1 parent 029c1bf commit eb5ca16

File tree

8 files changed

+656
-376
lines changed

8 files changed

+656
-376
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
/*.egg
1010
/*.egg-info
1111
.venv*
12-
/.serena
12+
1313
# Notebook checkpoints
1414
.ipynb_checkpoints
1515

@@ -102,3 +102,6 @@ target/
102102
# uv
103103
uv.lock
104104
.python-version
105+
106+
# Serena cache
107+
.serena/

sbi/inference/abc/abc_base.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
import logging
77
from typing import Callable, Dict, Optional, Union
88

9-
import numpy as np
109
import torch
1110
from sklearn.linear_model import LinearRegression
1211
from sklearn.preprocessing import PolynomialFeatures
12+
from torch import Tensor
1313

14-
from sbi.inference.abc.distances import Distance
1514
from sbi.simulators.simutils import simulate_in_batches
15+
from sbi.utils.metrics import Distance
1616

1717

1818
class ABCBASE:
@@ -84,12 +84,12 @@ def __init__(
8484
self.logger = logging.getLogger(__name__)
8585

8686
@staticmethod
87-
def get_sass_transform(
88-
theta: torch.Tensor,
89-
x: torch.Tensor,
87+
def _get_sass_transform(
88+
theta: Tensor,
89+
x: Tensor,
9090
expansion_degree: int = 1,
91-
sample_weight=None,
92-
) -> Callable:
91+
sample_weight: Optional[Tensor] = None,
92+
) -> Callable[[Tensor], Tensor]:
9393
"""Return semi-automatic summary statitics function.
9494
9595
Running weighted linear regressin as in
@@ -102,17 +102,18 @@ def get_sass_transform(
102102
"""
103103
expansion = PolynomialFeatures(degree=expansion_degree, include_bias=False)
104104
# Transform x, remove intercept.
105-
x_expanded = expansion.fit_transform(x)
106-
sumstats_map = np.zeros((x_expanded.shape[1], theta.shape[1]))
107-
108-
for parameter_idx in range(theta.shape[1]):
109-
regression_model = LinearRegression(fit_intercept=True)
110-
regression_model.fit(
111-
X=x_expanded, y=theta[:, parameter_idx], sample_weight=sample_weight
112-
)
113-
sumstats_map[:, parameter_idx] = regression_model.coef_
105+
x_expanded = expansion.fit_transform(x.numpy())
106+
107+
# Fit a single multi-output regression model for all parameters at once
108+
regression_model = LinearRegression(fit_intercept=True)
109+
regression_model.fit(
110+
X=x_expanded,
111+
y=theta.numpy(), # All parameters at once
112+
sample_weight=sample_weight.numpy() if sample_weight is not None else None,
113+
)
114114

115-
sumstats_map = torch.tensor(sumstats_map, dtype=torch.float32)
115+
# Get coefficients for all parameters (shape: [n_features, n_parameters])
116+
sumstats_map = torch.tensor(regression_model.coef_.T, dtype=torch.float32)
116117

117118
def sumstats_transform(x):
118119
x_expanded = torch.tensor(expansion.fit_transform(x), dtype=torch.float32)
@@ -121,28 +122,30 @@ def sumstats_transform(x):
121122
return sumstats_transform
122123

123124
@staticmethod
124-
def run_lra(
125-
theta: torch.Tensor,
126-
x: torch.Tensor,
127-
observation: torch.Tensor,
128-
sample_weight=None,
129-
) -> torch.Tensor:
125+
def _run_lra(
126+
theta: Tensor,
127+
x: Tensor,
128+
observation: Tensor,
129+
sample_weight: Optional[Tensor] = None,
130+
) -> Tensor:
130131
"""Return parameters adjusted with linear regression adjustment.
131132
132133
Implementation as in Beaumont et al. 2002: https://arxiv.org/abs/1707.01254
133134
"""
134135

135-
theta_adjusted = theta
136-
for parameter_idx in range(theta.shape[1]):
137-
regression_model = LinearRegression(fit_intercept=True)
138-
regression_model.fit(
139-
X=x,
140-
y=theta[:, parameter_idx],
141-
sample_weight=sample_weight,
142-
)
143-
theta_adjusted[:, parameter_idx] += regression_model.predict(
144-
observation.reshape(1, -1)
145-
)
146-
theta_adjusted[:, parameter_idx] -= regression_model.predict(x)
136+
# Fit a single multi-output regression model
137+
regression_model = LinearRegression(fit_intercept=True)
138+
regression_model.fit(
139+
X=x.numpy(),
140+
y=theta.numpy(), # All parameters at once
141+
sample_weight=sample_weight.numpy() if sample_weight is not None else None,
142+
)
143+
144+
# Predict for observation and simulated data
145+
pred_obs = regression_model.predict(observation.reshape(1, -1).numpy())
146+
pred_sim = regression_model.predict(x.numpy())
147+
148+
# Apply adjustment: theta + m(x_o) - m(x)
149+
theta_adjusted = theta + torch.from_numpy(pred_obs - pred_sim)
147150

148151
return theta_adjusted

sbi/inference/abc/distances.py

Lines changed: 0 additions & 136 deletions
This file was deleted.

sbi/inference/abc/mcabc.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from numpy import ndarray
1010
from torch import Tensor
11+
from torch.distributions import Distribution
1112

1213
from sbi.inference.abc.abc_base import ABCBASE
1314
from sbi.utils.kde import KDEWrapper, get_kde
@@ -20,9 +21,9 @@ class MCABC(ABCBASE):
2021
def __init__(
2122
self,
2223
simulator: Callable,
23-
prior,
24+
prior: Distribution,
2425
distance: Union[str, Callable] = "l2",
25-
requires_iid_data: Optional[None] = None,
26+
requires_iid_data: Optional[bool] = None,
2627
distance_kwargs: Optional[Dict] = None,
2728
num_workers: int = 1,
2829
simulation_batch_size: int = 1,
@@ -147,15 +148,15 @@ def __call__(
147148
pilot_theta = self.prior.sample((num_pilot_simulations,))
148149
pilot_x = self._batched_simulator(pilot_theta)
149150

150-
sass_transform = self.get_sass_transform(
151+
sass_transform = super()._get_sass_transform(
151152
pilot_theta, pilot_x, sass_expansion_degree
152153
)
153154

154155
# Add sass transform to simulator and x_o.
155156
def simulator(theta):
156157
return sass_transform(self._batched_simulator(theta))
157158

158-
x_o = sass_transform(x_o)
159+
x_o = sass_transform(process_x(x_o))
159160
else:
160161
simulator = self._batched_simulator
161162

@@ -173,11 +174,10 @@ def simulator(theta):
173174
if not self.distance.requires_iid_data:
174175
x = x.squeeze(1)
175176
self.x_shape = x[0].shape
176-
self.x_o = process_x(x_o, self.x_shape)
177177
else:
178178
self.x_shape = x[0, 0].shape
179-
self.x_o = process_x(x_o, self.x_shape)
180179

180+
self.x_o = process_x(x_o, self.x_shape)
181181
distances = self.distance(self.x_o, x)
182182

183183
# Select based on acceptance threshold epsilon.
@@ -204,16 +204,18 @@ def simulator(theta):
204204
# Maybe adjust theta with LRA.
205205
if lra:
206206
self.logger.info("Running Linear regression adjustment.")
207-
final_theta = self.run_lra(theta_accepted, x_accepted, observation=self.x_o)
207+
final_theta = super()._run_lra(
208+
theta_accepted, x_accepted, observation=self.x_o
209+
)
208210
else:
209211
final_theta = theta_accepted
210212

211213
if kde:
212214
self.logger.info(
213-
"""KDE on %s samples with bandwidth option
214-
{kde_kwargs["bandwidth"] if "bandwidth" in kde_kwargs else "cv"}.
215-
Beware that KDE can give unreliable results when used with too few
216-
samples and in high dimensions.""",
215+
"KDE on %s samples with bandwidth option "
216+
f"{kde_kwargs.get('bandwidth', 'cv')}. "
217+
"Beware that KDE can give unreliable results when used with too few"
218+
" samples and in high dimensions.",
217219
final_theta.shape[0],
218220
)
219221

0 commit comments

Comments
 (0)