Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions src/anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,19 @@ def _validate_value(self, val: Value, key: str) -> Value:
name = f"{self.attrname.title().rstrip('s')} {key!r}"
return coerce_array(val, name=name, allow_df=self._allow_df)

_attrname_override: str | None = None

@property
@abstractmethod
def attrname(self) -> str:
"""What attr for the AnnData is this?"""
if self._attrname_override is not None:
return self._attrname_override
return self._default_attrname

@property
@abstractmethod
def _default_attrname(self) -> str:
"""Default attr name derived from axis (e.g., 'obsm', 'varp')."""

@property
@abstractmethod
Expand Down Expand Up @@ -151,6 +160,9 @@ def __init__(self, parent_mapping: P, parent_view: AnnData, subset_idx: I) -> No
self.parent_mapping = parent_mapping
self._parent = parent_view
self.subset_idx = subset_idx
# Propagate attrname override from actual to view (for registered sections)
if parent_mapping._attrname_override is not None:
self._attrname_override = parent_mapping._attrname_override
if hasattr(parent_mapping, "_axis"):
# LayersBase has no _axis, the rest does
self._axis = parent_mapping._axis # type: ignore
Expand Down Expand Up @@ -237,7 +249,7 @@ class AxisArraysBase(AlignedMappingBase):
_axis: Literal[0, 1]

@property
def attrname(self) -> str:
def _default_attrname(self) -> str:
return f"{self.dim}m"

@property
Expand Down Expand Up @@ -311,9 +323,12 @@ class LayersBase(AlignedMappingBase):
"""

_allow_df: ClassVar = False
attrname: ClassVar[Literal["layers"]] = "layers"
axes: ClassVar[tuple[Literal[0], Literal[1]]] = (0, 1)

@property
def _default_attrname(self) -> str:
return "layers"


class Layers(AlignedActual, LayersBase):
pass
Expand All @@ -339,7 +354,7 @@ class PairwiseArraysBase(AlignedMappingBase):
_axis: Literal[0, 1]

@property
def attrname(self) -> str:
def _default_attrname(self) -> str:
return f"{self.dim}p"

@property
Expand Down Expand Up @@ -402,8 +417,13 @@ class AlignedMappingProperty[T: AlignedMapping](property):

def construct(self, obj: AnnData, *, store: MutableMapping[str, Value]) -> T:
if self.axis is None:
return self.cls(obj, store=store)
return self.cls(obj, axis=self.axis, store=store)
mapping = self.cls(obj, store=store)
else:
mapping = self.cls(obj, axis=self.axis, store=store)
# Override attrname for registered sections (e.g., "obst" instead of "obsm")
if mapping._default_attrname != self.name:
mapping._attrname_override = self.name
return mapping

@property
def fget(self) -> Callable[[], None]:
Expand All @@ -420,7 +440,11 @@ def __get__(self, obj: None | AnnData, objtype: type | None = None) -> T:
# this needs to return a `property` instance, e.g. for Sphinx
return self # type: ignore
if not obj.is_view:
return self.construct(obj, store=getattr(obj, f"_{self.name}"))
store = getattr(obj, f"_{self.name}", None)
if store is None:
store = {}
setattr(obj, f"_{self.name}", store)
return self.construct(obj, store=store)
parent_anndata = obj._adata_ref
idxs = (obj._oidx, obj._vidx)
parent: AlignedMapping = getattr(parent_anndata, self.name)
Expand Down
56 changes: 40 additions & 16 deletions src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641
)

_accessors: ClassVar[set[str]] = set()
_registered_sections: ClassVar[dict] = {} # str -> SectionSpec

# view attributes
_adata_ref: AnnData | None
Expand Down Expand Up @@ -242,6 +243,7 @@ def __init__( # noqa: PLR0913
varp: np.ndarray | Mapping[str, Sequence[Any]] | None = None,
oidx: _Index1DNorm | int | np.integer | None = None,
vidx: _Index1DNorm | int | np.integer | None = None,
**extra_sections,
):
# check for any multi-indices that aren’t later checked in coerce_array
for attr, key in [(obs, "obs"), (var, "var"), (X, "X")]:
Expand Down Expand Up @@ -270,6 +272,7 @@ def __init__( # noqa: PLR0913
varp=varp,
filename=filename,
filemode=filemode,
**extra_sections,
)

def _init_as_view(
Expand Down Expand Up @@ -361,6 +364,7 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915
shape=None,
filename=None,
filemode=None,
**extra_sections,
):
# view attributes
self._is_view = False
Expand Down Expand Up @@ -391,6 +395,15 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915
if any((obs, var, uns, obsm, varm, obsp, varp)):
msg = "If `X` is a dict no further arguments must be provided."
raise ValueError(msg)
# Copy extension sections from source AnnData
# (built-in sections are handled by the explicit unpacking below)
for sec_name, spec in self._registered_sections.items():
if spec.builtin:
continue
if sec_name not in extra_sections:
src_mapping = getattr(X, sec_name, None)
if src_mapping is not None and len(src_mapping) > 0:
extra_sections[sec_name] = dict(src_mapping)
X, obs, var, uns, obsm, varm, obsp, varp, layers, raw = (
X._X,
X.obs,
Expand Down Expand Up @@ -509,6 +522,12 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915
# layers
self.layers = layers

# registered sections (e.g., obst, vart from extensions)
for sec_name in self._registered_sections:
value = extra_sections.get(sec_name)
if value is not None:
setattr(self, sec_name, value)

@old_positionals("show_stratified", "with_disk")
def __sizeof__(
self, *, show_stratified: bool = False, with_disk: bool = False
Expand Down Expand Up @@ -545,21 +564,17 @@ def cs_to_bytes(X) -> int:
return sum(sizes.values())

def _gen_repr(self, n_obs, n_vars) -> str:
from .section_registry import iter_sections

backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else ""
descr = f"AnnData object with n_obs × n_vars = {n_obs} × {n_vars}{backed_at}"
for attr in [
"obs",
"var",
"uns",
"obsm",
"varm",
"layers",
"obsp",
"varp",
]:
keys = getattr(self, attr).keys()
for spec, value in iter_sections(self, exclude_kinds={"X", "raw"}):
try:
keys = value.keys()
except Exception: # noqa: BLE001
continue
if len(keys) > 0:
descr += f"\n {attr}: {str(list(keys))[1:-1]}"
descr += f"\n {spec.name}: {str(list(keys))[1:-1]}"
return descr

def __repr__(self) -> str:
Expand Down Expand Up @@ -1413,11 +1428,13 @@ def _mutated_copy(self, **kwargs) -> AnnData:
raise NotImplementedError(msg)
new = {}

for key in ["obs", "var", "obsm", "varm", "obsp", "varp", "layers"]:
if key in kwargs:
new[key] = kwargs[key]
from .section_registry import iter_sections

for spec, value in iter_sections(self, kinds={"dataframe", "mapping"}):
if spec.name in kwargs:
new[spec.name] = kwargs[spec.name]
else:
new[key] = getattr(self, key).copy()
new[spec.name] = value.copy()
if "X" in kwargs:
new["X"] = kwargs["X"]
elif self._has_X():
Expand Down Expand Up @@ -2154,6 +2171,13 @@ def _remove_unused_categories_xr(
pass # this is handled automatically by the categorical arrays themselves i.e., they dedup upon access.


# Populate _registered_sections with built-in section specs.
# Must happen after AnnData class definition is complete.
from .section_registry import _init_builtin_sections # noqa: E402

_init_builtin_sections(AnnData)


def _check_2d_shape(X):
"""\
Check shape of array or sparse matrix.
Expand Down
Loading
Loading