Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ classifiers = [
dynamic = [ "version" ]
dependencies = [
"anndata>=0.10.8",
"fast-array-utils[accel,sparse]>=1.2.1",
"fast-array-utils[accel,sparse]>=1.4",
"h5py>=3.11",
"joblib",
"matplotlib>=3.9",
Expand Down Expand Up @@ -214,8 +214,8 @@ lint.pylint.max-args = 10
lint.pylint.max-positional-args = 5

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"numba.jit".msg = "Use `scanpy._compat.njit` instead"
"numba.njit".msg = "Use `scanpy._compat.njit` instead"
"numba.jit".msg = "Use `fast_array_utils.numba.njit` instead"
"numba.njit".msg = "Use `fast_array_utils.numba.njit` instead"
"numpy.bool_".msg = "Use `np.bool` instead"
"pandas.api.types.is_categorical_dtype".msg = "Use isinstance(s.dtype, CategoricalDtype) instead"
"pandas.value_counts".msg = "Use pd.Series(a).value_counts() instead"
Expand Down Expand Up @@ -292,7 +292,7 @@ run.source_pkgs = [ "scanpy" ]
paths.source = [ "src", "**/site-packages" ]
report.exclude_also = [
# https://github.com/numba/numba/issues/4268
"@(numba\\.|nb\\.)?njit.*",
"@([\\w.]+.)?njit.*",
"@deprecated.*",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
Expand Down
103 changes: 2 additions & 101 deletions src/scanpy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

import sys
import warnings
from functools import cache, partial, wraps
from functools import cache, partial
from importlib.util import find_spec
from pathlib import Path
from typing import TYPE_CHECKING, Literal, cast, overload
from typing import TYPE_CHECKING

from packaging.version import Version
from scipy import sparse

if TYPE_CHECKING:
from collections.abc import Callable
from importlib.metadata import PackageMetadata


Expand All @@ -21,10 +20,8 @@
"CSRBase",
"DaskArray",
"SpBase",
"_numba_threading_layer",
"deprecated",
"fullname",
"njit",
"pkg_metadata",
"pkg_version",
"warn",
Expand Down Expand Up @@ -111,99 +108,3 @@ def warn(
warnings.warn( # noqa: TID251
message, category, source=source, skip_file_prefixes=skip_file_prefixes
)


@overload
def njit[**P, R](fn: Callable[P, R], /) -> Callable[P, R]: ...
@overload
def njit[**P, R]() -> Callable[[Callable[P, R]], Callable[P, R]]: ...
def njit[**P, R](
fn: Callable[P, R] | None = None, /
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
"""Jit-compile a function using numba.

On call, this function dispatches to a parallel or sequential numba function,
depending on if it has been called from a thread pool.

See <https://github.com/numbagg/numbagg/pull/201/files#r1409374809>
"""

def decorator(f: Callable[P, R], /) -> Callable[P, R]:
import numba

fns: dict[bool, Callable[P, R]] = {
parallel: numba.njit(f, cache=True, parallel=parallel) # noqa: TID251
for parallel in (True, False)
}

@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
parallel = not _is_in_unsafe_thread_pool()
if not parallel:
msg = (
"Detected unsupported threading environment. "
f"Trying to run {f.__name__} in serial mode. "
"In case of problems, install `tbb`."
)
warn(msg, UserWarning)
return fns[parallel](*args, **kwargs)

return wrapper

return decorator if fn is None else decorator(fn)


type LayerType = Literal["default", "safe", "threadsafe", "forksafe"]
type Layer = Literal["tbb", "omp", "workqueue"]


LAYERS: dict[LayerType, set[Layer]] = {
"default": {"tbb", "omp", "workqueue"},
"safe": {"tbb"},
"threadsafe": {"tbb", "omp"},
"forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})},
}


def _is_in_unsafe_thread_pool() -> bool:
import threading

current_thread = threading.current_thread()
# ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1'
return (
current_thread.name.startswith("ThreadPoolExecutor")
and _numba_threading_layer() not in LAYERS["threadsafe"]
)


@cache
def _numba_threading_layer() -> Layer:
"""Get numba’s threading layer.

This function implements the algorithm as described in
<https://numba.readthedocs.io/en/stable/user/threading-layer.html>
"""
import importlib

import numba

if (available := LAYERS.get(numba.config.THREADING_LAYER)) is None:
# given by direct name
return numba.config.THREADING_LAYER

# given by layer type (safe, …)
for layer in cast("list[Layer]", numba.config.THREADING_LAYER_PRIORITY):
if layer not in available:
continue
if layer != "workqueue":
try: # `importlib.util.find_spec` doesn’t work here
importlib.import_module(f"numba.np.ufunc.{layer}pool")
except ImportError:
continue
# the layer has been found
return layer
msg = (
f"No loadable threading layer: {numba.config.THREADING_LAYER=} "
f" ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})"
)
raise ValueError(msg)
3 changes: 2 additions & 1 deletion src/scanpy/experimental/pp/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from fast_array_utils.numba import njit
from fast_array_utils.stats import mean_var

from ... import logging as logg
from ..._compat import CSBase, njit, warn
from ..._compat import CSBase, warn
from ..._settings import Verbosity, settings
from ..._utils import _doc_params, check_nonnegative_integers, view_to_actual
from ...experimental._docs import (
Expand Down
3 changes: 2 additions & 1 deletion src/scanpy/metrics/_gearys_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

import numba
import numpy as np
from fast_array_utils.numba import njit

from .._compat import CSRBase, njit
from .._compat import CSRBase
from .._utils import _doc_params
from ..get import _get_obs_rep
from ..neighbors._doc import doc_neighbors_key
Expand Down
3 changes: 2 additions & 1 deletion src/scanpy/metrics/_morans_i.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

import numba
import numpy as np
from fast_array_utils.numba import njit

from .._compat import CSRBase, njit
from .._compat import CSRBase
from .._utils import _doc_params
from ..get import _get_obs_rep
from ..neighbors._doc import doc_neighbors_key
Expand Down
3 changes: 2 additions & 1 deletion src/scanpy/preprocessing/_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import numpy as np
import pandas as pd
from fast_array_utils import stats
from fast_array_utils.numba import njit
from scipy import sparse

from scanpy.get import _get_obs_rep
from scanpy.preprocessing._distributed import materialize_as_ndarray

from .._compat import CSBase, CSRBase, DaskArray, njit, warn
from .._compat import CSBase, CSRBase, DaskArray, warn
from .._utils import _doc_params, axis_nnz
from ._docs import (
doc_adata_basic,
Expand Down
3 changes: 2 additions & 1 deletion src/scanpy/preprocessing/_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import numba
import numpy as np
from anndata import AnnData
from fast_array_utils.numba import njit
from fast_array_utils.stats import mean_var

from .. import logging as logg
from .._compat import CSBase, CSCBase, CSRBase, DaskArray, njit, warn
from .._compat import CSBase, CSCBase, CSRBase, DaskArray, warn
from .._settings import Default, settings
from .._utils import (
axis_mul_or_truediv,
Expand Down
3 changes: 2 additions & 1 deletion src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from anndata import AnnData
from fast_array_utils import stats
from fast_array_utils.conv import to_dense
from fast_array_utils.numba import njit
from numpy._typing._array_like import NDArray
from pandas.api.types import CategoricalDtype
from sklearn.utils import check_array

from .. import logging as logg
from .._compat import CSBase, CSRBase, DaskArray, njit
from .._compat import CSBase, CSRBase, DaskArray
from .._docs import doc_rng
from .._settings import settings
from .._utils import (
Expand Down
3 changes: 2 additions & 1 deletion src/scanpy/tools/_rank_genes_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import numba
import numpy as np
import pandas as pd
from fast_array_utils.numba import njit
from fast_array_utils.stats import mean_var
from scipy import sparse

from .. import _utils
from .. import logging as logg
from .._compat import CSBase, njit
from .._compat import CSBase
from .._settings import Default
from .._settings.presets import DETest
from .._utils import (
Expand Down
Loading