Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 122 additions & 4 deletions src/fast_array_utils/numba.py
Comment thread
JhonatanFelix marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

from __future__ import annotations

import os
import platform
import subprocess
import sys
import warnings
from functools import cache, update_wrapper, wraps
Expand All @@ -29,6 +32,7 @@
"""Identifier for a threading layer category."""
type ThreadingLayer = Literal["tbb", "omp", "workqueue"]
"""Identifier for a concrete threading layer."""
type _ParallelRuntimeProbeKey = tuple[str, ThreadingLayer | TheadingCategory, tuple[ThreadingLayer, ...], tuple[str, ...]]


LAYERS: dict[TheadingCategory, set[ThreadingLayer]] = {
Expand All @@ -39,6 +43,11 @@
}


_PARALLEL_RUNTIME_PROBE_SENTINEL = "FAST_ARRAY_UTILS_NUMBA_PROBE_OK"
_PARALLEL_RUNTIME_PROBE_TIMEOUT = 20
_PARALLEL_RUNTIME_PROBE_MODULE_WHITELIST = ("torch",)


def threading_layer(layer_or_category: ThreadingLayer | TheadingCategory | None = None, /, priority: Iterable[ThreadingLayer] | None = None) -> ThreadingLayer:
"""Get numba’s configured threading layer as specified in :ref:`numba-threading-layer`.

Expand Down Expand Up @@ -84,6 +93,108 @@ def _is_in_unsafe_thread_pool() -> bool:
return current_thread.name.startswith("ThreadPoolExecutor") and threading_layer() not in LAYERS["threadsafe"]


def _is_apple_silicon() -> bool:
return sys.platform == "darwin" and platform.machine() == "arm64"


def _is_torch_loaded() -> bool:
return "torch" in sys.modules


def _configured_threading_layer_or_category_without_probing() -> ThreadingLayer | TheadingCategory:
Comment thread
JhonatanFelix marked this conversation as resolved.
Outdated
import numba

# Avoid `threading_layer()` here: resolving backends may import pool modules.
return numba.config.THREADING_LAYER


def _configured_explicit_threading_layer_without_probing() -> ThreadingLayer | None:
layer_or_category = _configured_threading_layer_or_category_without_probing()
return layer_or_category if layer_or_category in LAYERS["default"] else None


def _is_explicit_safe_threading_layer() -> bool:
return _configured_explicit_threading_layer_without_probing() in {"workqueue", "tbb"}


def _could_select_omp_from_threading_config_without_probing() -> bool:
if (layer := _configured_explicit_threading_layer_without_probing()) is not None:
return layer == "omp"
return "omp" in LAYERS[cast("TheadingCategory", _configured_threading_layer_or_category_without_probing())]


def _needs_parallel_runtime_probe() -> bool:
if not _is_apple_silicon() or not _is_torch_loaded():
return False
if _is_explicit_safe_threading_layer():
return False
return _could_select_omp_from_threading_config_without_probing()


def _loaded_relevant_parallel_runtime_probe_modules() -> tuple[str, ...]:
return tuple(module for module in _PARALLEL_RUNTIME_PROBE_MODULE_WHITELIST if module in sys.modules)


def _parallel_runtime_probe_code(modules: tuple[str, ...]) -> str:
lines = [*(f"import {module}" for module in modules), "import numba", "import numpy as np", ""]
lines.extend([
"@numba.njit(parallel=True, cache=False)",
"def _probe(values):",
" total = 0.0",
" for i in numba.prange(values.shape[0]):",
" total += values[i]",
" return total",
"",
"values = np.arange(32, dtype=np.float64)",
"assert _probe(values) == np.sum(values)",
f"print({_PARALLEL_RUNTIME_PROBE_SENTINEL!r})",
"",
])
return "\n".join(lines)


def _parallel_runtime_probe_key() -> _ParallelRuntimeProbeKey:
import numba

return (
sys.executable,
_configured_threading_layer_or_category_without_probing(),
tuple(cast("Iterable[ThreadingLayer]", numba.config.THREADING_LAYER_PRIORITY)),
_loaded_relevant_parallel_runtime_probe_modules(),
)


def _build_parallel_runtime_probe_env(key: _ParallelRuntimeProbeKey | None = None) -> dict[str, str]:
_, layer_or_category, priority, _ = _parallel_runtime_probe_key() if key is None else key
env = dict(os.environ)
env["NUMBA_THREADING_LAYER"] = layer_or_category
env["NUMBA_THREADING_LAYER_PRIORITY"] = " ".join(priority)
return env


@cache
def _parallel_numba_runtime_is_safe_cached(key: _ParallelRuntimeProbeKey) -> bool:
try:
# The probe command is built from `sys.executable` plus a generated script
# that only imports modules from a fixed whitelist.
result = subprocess.run( # noqa: S603
[key[0], "-c", _parallel_runtime_probe_code(key[3])],
capture_output=True,
check=False,
env=_build_parallel_runtime_probe_env(key),
text=True,
timeout=_PARALLEL_RUNTIME_PROBE_TIMEOUT,
)
except Exception: # noqa: BLE001
# Any probe failure should conservatively disable the parallel fast-path.
return False
return result.returncode == 0 and _PARALLEL_RUNTIME_PROBE_SENTINEL in result.stdout


def _parallel_numba_runtime_is_safe() -> bool:
return _parallel_numba_runtime_is_safe_cached(_parallel_runtime_probe_key())


@overload
def njit[**P, R](fn: Callable[P, R], /) -> Callable[P, R]: ...
@overload
Expand All @@ -92,7 +203,7 @@ def njit[**P, R](fn: Callable[P, R] | None = None, /) -> Callable[P, R] | Callab
"""Jit-compile a function using numba.

On call, this function dispatches to a parallel or serial numba function,
depending on if it has been called from a thread pool.
depending on the current threading environment.
"""
# See https://github.com/numbagg/numbagg/pull/201/files#r1409374809

Expand All @@ -109,11 +220,18 @@ def decorator(f: Callable[P, R], /) -> Callable[P, R]:

@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
parallel = not _is_in_unsafe_thread_pool()
if not parallel: # pragma: no cover
if _is_in_unsafe_thread_pool(): # pragma: no cover
msg = f"Detected unsupported threading environment. Trying to run {f.__name__} in serial mode. In case of problems, install `tbb`."
warnings.warn(msg, UserWarning, stacklevel=2)
return fns[parallel](*args, **kwargs)
return fns[False](*args, **kwargs)
if _needs_parallel_runtime_probe() and not _parallel_numba_runtime_is_safe():
msg = (
f"Detected an unsupported numba parallel runtime. Running {f.__name__} in serial mode as a workaround. "
"Set `NUMBA_THREADING_LAYER=workqueue` or install `tbb` to avoid this fallback."
)
warnings.warn(msg, UserWarning, stacklevel=2)
return fns[False](*args, **kwargs)
return fns[True](*args, **kwargs)

return wrapper

Expand Down
Loading
Loading