|
| 1 | +"""Add `array-support` directive.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from itertools import groupby |
| 6 | +from typing import TYPE_CHECKING |
| 7 | + |
| 8 | +from docutils import nodes |
| 9 | +from sphinx.util.docutils import SphinxDirective |
| 10 | + |
| 11 | +from scanpy._utils import _docs |
| 12 | + |
| 13 | +if TYPE_CHECKING: |
| 14 | + from collections.abc import Collection, Generator, Iterable, Sequence |
| 15 | + from typing import ClassVar |
| 16 | + |
| 17 | + from sphinx.application import Sphinx |
| 18 | + |
| 19 | + |
| 20 | +ALL_INNER = list(_docs.parse(["np", "sp"], inner=True)) |
| 21 | + |
| 22 | + |
| 23 | +class ArraySupport(SphinxDirective): |
| 24 | + """Document array support.""" |
| 25 | + |
| 26 | + required_arguments: ClassVar = 1 |
| 27 | + |
| 28 | + @property |
| 29 | + def _array_support(self) -> dict[str, tuple[list[str], list[str]]]: |
| 30 | + return self.config.array_support |
| 31 | + |
| 32 | + def run(self) -> list[nodes.Node]: # noqa: D102 |
| 33 | + if self.arguments[0] == "all": |
| 34 | + return self._render_overview() |
| 35 | + |
| 36 | + if not self.arguments[0] not in self._array_support: |
| 37 | + self.error( |
| 38 | + f"API not in `array_support`, add it in `docs/conf.py`: {self.arguments[0]}" |
| 39 | + ) |
| 40 | + array_types = list(_docs.parse(*self._array_support[self.arguments[0]])) |
| 41 | + headers = ( |
| 42 | + "Array type", |
| 43 | + "supported", |
| 44 | + "… experimentally in dask :class:`~dask.array.Array`", |
| 45 | + ) |
| 46 | + data: list[tuple[_docs.Inner, bool, bool]] = [] |
| 47 | + for array_type in ALL_INNER: |
| 48 | + dask_array_type = _docs.DaskArray(array_type) |
| 49 | + data.append(( |
| 50 | + array_type, |
| 51 | + array_type in array_types, |
| 52 | + dask_array_type in array_types, |
| 53 | + )) |
| 54 | + |
| 55 | + title = nodes.title("", "", *self.parse_inline(":ref:`array-support`")[0]) |
| 56 | + rows = self._render_support_data(data) |
| 57 | + return self._render_table(headers, rows, title=title) |
| 58 | + |
| 59 | + def _render_overview(self) -> list[nodes.Node]: |
| 60 | + headers = ["Function", *(at.rst(short=True) for at in ALL_INNER)] |
| 61 | + rows: list[nodes.row] = [] |
| 62 | + for fn, (include, exclude) in self._array_support.items(): |
| 63 | + row_header, _ = self.parse_inline(f":func:`scanpy.{fn}`") |
| 64 | + ats = frozenset(_docs.parse(include, exclude)) |
| 65 | + cells: list[Sequence[nodes.Node]] = [ |
| 66 | + row_header, |
| 67 | + *( |
| 68 | + self._render_support(at in ats, dask=dt in ats) |
| 69 | + for at, dt in zip( |
| 70 | + ALL_INNER, map(_docs.DaskArray, ALL_INNER), strict=True |
| 71 | + ) |
| 72 | + ), |
| 73 | + ] |
| 74 | + rows.append( |
| 75 | + nodes.row( |
| 76 | + "", |
| 77 | + *( |
| 78 | + nodes.entry("", nodes.paragraph("", "", *cell)) |
| 79 | + for cell in cells |
| 80 | + ), |
| 81 | + ) |
| 82 | + ) |
| 83 | + return self._render_table(headers, rows) |
| 84 | + |
| 85 | + def _render_support_data( |
| 86 | + self, |
| 87 | + data: list[tuple[_docs.Inner, bool, bool]], |
| 88 | + ) -> Generator[nodes.row, None, None]: |
| 89 | + for t, group in groupby(data, key=lambda r: type(r[0])): |
| 90 | + group = list(group) # noqa: PLW2901 |
| 91 | + if ( # if all sparse types have the same support, just one row |
| 92 | + t is _docs.ScipySparse |
| 93 | + and (support := one({s for _, s, _ in group})) is not None |
| 94 | + and (in_dask := one({d for _, _, d in group})) is not None |
| 95 | + ): |
| 96 | + refs: list[nodes.Node] = [ |
| 97 | + nodes.inline("", "scipy.sparse.{"), |
| 98 | + *self.parse_inline(":class:`csr <scipy.sparse.csr_array>`")[0], |
| 99 | + nodes.inline("", ","), |
| 100 | + *self.parse_inline(":class:`csc <scipy.sparse.csc_matrix>`")[0], |
| 101 | + nodes.inline("", "}_{"), |
| 102 | + *self.parse_inline(":class:`array <scipy.sparse.csc_array>`")[0], |
| 103 | + nodes.inline("", ","), |
| 104 | + *self.parse_inline(":class:`matrix <scipy.sparse.csr_matrix>`")[0], |
| 105 | + nodes.inline("", "}"), |
| 106 | + ] |
| 107 | + header = [nodes.literal("", "", *refs)] |
| 108 | + yield self._render_row(header, support=support, in_dask=in_dask) |
| 109 | + else: # otherwise, show them individually |
| 110 | + for array_type, support, in_dask in group: |
| 111 | + yield self._render_row( |
| 112 | + self._render_array_type(array_type), |
| 113 | + support=support, |
| 114 | + in_dask=in_dask, |
| 115 | + ) |
| 116 | + |
| 117 | + def _render_row( |
| 118 | + self, header: Sequence[nodes.Node], *, support: bool, in_dask: bool |
| 119 | + ) -> nodes.row: |
| 120 | + cells: list[Sequence[nodes.Node]] = [ |
| 121 | + header, |
| 122 | + self._render_support(support), |
| 123 | + self._render_support(in_dask), |
| 124 | + ] |
| 125 | + children = (nodes.entry("", nodes.paragraph("", "", *cell)) for cell in cells) |
| 126 | + return nodes.row("", *children) |
| 127 | + |
| 128 | + def _render_table( |
| 129 | + self, |
| 130 | + headers: Collection[str], |
| 131 | + rows: Iterable[nodes.row], |
| 132 | + *, |
| 133 | + title: nodes.title | None = None, |
| 134 | + ) -> list[nodes.Node]: |
| 135 | + colspecs = [ |
| 136 | + nodes.colspec(stub=True), |
| 137 | + *(nodes.colspec() for _ in range(len(headers) - 1)), |
| 138 | + ] |
| 139 | + header_nodes = [ |
| 140 | + nodes.entry("", nodes.paragraph("", "", *self.parse_inline(t)[0])) |
| 141 | + for t in headers |
| 142 | + ] |
| 143 | + thead = nodes.thead("", nodes.row("", *header_nodes)) |
| 144 | + tbody = nodes.tbody("", *rows) |
| 145 | + return [ |
| 146 | + nodes.table( |
| 147 | + "", |
| 148 | + *([title] if title else []), |
| 149 | + nodes.tgroup("", *colspecs, thead, tbody, cols=len(colspecs)), |
| 150 | + ids=["array-support"], |
| 151 | + ) |
| 152 | + ] |
| 153 | + |
| 154 | + def _render_support( |
| 155 | + self, |
| 156 | + support: bool, # noqa: FBT001 |
| 157 | + /, |
| 158 | + *, |
| 159 | + dask: bool = False, |
| 160 | + ) -> Sequence[nodes.Node]: |
| 161 | + dask_expl = "Also supports this type as chunk in a dask Array" |
| 162 | + return [ |
| 163 | + nodes.Text(("✅" if support else "❌") + " " * dask), |
| 164 | + *([nodes.abbreviation(text="⚡", explanation=dask_expl)] if dask else []), |
| 165 | + ] |
| 166 | + |
| 167 | + def _render_array_type(self, array_type: _docs.ArrayType, /) -> list[nodes.Node]: |
| 168 | + nodes_, msgs = self.parse_inline(array_type.rst()) |
| 169 | + assert not msgs, msgs |
| 170 | + return nodes_ |
| 171 | + |
| 172 | + |
| 173 | +def one[T](arg: Collection[T]) -> T | None: |
| 174 | + """Return the only item in `arg` or None if `arg` is not of length 1.""" |
| 175 | + try: |
| 176 | + [item] = arg |
| 177 | + except ValueError: |
| 178 | + return None |
| 179 | + return item |
| 180 | + |
| 181 | + |
| 182 | +def setup(app: Sphinx) -> None: |
| 183 | + """App setup hook.""" |
| 184 | + app.add_directive("array-support", ArraySupport) |
| 185 | + app.add_config_value("array_support", {}, "env") |
0 commit comments