|
2 | 2 |
|
3 | 3 | import math |
4 | 4 | import warnings |
5 | | -from typing import Literal |
| 5 | +from typing import List, Literal, Tuple |
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 | from numba import njit |
@@ -281,3 +281,49 @@ def chunk_and_fortranize(X: np.ndarray, chunk_lb: int, chunk_ub: int, indices: n |
281 | 281 | for j in range(0, chunk_ub - chunk_lb): |
282 | 282 | chunk[i, j] = X[i, chunk_lb + j] |
283 | 283 | return chunk |
| 284 | + |
| 285 | + |
| 286 | +def compute_batch_bounds(n_genes: int, batch_size: Literal["auto"] | int, n_threads: int) -> List[Tuple[int, int]]: |
| 287 | + """Computes ideal batch bounds for processing genes in batches. |
| 288 | + This function ensures no worker is starving. This could happen if we have 8 workers but 9 batches to allocate. |
| 289 | + In this case, because each batch takes the same time to be processed, all but one workers will be idle waiting for one worker to process the last batch. |
| 290 | +
|
| 291 | + Args: |
| 292 | + n_genes (int): Total number of genes |
| 293 | + batch_size (Literal["auto"] | int): Batch size, or "auto" to compute ideal batch size. |
| 294 | + n_threads (int): Number of threads to use. |
| 295 | + Returns: |
| 296 | + List[Tuple[int, int]]: List of (lower_bound, upper_bound) for each batch. Upper bound is excluding, following slicing conventions. |
| 297 | + """ |
| 298 | + # No batching nor multithreading for small inputs |
| 299 | + if n_genes < n_threads or n_genes < 256: |
| 300 | + batch_size = n_genes |
| 301 | + # n_threads = 1 |
| 302 | + batch_size = n_genes |
| 303 | + bounds_iterator = [[0, n_genes]] |
| 304 | + elif isinstance(batch_size, int): |
| 305 | + # batch_size = min(batch_size, math.ceil(n_genes / n_threads)) |
| 306 | + bounds = list(range(0, n_genes + 1, batch_size)) |
| 307 | + if bounds[-1] != n_genes: |
| 308 | + bounds.append(n_genes) |
| 309 | + bounds_iterator = list(zip(bounds[:-1], bounds[1:])) |
| 310 | + elif batch_size == "auto": |
| 311 | + target_batch_size = 256 |
| 312 | + min_batches = (n_genes + target_batch_size - 1) // target_batch_size |
| 313 | + num_batches = ((min_batches + n_threads - 1) // n_threads) * n_threads |
| 314 | + base_size = n_genes // num_batches |
| 315 | + remainder = n_genes % num_batches |
| 316 | + bounds_iterator = [] |
| 317 | + start = 0 |
| 318 | + for i in range(num_batches): |
| 319 | + end = start + base_size + (1 if i < remainder else 0) |
| 320 | + bounds_iterator.append((start, end)) |
| 321 | + start = end |
| 322 | + # Append the last gene as the right bound is excluding |
| 323 | + if bounds_iterator[-1][1] != n_genes: |
| 324 | + bounds_iterator[-1][1] = n_genes |
| 325 | + batch_size = base_size |
| 326 | + else: |
| 327 | + raise ValueError(f"Invalid batch_size value: {batch_size}. Must be 'auto' or an integer.") |
| 328 | + |
| 329 | + return bounds_iterator, batch_size |
0 commit comments