|
1 | 1 | # This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed |
2 | 2 | # under the Apache License Version 2.0, see <https://www.apache.org/licenses/> |
3 | 3 |
|
4 | | -from typing import Any, Callable, Optional, Tuple, Union |
| 4 | +import warnings |
| 5 | +from typing import Any, Callable, List, Optional, Tuple, Union |
5 | 6 |
|
6 | 7 | import numpy as np |
7 | 8 | import torch |
8 | | -from joblib import Parallel, delayed |
9 | | -from numpy import ndarray |
| 9 | +from joblib import Parallel, delayed, parallel_config |
10 | 10 | from torch import Tensor, float32 |
11 | 11 | from tqdm.auto import tqdm |
12 | 12 |
|
13 | 13 | from sbi.utils.sbiutils import seed_all_backends |
14 | 14 |
|
| 15 | +Data = Tensor | np.ndarray | List[str] |
| 16 | +Theta = Tensor | np.ndarray | List[Any] |
| 17 | + |
15 | 18 |
|
16 | 19 | # Refactoring following #1175. tl:dr: letting joblib iterate over numpy arrays |
17 | 20 | # allows for a roughly 10x performance gain. The resulting casting necessity |
@@ -60,61 +63,206 @@ def simulate_for_sbi( |
60 | 63 | """ |
61 | 64 |
|
62 | 65 | if num_simulations == 0: |
63 | | - theta = torch.tensor([], dtype=float32) |
64 | | - x = torch.tensor([], dtype=float32) |
| 66 | + return torch.tensor([], dtype=float32), torch.tensor([], dtype=float32) |
| 67 | + |
| 68 | + seed_all_backends(seed) |
| 69 | + theta = proposal.sample((num_simulations,)) |
| 70 | + # Cast to numpy for joblib efficiency |
| 71 | + theta_numpy = theta.cpu().numpy() |
65 | 72 |
|
| 73 | + if simulation_batch_size is None: |
| 74 | + simulation_batch_size = num_simulations |
66 | 75 | else: |
67 | | - # Cast theta to numpy for better joblib performance (seee #1175) |
68 | | - seed_all_backends(seed) |
69 | | - theta = proposal.sample((num_simulations,)) |
70 | | - |
71 | | - # Parse the simulation_batch_size logic |
72 | | - if simulation_batch_size is None: |
73 | | - simulation_batch_size = num_simulations |
74 | | - else: |
75 | | - simulation_batch_size = min(simulation_batch_size, num_simulations) |
76 | | - |
77 | | - if num_workers != 1: |
78 | | - # For multiprocessing, we want to switch to numpy arrays. |
79 | | - # The batch size will be an approximation, since np.array_split does |
80 | | - # not take as argument the size of the batch but their total. |
81 | | - num_batches = num_simulations // simulation_batch_size |
82 | | - batches = np.array_split(theta.cpu().numpy(), num_batches, axis=0) |
83 | | - batch_seeds = np.random.randint(low=0, high=1_000_000, size=(len(batches),)) |
| 76 | + simulation_batch_size = min(simulation_batch_size, num_simulations) |
84 | 77 |
|
85 | | - # define seeded simulator. |
86 | | - def simulator_seeded(theta: ndarray, seed: int) -> Tensor: |
87 | | - seed_all_backends(seed) |
88 | | - return simulator(theta) |
89 | | - |
90 | | - try: # catch TypeError to give more informative error message |
91 | | - simulation_outputs: list[Tensor] = [ # pyright: ignore |
92 | | - xx |
93 | | - for xx in tqdm( |
94 | | - Parallel(return_as="generator", n_jobs=num_workers)( |
95 | | - delayed(simulator_seeded)(batch, seed) |
96 | | - for batch, seed in zip(batches, batch_seeds, strict=False) |
97 | | - ), |
98 | | - total=num_simulations, |
99 | | - disable=not show_progress_bar, |
100 | | - ) |
101 | | - ] |
102 | | - except TypeError as err: |
| 78 | + # Handle parallel context |
| 79 | + context = parallel_config(n_jobs=num_workers) |
| 80 | + |
| 81 | + with context: |
| 82 | + # We enforce simulator_is_batched=True because simulate_for_sbi semantics |
| 83 | + # implies that the simulator receives batches (even if size 1). |
| 84 | + try: |
| 85 | + theta, x = simulate_from_thetas( |
| 86 | + simulator, |
| 87 | + theta_numpy, |
| 88 | + simulation_batch_size=simulation_batch_size, |
| 89 | + simulator_is_batched=True, |
| 90 | + show_progress_bar=show_progress_bar, |
| 91 | + seed=seed, |
| 92 | + ) |
| 93 | + except TypeError as err: |
| 94 | + if num_workers > 1: |
103 | 95 | raise TypeError( |
104 | 96 | "There is a TypeError error in your simulator function. Note: For" |
105 | 97 | " multiprocessing, we switch to numpy arrays. Besides confirming" |
106 | 98 | " your simulator works correctly, make sure to preprocess your" |
107 | 99 | " simulator with `process_simulator` to handle numpy arrays." |
108 | 100 | ) from err |
| 101 | + else: |
| 102 | + raise err |
109 | 103 |
|
110 | | - else: |
111 | | - simulation_outputs: list[Tensor] = [] |
112 | | - batches = torch.split(theta, simulation_batch_size) |
113 | | - for batch in tqdm(batches, disable=not show_progress_bar): |
114 | | - simulation_outputs.append(simulator(batch)) |
| 104 | + # Correctly format the output to Tensor |
| 105 | + theta = torch.as_tensor(theta, dtype=float32) |
115 | 106 |
|
116 | | - # Correctly format the output |
117 | | - x = torch.cat(simulation_outputs, dim=0) |
118 | | - theta = torch.as_tensor(theta, dtype=float32) |
| 107 | + if isinstance(x, np.ndarray): |
| 108 | + x = torch.from_numpy(x) |
119 | 109 |
|
120 | 110 | return theta, x |
| 111 | + |
| 112 | + |
| 113 | +def parallelize_simulator( |
| 114 | + simulator: Optional[Callable[[Theta], Data]] = None, |
| 115 | + simulator_is_batched: bool = False, |
| 116 | + simulation_batch_size: int = 10, |
| 117 | + show_progress_bar: bool = True, |
| 118 | + seed: Optional[int] = None, |
| 119 | +) -> Union[ |
| 120 | + Callable[[Theta], Data], |
| 121 | + Callable[[Callable[[Theta], Data]], Callable[[Theta], Data]], |
| 122 | +]: |
| 123 | + r""" |
| 124 | + Returns a function that executes simulations in parallel for a given set of |
| 125 | + parameters. Can be used as a function or a decorator. |
| 126 | +
|
| 127 | + Args: |
| 128 | + simulator: Function to run simulations. |
| 129 | + simulator_is_batched: Whether the simulator can handle batches directly. |
| 130 | + simulation_batch_size: Number of simulations to run in each batch. |
| 131 | + show_progress_bar: Whether to show tqdm progress bar. |
| 132 | + seed: Random seed. |
| 133 | +
|
| 134 | + Returns: |
| 135 | + Callable that takes a set of :math:`\theta` and returns simulation outputs. |
| 136 | + """ |
| 137 | + |
| 138 | + def decorator(simulator_func: Callable[[Theta], Data]) -> Callable[[Theta], Data]: |
| 139 | + warnings.warn( |
| 140 | + "Joblib is used for parallelization. It is recommended to use numpy arrays " |
| 141 | + "for the simulator input and output to avoid serialization overhead with " |
| 142 | + "torch tensors.", |
| 143 | + UserWarning, |
| 144 | + stacklevel=2, |
| 145 | + ) |
| 146 | + |
| 147 | + def parallel_simulator(thetas: Theta) -> Data: |
| 148 | + seed_all_backends(seed) |
| 149 | + |
| 150 | + num_simulations = len(thetas) |
| 151 | + |
| 152 | + if num_simulations == 0: |
| 153 | + return torch.tensor([], dtype=float32) |
| 154 | + |
| 155 | + # Create batches |
| 156 | + if simulator_is_batched: |
| 157 | + num_batches = ( |
| 158 | + num_simulations + simulation_batch_size - 1 |
| 159 | + ) // simulation_batch_size |
| 160 | + batches = [ |
| 161 | + thetas[i * simulation_batch_size : (i + 1) * simulation_batch_size] |
| 162 | + for i in range(num_batches) |
| 163 | + ] |
| 164 | + elif simulation_batch_size > 1: |
| 165 | + warnings.warn( |
| 166 | + "Simulation batch size is greater than 1, but simulator_is_batched " |
| 167 | + "is False. Simulations will be run sequentially (batch size 1).", |
| 168 | + UserWarning, |
| 169 | + stacklevel=2, |
| 170 | + ) |
| 171 | + batches = [theta for theta in thetas] |
| 172 | + else: |
| 173 | + batches = [theta for theta in thetas] |
| 174 | + |
| 175 | + # Run in parallel |
| 176 | + # Generate seeds |
| 177 | + batch_seeds = np.random.randint(low=0, high=1_000_000, size=(len(batches),)) |
| 178 | + |
| 179 | + def run_simulation(batch, seed): |
| 180 | + seed_all_backends(seed) |
| 181 | + return simulator_func(batch) |
| 182 | + |
| 183 | + # Execute in parallel with joblib |
| 184 | + results = Parallel(return_as="generator")( |
| 185 | + delayed(run_simulation)(batch, seed) |
| 186 | + for batch, seed in zip(batches, batch_seeds, strict=False) |
| 187 | + ) |
| 188 | + |
| 189 | + # Progress bar |
| 190 | + simulation_outputs = [] |
| 191 | + if show_progress_bar: |
| 192 | + pbar = tqdm(total=num_simulations) |
| 193 | + |
| 194 | + for i, res in enumerate(results): |
| 195 | + simulation_outputs.append(res) |
| 196 | + if show_progress_bar: |
| 197 | + pbar.update(len(batches[i])) |
| 198 | + |
| 199 | + if show_progress_bar: |
| 200 | + pbar.close() |
| 201 | + |
| 202 | + # Flatten results |
| 203 | + output_data = [] |
| 204 | + if simulator_is_batched: |
| 205 | + for batch_out in simulation_outputs: |
| 206 | + if isinstance(batch_out, (list, tuple)): |
| 207 | + output_data.extend(batch_out) |
| 208 | + elif isinstance(batch_out, (torch.Tensor, np.ndarray)): |
| 209 | + output_data.extend([x for x in batch_out]) |
| 210 | + else: |
| 211 | + output_data.append(batch_out) |
| 212 | + else: |
| 213 | + output_data = simulation_outputs |
| 214 | + |
| 215 | + if not output_data: |
| 216 | + return torch.tensor([], dtype=float32) |
| 217 | + |
| 218 | + # Handle file paths (strings) |
| 219 | + if isinstance(output_data[0], (str, np.str_, np.bytes_)): |
| 220 | + output_data = [str(f) for f in output_data] |
| 221 | + return output_data |
| 222 | + |
| 223 | + if isinstance(output_data[0], torch.Tensor): |
| 224 | + return torch.stack(output_data) |
| 225 | + elif isinstance(output_data[0], np.ndarray): |
| 226 | + return np.stack(output_data) |
| 227 | + |
| 228 | + return output_data |
| 229 | + |
| 230 | + return parallel_simulator |
| 231 | + |
| 232 | + if simulator is None: |
| 233 | + return decorator |
| 234 | + |
| 235 | + return decorator(simulator) |
| 236 | + |
| 237 | + |
| 238 | +def simulate_from_thetas( |
| 239 | + simulator: Callable[[Theta], Data], |
| 240 | + thetas: Theta, |
| 241 | + simulator_is_batched: bool = False, |
| 242 | + simulation_batch_size: int = 10, |
| 243 | + show_progress_bar: bool = True, |
| 244 | + seed: Optional[int] = None, |
| 245 | +) -> Tuple[Tensor | np.ndarray, Data]: |
| 246 | + r""" |
| 247 | + Execute simulations for a given set of parameters. |
| 248 | +
|
| 249 | + Args: |
| 250 | + simulator: Function to run simulations. |
| 251 | + thetas: Parameters to simulate (Tensor, Numpy array, or list). |
| 252 | + simulator_is_batched: Whether the simulator can handle batches directly. |
| 253 | + simulation_batch_size: Number of simulations to run in each batch. |
| 254 | + show_progress_bar: Whether to show tqdm progress bar. |
| 255 | + seed: Random seed. |
| 256 | +
|
| 257 | + Returns: |
| 258 | + Tuple of (:math:`\theta`, simulation_outputs). |
| 259 | + """ |
| 260 | + parallel_sim = parallelize_simulator( |
| 261 | + simulator, |
| 262 | + simulator_is_batched=simulator_is_batched, |
| 263 | + simulation_batch_size=simulation_batch_size, |
| 264 | + show_progress_bar=show_progress_bar, |
| 265 | + seed=seed, |
| 266 | + ) |
| 267 | + |
| 268 | + return thetas, parallel_sim(thetas) |
0 commit comments