Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/geome/ann2data/base/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
adata2iterable: Callable[[AnnData], Iterable[AnnData]] | None = None,
preprocess: Callable[[AnnData], AnnData] | None = None,
transform: Callable[[AnnData], AnnData] | None = None,
save_preprocessed_adata: bool = False,
Comment thread
FrancescaDr marked this conversation as resolved.
Outdated
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -37,6 +38,7 @@ def __init__(
self._preprocess = preprocess
self.fields = fields
self._transform = transform
self.save_preprocessed_adata = save_preprocessed_adata

@abstractmethod
def merge_field(self, adata: AnnData, field: str, locations: list[str]) -> torch.Tensor:
Expand Down Expand Up @@ -82,6 +84,7 @@ def __call__(self, adata: AnnData | Iterable[AnnData]) -> Iterable[Data]:
PyTorch Geometric compatible data object.

"""
print("call new")
# do the given preprocessing steps.
if self._preprocess is not None:
adata = self._preprocess(adata)
Expand All @@ -91,10 +94,17 @@ def __call__(self, adata: AnnData | Iterable[AnnData]) -> Iterable[Data]:
adata_iter = self._adata2iterable(adata)

# iterate trough adata.
data_objects = []
for subadata in adata_iter:
if self._transform is not None:
subadata = self._transform(subadata)
yield self.create_data_obj(subadata)
data_objects.append(self.create_data_obj(subadata))

# Return the data objects and adata or None
if self.save_preprocessed_adata:
return data_objects, adata
else:
return data_objects, None

def to_list(self, adata: AnnData | Iterable[AnnData]) -> list[Data]:
"""Convert an AnnData object to a list of PyTorch compatible data objects.
Expand Down
6 changes: 4 additions & 2 deletions src/geome/ann2data/basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable
from typing import Any, Callable

import numpy as np
import pandas as pd
Expand All @@ -22,6 +22,8 @@ def __init__(
adata2iter: Callable[[AnnData], AnnData] | None = None,
preprocess: list[Callable[[AnnData], AnnData]] | None = None,
transform: list[Callable[[AnnData], AnnData]] | None = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Convert anndata object into a dictionary of arrays.

Expand All @@ -46,7 +48,7 @@ def __init__(
transform: List of functions to transform the AnnData object after preprocessing.
edge_index_key: Key for the edge index in the converted data. Defaults to 'edge_index'.
"""
super().__init__(fields, adata2iter, preprocess, transform)
super().__init__(fields, adata2iter, preprocess, transform, *args, **kwargs)

self._preprocess = preprocess
self._transform = transform
Expand Down
22 changes: 12 additions & 10 deletions src/geome/iterables/to_category_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Iterator
from dataclasses import dataclass
from typing import Literal
from typing import Any, Callable, Literal, Optional

from anndata import AnnData

Expand All @@ -21,12 +21,12 @@ class ToCategoryIterator(ToIterable):
axis (int | str): The axis along which to iterate over the categories. Can be either 0, 1, "obs" or "var".
0 or "obs" means the categories are in the observation axis.
1 or "var" means the categories are in the variable axis.
preserve_categories (bool): Preserves the categories in the resulting AnnData obs and var Series if `preserve_categories` is True.
preserve_categories (list): If not None, preserves the indicated categories from the Anndata 'obs' and 'var'
"""

category: str
axis: Literal[0, 1, "obs", "var"] = "obs"
preserve_categories: bool = True
preserve_categories: Optional[list[str]] = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea but I think it might be better if we do it like this

preserve_categories: Optional[list[str]] = []

So empty list by default which would mean no preservation. And if it is None that would mean all categories. So then it would make sense to change the variable name to preserved_categories. Also to make the typing more modern finally it would look like this

preserved_categories: Sequence[str] | None = tuple()

here default is tuple instead because its not good to have default arguments as mutuable objects. Could you also update docstring with that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review.

If we change preserved_categories to a tuple then how are we to distinguish between the "obs" and "var" variables?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we specify if its a var or obs in axis right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes true. I am now questioning whether this even works for .var. In the Ann2DataByCategory object the ToCategoryIterator is by default initialised with obs. Also in thetest_to_category_iterator() there is nothing tested for "var". Am I overlooking where the ToCategoryIteratorwould be initialised withaxis = var`?


def __post_init__(self):
if self.axis not in (0, 1, "obs", "var"):
Expand All @@ -51,13 +51,15 @@ def __call__(self, adata: AnnData) -> Iterator[AnnData]:
cats_df = get_from_loc(adata, f"{self.axis}/{self.category}")
cats = cats_df.dtypes.categories
preserved_categories = {"obs": {}, "var": {}}
if self.preserve_categories:
for axis in ("obs", "var"):
adata_axis = getattr(adata, axis)
if adata_axis is not None:
for key in adata_axis.keys():
if adata_axis[key].dtype.name == "category":
preserved_categories[axis][key] = adata_axis[key].cat.categories

if self.preserve_categories is not None:
for key in self.preserve_categories:
for axis in ("obs", "var"):
adata_axis = getattr(adata, axis)
if adata_axis is not None:
if key in adata_axis.keys():
if adata_axis[key].dtype.name == "category":
preserved_categories[axis][key] = adata_axis[key].cat.categories

for cat in cats:
# TODO(syelman): is this wise? Maybe create copy only if preserve_categories is True?
Expand Down
2 changes: 2 additions & 0 deletions src/geome/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .base.transform import Transform
from .categorize import Categorize
from .compose import Compose
from .one_hot_encode import SaveOneHotEncodeLabels
from .subset import Subset

__all__ = [
Expand All @@ -15,4 +16,5 @@
"AddEdgeIndex",
"AddEdgeIndexFromAdj",
"Subset",
"SaveOneHotEncodeLabels",
]
70 changes: 70 additions & 0 deletions src/geome/transforms/one_hot_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

import pandas as pd
from anndata import AnnData

from .base.transform import Transform


@dataclass
class SaveOneHotEncodeLabels(Transform):
"""One-hot encode specified columns in an AnnData object and store them in the specified matrix slot.

Args:
-----
keys (Union[str, List[str]]): Columns to be one-hot encoded.
axis (Literal["obs", "var"]): Axis on which the columns are located. 'obs' for observation, 'var' for variables.
key_added (str): Base key under which the one-hot encoded data and label mappings will be stored.

Methods
-------
__call__(adata: AnnData) -> None:
Converts the specified columns to one-hot encoded format and updates `adata` accordingly.
"""

keys: str | list
axis: Literal["obs", "var"]
key_added: str

def __post_init__(self):
if isinstance(self.keys, str):
self.keys = [self.keys]

def __call__(self, adata: AnnData) -> None:
"""
One-hot encode the specified columns and store the result in the AnnData object.

Parameters
----------
adata : AnnData
The annotated data matrix to be updated with one-hot encoded data.

Returns
-------
None
"""
matrix_key = f"{self.axis}m" # e.g., 'obsm' or 'varm'
# encoded_data = {}
label_mappings = {}
encoded_data_list = []

for key in self.keys:
# Generate one-hot encoding
categories = pd.get_dummies(getattr(adata, self.axis)[key])
# encoded_data[key] = categories
encoded_data_list.append(categories)
# Store mapping of codes to labels
label_mappings[key] = categories.columns.tolist()

encoded_data_combined = pd.concat(encoded_data_list, axis=1)

# Save the encoded data and mappings in the appropriate AnnData structure
getattr(adata, matrix_key)[self.key_added] = pd.DataFrame(encoded_data_combined)
adata.uns[f"{self.key_added}_mappings"] = label_mappings

print(label_mappings)

return adata
Comment thread
FrancescaDr marked this conversation as resolved.
Empty file.