Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions src/geome/ann2data/base/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,31 +70,61 @@ def create_data_obj(self, adata: AnnData) -> Data:
obj[field] = self.merge_field(adata, field, locations)
return Data(**obj)

def __call__(self, adata: AnnData | Iterable[AnnData]) -> Iterable[Data]:
def __call__(self,
adata: AnnData | Iterable[AnnData],
save_subadata: bool = False,
save_data: bool = False,
save_preprocessed: bool = False) -> Iterable[Data]:
"""Convert an AnnData object to a PyTorch compatible data object.

Args:
----
adata: The AnnData object to be converted.
save_subadata: If True, save the subadata objects.
save_data: If True, save the data objects.
save_preprocessed: If True, save the preprocessed AnnData object.

Yields
------
PyTorch Geometric compatible data object.

"""
if save_subadata or save_data or save_preprocessed:
res = {}
else:
res = None

# do the given preprocessing steps.
if self._preprocess is not None:
adata = self._preprocess(adata)
if save_preprocessed:
res['preprocessed'] = adata
# convert adata to iterable if it is specified
adata_iter = adata
if self._adata2iterable is not None:
adata_iter = self._adata2iterable(adata)

if save_subadata:
res['subadata'] = []
if save_data:
res['data'] = []

# iterate trough adata.
data_objects = []
for subadata in adata_iter:
if self._transform is not None:
subadata = self._transform(subadata)
yield self.create_data_obj(subadata)

current_data_obj = self.create_data_obj(subadata)
data_objects.append(current_data_obj)

if save_subadata:
res['subadata'].append(subadata)
if save_data:
res['data'].append(current_data_obj)

return data_objects, res


def to_list(self, adata: AnnData | Iterable[AnnData]) -> list[Data]:
"""Convert an AnnData object to a list of PyTorch compatible data objects.
Expand All @@ -108,3 +138,4 @@ def to_list(self, adata: AnnData | Iterable[AnnData]) -> list[Data]:
A list of PyTorch Geometric compatible data objects.
"""
return list(self(adata))

6 changes: 4 additions & 2 deletions src/geome/ann2data/basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable
from typing import Any, Callable

import numpy as np
import pandas as pd
Expand All @@ -22,6 +22,8 @@ def __init__(
adata2iter: Callable[[AnnData], AnnData] | None = None,
preprocess: list[Callable[[AnnData], AnnData]] | None = None,
transform: list[Callable[[AnnData], AnnData]] | None = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Convert anndata object into a dictionary of arrays.

Expand All @@ -46,7 +48,7 @@ def __init__(
transform: List of functions to transform the AnnData object after preprocessing.
edge_index_key: Key for the edge index in the converted data. Defaults to 'edge_index'.
"""
super().__init__(fields, adata2iter, preprocess, transform)
super().__init__(fields, adata2iter, preprocess, transform, *args, **kwargs)

self._preprocess = preprocess
self._transform = transform
Expand Down
22 changes: 12 additions & 10 deletions src/geome/iterables/to_category_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Iterator
from dataclasses import dataclass
from typing import Literal
from typing import Any, Callable, Literal, Optional

from anndata import AnnData

Expand All @@ -21,12 +21,12 @@ class ToCategoryIterator(ToIterable):
axis (int | str): The axis along which to iterate over the categories. Can be either 0, 1, "obs" or "var".
0 or "obs" means the categories are in the observation axis.
1 or "var" means the categories are in the variable axis.
preserve_categories (bool): Preserves the categories in the resulting AnnData obs and var Series if `preserve_categories` is True.
preserve_categories (list): If not None, preserves the indicated categories from the Anndata 'obs' and 'var'
"""

category: str
axis: Literal[0, 1, "obs", "var"] = "obs"
preserve_categories: bool = True
preserve_categories: Optional[list[str]] = []

def __post_init__(self):
if self.axis not in (0, 1, "obs", "var"):
Expand All @@ -51,13 +51,15 @@ def __call__(self, adata: AnnData) -> Iterator[AnnData]:
cats_df = get_from_loc(adata, f"{self.axis}/{self.category}")
cats = cats_df.dtypes.categories
preserved_categories = {"obs": {}, "var": {}}
if self.preserve_categories:
for axis in ("obs", "var"):
adata_axis = getattr(adata, axis)
if adata_axis is not None:
for key in adata_axis.keys():
if adata_axis[key].dtype.name == "category":
preserved_categories[axis][key] = adata_axis[key].cat.categories

if self.preserve_categories is not None:
for key in self.preserve_categories:
for axis in ("obs", "var"):
adata_axis = getattr(adata, axis)
if adata_axis is not None:
if key in adata_axis.keys():
if adata_axis[key].dtype.name == "category":
preserved_categories[axis][key] = adata_axis[key].cat.categories

for cat in cats:
# TODO(syelman): is this wise? Maybe create copy only if preserve_categories is True?
Expand Down
2 changes: 2 additions & 0 deletions src/geome/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .base.transform import Transform
from .categorize import Categorize
from .compose import Compose
from .one_hot_encode import SaveOneHotEncodeLabels
from .subset import Subset

__all__ = [
Expand All @@ -15,4 +16,5 @@
"AddEdgeIndex",
"AddEdgeIndexFromAdj",
"Subset",
"SaveOneHotEncodeLabels",
]
70 changes: 70 additions & 0 deletions src/geome/transforms/one_hot_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

import pandas as pd
from anndata import AnnData

from .base.transform import Transform


@dataclass
class SaveOneHotEncodeLabels(Transform):
"""One-hot encode specified columns in an AnnData object and store them in the specified matrix slot.

Args:
-----
keys (Union[str, List[str]]): Columns to be one-hot encoded.
axis (Literal["obs", "var"]): Axis on which the columns are located. 'obs' for observation, 'var' for variables.
key_added (str): Base key under which the one-hot encoded data and label mappings will be stored.

Methods
-------
__call__(adata: AnnData) -> None:
Converts the specified columns to one-hot encoded format and updates `adata` accordingly.
"""

keys: str | list
axis: Literal["obs", "var"]
key_added: str

def __post_init__(self):
if isinstance(self.keys, str):
self.keys = [self.keys]

def __call__(self, adata: AnnData) -> None:
"""
One-hot encode the specified columns and store the result in the AnnData object.

Parameters
----------
adata : AnnData
The annotated data matrix to be updated with one-hot encoded data.

Returns
-------
None
"""
matrix_key = f"{self.axis}m" # e.g., 'obsm' or 'varm'
# encoded_data = {}
label_mappings = {}
encoded_data_list = []

for key in self.keys:
# Generate one-hot encoding
categories = pd.get_dummies(getattr(adata, self.axis)[key])
# encoded_data[key] = categories
encoded_data_list.append(categories)
# Store mapping of codes to labels
label_mappings[key] = categories.columns.tolist()

encoded_data_combined = pd.concat(encoded_data_list, axis=1)

# Save the encoded data and mappings in the appropriate AnnData structure
getattr(adata, matrix_key)[self.key_added] = pd.DataFrame(encoded_data_combined)
adata.uns[f"{self.key_added}_mappings"] = label_mappings

print(label_mappings)

return adata
Empty file.