Skip to content

Commit a4a585f

Browse files
committed
WIP: parallel simulation refactoring
1 parent 649f5d3 commit a4a585f

File tree

2 files changed

+391
-48
lines changed

2 files changed

+391
-48
lines changed

sbi/utils/simulation_utils.py

Lines changed: 196 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Any, Callable, Optional, Tuple, Union
4+
import warnings
5+
from typing import Any, Callable, List, Optional, Tuple, Union
56

67
import numpy as np
78
import torch
8-
from joblib import Parallel, delayed
9-
from numpy import ndarray
9+
from joblib import Parallel, delayed, parallel_config
1010
from torch import Tensor, float32
1111
from tqdm.auto import tqdm
1212

1313
from sbi.utils.sbiutils import seed_all_backends
1414

15+
Data = Tensor | np.ndarray | List[str]
16+
Theta = Tensor | np.ndarray | List[Any]
17+
1518

1619
# Refactoring following #1175. tl:dr: letting joblib iterate over numpy arrays
1720
# allows for a roughly 10x performance gain. The resulting casting necessity
@@ -60,61 +63,206 @@ def simulate_for_sbi(
6063
"""
6164

6265
if num_simulations == 0:
63-
theta = torch.tensor([], dtype=float32)
64-
x = torch.tensor([], dtype=float32)
66+
return torch.tensor([], dtype=float32), torch.tensor([], dtype=float32)
67+
68+
seed_all_backends(seed)
69+
theta = proposal.sample((num_simulations,))
70+
# Cast to numpy for joblib efficiency
71+
theta_numpy = theta.cpu().numpy()
6572

73+
if simulation_batch_size is None:
74+
simulation_batch_size = num_simulations
6675
else:
67-
# Cast theta to numpy for better joblib performance (seee #1175)
68-
seed_all_backends(seed)
69-
theta = proposal.sample((num_simulations,))
70-
71-
# Parse the simulation_batch_size logic
72-
if simulation_batch_size is None:
73-
simulation_batch_size = num_simulations
74-
else:
75-
simulation_batch_size = min(simulation_batch_size, num_simulations)
76-
77-
if num_workers != 1:
78-
# For multiprocessing, we want to switch to numpy arrays.
79-
# The batch size will be an approximation, since np.array_split does
80-
# not take as argument the size of the batch but their total.
81-
num_batches = num_simulations // simulation_batch_size
82-
batches = np.array_split(theta.cpu().numpy(), num_batches, axis=0)
83-
batch_seeds = np.random.randint(low=0, high=1_000_000, size=(len(batches),))
76+
simulation_batch_size = min(simulation_batch_size, num_simulations)
8477

85-
# define seeded simulator.
86-
def simulator_seeded(theta: ndarray, seed: int) -> Tensor:
87-
seed_all_backends(seed)
88-
return simulator(theta)
89-
90-
try: # catch TypeError to give more informative error message
91-
simulation_outputs: list[Tensor] = [ # pyright: ignore
92-
xx
93-
for xx in tqdm(
94-
Parallel(return_as="generator", n_jobs=num_workers)(
95-
delayed(simulator_seeded)(batch, seed)
96-
for batch, seed in zip(batches, batch_seeds, strict=False)
97-
),
98-
total=num_simulations,
99-
disable=not show_progress_bar,
100-
)
101-
]
102-
except TypeError as err:
78+
# Handle parallel context
79+
context = parallel_config(n_jobs=num_workers)
80+
81+
with context:
82+
# We enforce simulator_is_batched=True because simulate_for_sbi semantics
83+
# implies that the simulator receives batches (even if size 1).
84+
try:
85+
theta, x = simulate_from_thetas(
86+
simulator,
87+
theta_numpy,
88+
simulation_batch_size=simulation_batch_size,
89+
simulator_is_batched=True,
90+
show_progress_bar=show_progress_bar,
91+
seed=seed,
92+
)
93+
except TypeError as err:
94+
if num_workers > 1:
10395
raise TypeError(
10496
"There is a TypeError error in your simulator function. Note: For"
10597
" multiprocessing, we switch to numpy arrays. Besides confirming"
10698
" your simulator works correctly, make sure to preprocess your"
10799
" simulator with `process_simulator` to handle numpy arrays."
108100
) from err
101+
else:
102+
raise err
109103

110-
else:
111-
simulation_outputs: list[Tensor] = []
112-
batches = torch.split(theta, simulation_batch_size)
113-
for batch in tqdm(batches, disable=not show_progress_bar):
114-
simulation_outputs.append(simulator(batch))
104+
# Correctly format the output to Tensor
105+
theta = torch.as_tensor(theta, dtype=float32)
115106

116-
# Correctly format the output
117-
x = torch.cat(simulation_outputs, dim=0)
118-
theta = torch.as_tensor(theta, dtype=float32)
107+
if isinstance(x, np.ndarray):
108+
x = torch.from_numpy(x)
119109

120110
return theta, x
111+
112+
113+
def parallelize_simulator(
114+
simulator: Optional[Callable[[Theta], Data]] = None,
115+
simulator_is_batched: bool = False,
116+
simulation_batch_size: int = 10,
117+
show_progress_bar: bool = True,
118+
seed: Optional[int] = None,
119+
) -> Union[
120+
Callable[[Theta], Data],
121+
Callable[[Callable[[Theta], Data]], Callable[[Theta], Data]],
122+
]:
123+
r"""
124+
Returns a function that executes simulations in parallel for a given set of
125+
parameters. Can be used as a function or a decorator.
126+
127+
Args:
128+
simulator: Function to run simulations.
129+
simulator_is_batched: Whether the simulator can handle batches directly.
130+
simulation_batch_size: Number of simulations to run in each batch.
131+
show_progress_bar: Whether to show tqdm progress bar.
132+
seed: Random seed.
133+
134+
Returns:
135+
Callable that takes a set of :math:`\theta` and returns simulation outputs.
136+
"""
137+
138+
def decorator(simulator_func: Callable[[Theta], Data]) -> Callable[[Theta], Data]:
139+
warnings.warn(
140+
"Joblib is used for parallelization. It is recommended to use numpy arrays "
141+
"for the simulator input and output to avoid serialization overhead with "
142+
"torch tensors.",
143+
UserWarning,
144+
stacklevel=2,
145+
)
146+
147+
def parallel_simulator(thetas: Theta) -> Data:
148+
seed_all_backends(seed)
149+
150+
num_simulations = len(thetas)
151+
152+
if num_simulations == 0:
153+
return torch.tensor([], dtype=float32)
154+
155+
# Create batches
156+
if simulator_is_batched:
157+
num_batches = (
158+
num_simulations + simulation_batch_size - 1
159+
) // simulation_batch_size
160+
batches = [
161+
thetas[i * simulation_batch_size : (i + 1) * simulation_batch_size]
162+
for i in range(num_batches)
163+
]
164+
elif simulation_batch_size > 1:
165+
warnings.warn(
166+
"Simulation batch size is greater than 1, but simulator_is_batched "
167+
"is False. Simulations will be run sequentially (batch size 1).",
168+
UserWarning,
169+
stacklevel=2,
170+
)
171+
batches = [theta for theta in thetas]
172+
else:
173+
batches = [theta for theta in thetas]
174+
175+
# Run in parallel
176+
# Generate seeds
177+
batch_seeds = np.random.randint(low=0, high=1_000_000, size=(len(batches),))
178+
179+
def run_simulation(batch, seed):
180+
seed_all_backends(seed)
181+
return simulator_func(batch)
182+
183+
# Execute in parallel with joblib
184+
results = Parallel(return_as="generator")(
185+
delayed(run_simulation)(batch, seed)
186+
for batch, seed in zip(batches, batch_seeds, strict=False)
187+
)
188+
189+
# Progress bar
190+
simulation_outputs = []
191+
if show_progress_bar:
192+
pbar = tqdm(total=num_simulations)
193+
194+
for i, res in enumerate(results):
195+
simulation_outputs.append(res)
196+
if show_progress_bar:
197+
pbar.update(len(batches[i]))
198+
199+
if show_progress_bar:
200+
pbar.close()
201+
202+
# Flatten results
203+
output_data = []
204+
if simulator_is_batched:
205+
for batch_out in simulation_outputs:
206+
if isinstance(batch_out, (list, tuple)):
207+
output_data.extend(batch_out)
208+
elif isinstance(batch_out, (torch.Tensor, np.ndarray)):
209+
output_data.extend([x for x in batch_out])
210+
else:
211+
output_data.append(batch_out)
212+
else:
213+
output_data = simulation_outputs
214+
215+
if not output_data:
216+
return torch.tensor([], dtype=float32)
217+
218+
# Handle file paths (strings)
219+
if isinstance(output_data[0], (str, np.str_, np.bytes_)):
220+
output_data = [str(f) for f in output_data]
221+
return output_data
222+
223+
if isinstance(output_data[0], torch.Tensor):
224+
return torch.stack(output_data)
225+
elif isinstance(output_data[0], np.ndarray):
226+
return np.stack(output_data)
227+
228+
return output_data
229+
230+
return parallel_simulator
231+
232+
if simulator is None:
233+
return decorator
234+
235+
return decorator(simulator)
236+
237+
238+
def simulate_from_thetas(
239+
simulator: Callable[[Theta], Data],
240+
thetas: Theta,
241+
simulator_is_batched: bool = False,
242+
simulation_batch_size: int = 10,
243+
show_progress_bar: bool = True,
244+
seed: Optional[int] = None,
245+
) -> Tuple[Tensor | np.ndarray, Data]:
246+
r"""
247+
Execute simulations for a given set of parameters.
248+
249+
Args:
250+
simulator: Function to run simulations.
251+
thetas: Parameters to simulate (Tensor, Numpy array, or list).
252+
simulator_is_batched: Whether the simulator can handle batches directly.
253+
simulation_batch_size: Number of simulations to run in each batch.
254+
show_progress_bar: Whether to show tqdm progress bar.
255+
seed: Random seed.
256+
257+
Returns:
258+
Tuple of (:math:`\theta`, simulation_outputs).
259+
"""
260+
parallel_sim = parallelize_simulator(
261+
simulator,
262+
simulator_is_batched=simulator_is_batched,
263+
simulation_batch_size=simulation_batch_size,
264+
show_progress_bar=show_progress_bar,
265+
seed=seed,
266+
)
267+
268+
return thetas, parallel_sim(thetas)

0 commit comments

Comments
 (0)