-
Notifications
You must be signed in to change notification settings - Fork 4
Features for InterScale #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
e38e81e
f3f5e31
7fc9ddd
79f836f
09e582f
b50a76f
bd46b54
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |
|
|
||
| from collections.abc import Iterator | ||
| from dataclasses import dataclass | ||
| from typing import Literal | ||
| from typing import Any, Callable, Literal, Optional | ||
|
|
||
| from anndata import AnnData | ||
|
|
||
|
|
@@ -21,12 +21,12 @@ class ToCategoryIterator(ToIterable): | |
| axis (int | str): The axis along which to iterate over the categories. Can be either 0, 1, "obs" or "var". | ||
| 0 or "obs" means the categories are in the observation axis. | ||
| 1 or "var" means the categories are in the variable axis. | ||
| preserve_categories (bool): Preserves the categories in the resulting AnnData obs and var Series if `preserve_categories` is True. | ||
| preserve_categories (list): If not None, preserves the indicated categories from the Anndata 'obs' and 'var' | ||
| """ | ||
|
|
||
| category: str | ||
| axis: Literal[0, 1, "obs", "var"] = "obs" | ||
| preserve_categories: bool = True | ||
| preserve_categories: Optional[list[str]] = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this idea but I think it might be better if we do it like this So empty list by default which would mean no preservation. And if it is preserved_categories: Sequence[str] | None = tuple()here default is tuple instead because its not good to have default arguments as mutuable objects. Could you also update docstring with that?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the review. If we change preserved_categories to a tuple then how are we to distinguish between the "obs" and "var" variables?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But we specify if its a var or obs in
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes true. I am now questioning whether this even works for |
||
|
|
||
| def __post_init__(self): | ||
| if self.axis not in (0, 1, "obs", "var"): | ||
|
|
@@ -51,13 +51,15 @@ def __call__(self, adata: AnnData) -> Iterator[AnnData]: | |
| cats_df = get_from_loc(adata, f"{self.axis}/{self.category}") | ||
| cats = cats_df.dtypes.categories | ||
| preserved_categories = {"obs": {}, "var": {}} | ||
| if self.preserve_categories: | ||
| for axis in ("obs", "var"): | ||
| adata_axis = getattr(adata, axis) | ||
| if adata_axis is not None: | ||
| for key in adata_axis.keys(): | ||
| if adata_axis[key].dtype.name == "category": | ||
| preserved_categories[axis][key] = adata_axis[key].cat.categories | ||
|
|
||
| if self.preserve_categories is not None: | ||
| for key in self.preserve_categories: | ||
| for axis in ("obs", "var"): | ||
| adata_axis = getattr(adata, axis) | ||
| if adata_axis is not None: | ||
| if key in adata_axis.keys(): | ||
| if adata_axis[key].dtype.name == "category": | ||
| preserved_categories[axis][key] = adata_axis[key].cat.categories | ||
|
|
||
| for cat in cats: | ||
| # TODO(syelman): is this wise? Maybe create copy only if preserve_categories is True? | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Literal | ||
|
|
||
| import pandas as pd | ||
| from anndata import AnnData | ||
|
|
||
| from .base.transform import Transform | ||
|
|
||
|
|
||
| @dataclass | ||
| class SaveOneHotEncodeLabels(Transform): | ||
| """One-hot encode specified columns in an AnnData object and store them in the specified matrix slot. | ||
|
|
||
| Args: | ||
| ----- | ||
| keys (Union[str, List[str]]): Columns to be one-hot encoded. | ||
| axis (Literal["obs", "var"]): Axis on which the columns are located. 'obs' for observation, 'var' for variables. | ||
| key_added (str): Base key under which the one-hot encoded data and label mappings will be stored. | ||
|
|
||
| Methods | ||
| ------- | ||
| __call__(adata: AnnData) -> None: | ||
| Converts the specified columns to one-hot encoded format and updates `adata` accordingly. | ||
| """ | ||
|
|
||
| keys: str | list | ||
| axis: Literal["obs", "var"] | ||
| key_added: str | ||
|
|
||
| def __post_init__(self): | ||
| if isinstance(self.keys, str): | ||
| self.keys = [self.keys] | ||
|
|
||
| def __call__(self, adata: AnnData) -> None: | ||
| """ | ||
| One-hot encode the specified columns and store the result in the AnnData object. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| adata : AnnData | ||
| The annotated data matrix to be updated with one-hot encoded data. | ||
|
|
||
| Returns | ||
| ------- | ||
| None | ||
| """ | ||
| matrix_key = f"{self.axis}m" # e.g., 'obsm' or 'varm' | ||
| # encoded_data = {} | ||
| label_mappings = {} | ||
| encoded_data_list = [] | ||
|
|
||
| for key in self.keys: | ||
| # Generate one-hot encoding | ||
| categories = pd.get_dummies(getattr(adata, self.axis)[key]) | ||
| # encoded_data[key] = categories | ||
| encoded_data_list.append(categories) | ||
| # Store mapping of codes to labels | ||
| label_mappings[key] = categories.columns.tolist() | ||
|
|
||
| encoded_data_combined = pd.concat(encoded_data_list, axis=1) | ||
|
|
||
| # Save the encoded data and mappings in the appropriate AnnData structure | ||
| getattr(adata, matrix_key)[self.key_added] = pd.DataFrame(encoded_data_combined) | ||
| adata.uns[f"{self.key_added}_mappings"] = label_mappings | ||
|
|
||
| print(label_mappings) | ||
|
|
||
| return adata |
|
FrancescaDr marked this conversation as resolved.
|
Uh oh!
There was an error while loading. Please reload this page.