Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"scanpydoc.elegant_typehints",
"sphinx_autofixture",
]

# API documentation when building
nitpicky = True
autosummary_generate = True
autodoc_member_order = "bysource"
autodoc_default_options = {
"special-members": True,
# everything except __call__ really, to avoid having to write autosummary templates
"exclude-members": "__setattr__,__delattr__,__repr__,__eq__,__hash__,__weakref__,__init__",
}
napoleon_google_docstring = False
napoleon_numpy_docstring = True
todo_include_todos = False
Expand All @@ -55,9 +61,11 @@
"np.dtype": "numpy.dtype",
"np.number": "numpy.number",
"np.integer": "numpy.integer",
"np.random.Generator": "numpy.random.Generator",
"ArrayLike": "numpy.typing.ArrayLike",
"DTypeLike": "numpy.typing.DTypeLike",
"NDArray": "numpy.typing.NDArray",
"_pytest.fixtures.FixtureRequest": "pytest.FixtureRequest",
**{
k: v
for k_plain, v in {
Expand All @@ -74,10 +82,17 @@
# If that doesn’t work, ignore them
nitpick_ignore = {
("py:class", "fast_array_utils.types.T_co"),
("py:class", "Arr"),
("py:class", "testing.fast_array_utils._array_type.Arr"),
("py:class", "testing.fast_array_utils._array_type.Inner"),
("py:class", "_DTypeLikeFloat32"),
("py:class", "_DTypeLikeFloat64"),
# sphinx bugs, should be covered by `autodoc_type_aliases` above
("py:class", "Array"),
("py:class", "ArrayLike"),
("py:class", "DTypeLike"),
("py:class", "NDArray"),
("py:class", "_pytest.fixtures.FixtureRequest"),
}

# Options for HTML output
Expand Down
7 changes: 6 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
``fast_array_utils``
====================

.. toctree::
:hidden:

fast-array-utils <self>
testing

.. automodule:: fast_array_utils
:members:


``fast_array_utils.conv``
-------------------------

Expand Down
11 changes: 11 additions & 0 deletions docs/testing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
``testing.fast_array_utils``
============================

.. automodule:: testing.fast_array_utils
:members:

``testing.fast_array_utils.pytest``
-----------------------------------

.. automodule:: testing.fast_array_utils.pytest
:members:
19 changes: 17 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,33 @@ classifiers = [
]
dynamic = [ "description", "version" ]
dependencies = [ "numba", "numpy" ]
optional-dependencies.doc = [ "furo", "scanpydoc>=0.15.2", "sphinx>=8", "sphinx-autodoc-typehints" ]
optional-dependencies.doc = [
"furo",
"pytest",
"scanpydoc>=0.15.2",
"sphinx>=8",
"sphinx-autodoc-typehints",
"sphinx-autofixture",
]
optional-dependencies.full = [ "dask", "fast-array-utils[sparse]", "h5py", "zarr" ]
optional-dependencies.sparse = [ "scipy>=1.8" ]
optional-dependencies.test = [ "coverage[toml]", "pytest", "pytest-codspeed" ]
urls.'Documentation' = "https://icb-fast-array-utils.readthedocs-hosted.com/"
urls.'Issue Tracker' = "https://github.com/scverse/fast-array-utils/issues"
urls.'Source Code' = "https://github.com/scverse/fast-array-utils"

[tool.hatch.metadata.hooks.docstring-description]
entry_points.pytest11.fast_array_utils = "testing.fast_array_utils.pytest"

[tool.hatch.version]
source = "vcs"
raw-options = { local_scheme = "no-local-version" } # be able to publish dev version

# TODO: support setting main package in the plugin
# [tool.hatch.metadata.hooks.docstring-description]

[tool.hatch.build.targets.wheel]
packages = [ "src/testing", "src/fast_array_utils" ]

[tool.hatch.envs.default]
installer = "uv"

Expand Down Expand Up @@ -85,6 +98,8 @@ lint.per-file-ignores."tests/**/test_*.py" = [
"S101", # tests use `assert`
]
lint.allowed-confusables = [ "×", "’" ]
lint.flake8-bugbear.extend-immutable-calls = [ "testing.fast_array_utils.Flags" ]

lint.flake8-copyright.notice-rgx = "SPDX-License-Identifier: MPL-2\\.0"
lint.flake8-type-checking.exempt-modules = [ ]
lint.flake8-type-checking.strict = true
Expand Down
161 changes: 34 additions & 127 deletions src/testing/fast_array_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,135 +3,42 @@

from __future__ import annotations

import re
from typing import TYPE_CHECKING

import numpy as np
from ._array_type import ArrayType, ConversionContext, Flags, random_mat


if TYPE_CHECKING:
from typing import Any, Literal, Protocol, SupportsFloat, TypeAlias

from numpy.typing import ArrayLike, DTypeLike, NDArray

from fast_array_utils import types
from fast_array_utils.types import CSBase

Array: TypeAlias = (
NDArray[Any]
| types.CSBase
| types.CupyArray
| types.DaskArray
| types.H5Dataset
| types.ZarrArray
)

class ToArray(Protocol):
"""Convert to a supported array."""

def __call__( # noqa: D102
self, data: ArrayLike, /, *, dtype: DTypeLike | None = None
) -> Array: ...

_DTypeLikeFloat32 = np.dtype[np.float32] | type[np.float32]
_DTypeLikeFloat64 = np.dtype[np.float64] | type[np.float64]


RE_ARRAY_QUAL = re.compile(r"(?P<mod>(?:\w+\.)*\w+)\.(?P<name>[^\[]+)(?:\[(?P<inner>[\w.]+)\])?")


def get_array_cls(qualname: str) -> type[Array]: # noqa: PLR0911
"""Get a supported array class by qualname."""
m = RE_ARRAY_QUAL.fullmatch(qualname)
assert m
match m["mod"], m["name"], m["inner"]:
case "numpy", "ndarray", None:
return np.ndarray
case "scipy.sparse", (
"csr_array" | "csc_array" | "csr_matrix" | "csc_matrix"
) as cls_name, None:
import scipy.sparse

return getattr(scipy.sparse, cls_name) # type: ignore[no-any-return]
case "cupy", "ndarray", None:
import cupy as cp

return cp.ndarray # type: ignore[no-any-return]
case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None:
import cupyx.scipy.sparse as cu_sparse

return getattr(cu_sparse, cls_name) # type: ignore[no-any-return]
case "dask.array", cls_name, _:
if TYPE_CHECKING:
from dask.array.core import Array as DaskArray
else:
from dask.array import Array as DaskArray

return DaskArray
case "h5py", "Dataset", _:
import h5py

return h5py.Dataset # type: ignore[no-any-return]
case "zarr", "Array", _:
import zarr

return zarr.Array
case _:
msg = f"Unknown array class: {qualname}"
raise ValueError(msg)


def random_mat(
shape: tuple[int, int],
*,
density: SupportsFloat = 0.01,
format: Literal["csr", "csc"] = "csr", # noqa: A002
dtype: DTypeLike | None = None,
container: Literal["array", "matrix"] = "array",
gen: np.random.Generator | None = None,
) -> CSBase:
"""Create a random matrix."""
from scipy.sparse import random as random_spmat
from scipy.sparse import random_array as random_sparr

m, n = shape
return (
random_spmat(m, n, density=density, format=format, dtype=dtype, random_state=gen)
if container == "matrix"
else random_sparr(shape, density=density, format=format, dtype=dtype, random_state=gen)
)


def random_array(
qualname: str,
shape: tuple[int, int],
*,
dtype: _DTypeLikeFloat32 | _DTypeLikeFloat64 | None,
gen: np.random.Generator | None = None,
) -> Array:
"""Create a random array."""
gen = np.random.default_rng(gen)

m = RE_ARRAY_QUAL.fullmatch(qualname)
assert m
match m["mod"], m["name"], m["inner"]:
case "numpy", "ndarray", None:
return gen.random(shape, dtype=dtype or np.float64)
case "scipy.sparse", (
"csr_array" | "csc_array" | "csr_matrix" | "csc_matrix"
) as cls_name, None:
fmt, container = cls_name.split("_")
return random_mat(shape, format=fmt, container=container, dtype=dtype) # type: ignore[arg-type]
case "cupy", "ndarray", None:
raise NotImplementedError
case "cupyx.scipy.sparse", ("csr_matrix" | "csc_matrix") as cls_name, None:
raise NotImplementedError
case "dask.array", cls_name, _:
raise NotImplementedError
case "h5py", "Dataset", _:
raise NotImplementedError
case "zarr", "Array", _:
raise NotImplementedError
case _:
msg = f"Unknown array class: {qualname}"
raise ValueError(msg)
from ._array_type import Array, ToArray # noqa: TC004


__all__ = [
"SUPPORTED_TYPES",
"Array",
"ArrayType",
"ConversionContext",
"Flags",
"ToArray",
"random_mat",
]


_TP_MEM = (
ArrayType("numpy", "ndarray", Flags.Any),
ArrayType("cupy", "ndarray", Flags.Any | Flags.Gpu),
*(
ArrayType("scipy.sparse", n, Flags.Any | Flags.Sparse)
for n in ["csr_array", "csc_array", "csr_matrix", "csc_matrix"]
),
*(
ArrayType("cupyx.scipy.sparse", n, Flags.Any | Flags.Gpu | Flags.Sparse)
for n in ["csr_matrix", "csc_matrix"]
),
)
_TP_DASK = tuple(ArrayType("dask.array", "Array", Flags.Dask | t.flags, inner=t) for t in _TP_MEM)
_TP_DISK = tuple(
ArrayType(m, n, Flags.Any | Flags.Disk) for m, n in [("h5py", "Dataset"), ("zarr", "Array")]
)

SUPPORTED_TYPES: tuple[ArrayType, ...] = (*_TP_MEM, *_TP_DASK, *_TP_DISK)
"""All supported array types."""
Loading