Skip to content

Commit 84aaf5c

Browse files
Release 0.1.5 (#25)
* Refactor benchmark * Add CELU for MA * Update version * Revert callable for latent mean / var activation * Update CHANGELOG.md * [pre-commit.ci] pre-commit autoupdate (#24) updates: - [github.com/astral-sh/ruff-pre-commit: v0.11.5 → v0.11.8](astral-sh/ruff-pre-commit@v0.11.5...v0.11.8) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update _benchmark.py (align with ruff) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 71c6a0a commit 84aaf5c

File tree

11 files changed

+334
-38
lines changed

11 files changed

+334
-38
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ repos:
1111
hooks:
1212
- id: prettier
1313
- repo: https://github.com/astral-sh/ruff-pre-commit
14-
rev: v0.11.5
14+
rev: v0.11.8
1515
hooks:
1616
- id: ruff
1717
types_or: [python, pyi, jupyter]

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44

55
-
66

7+
## [0.1.5] - 2025-05-09
8+
9+
- Refactor benchmarking code for better reusability
10+
- Revert callable for mean and var activation
11+
12+
## [0.1.4] - 2025-04-17
13+
14+
- Limit anndata version for compatibility with old scvi-tools
15+
716
## [0.1.3] - 2025-02-12
817

918
- Introduce mean activation to make non-negative latents possible (docs will come later)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ requires = ["hatchling"]
44

55
[project]
66
name = "drvi-py"
7-
version = "0.1.3"
7+
version = "0.1.5"
88
description = "Disentangled Generative Representation of Single Cell Omics"
99
readme = "README.md"
1010
requires-python = ">=3.10,<3.13"

src/drvi/scvi_tools_based/module/_drvi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Iterable, Sequence
1+
from collections.abc import Callable, Iterable, Sequence
22
from typing import Literal
33

44
import numpy as np
@@ -148,8 +148,8 @@ def __init__(
148148
] = "pnb_softmax",
149149
prior: Literal["normal", "gmm_x", "vamp_x"] = "normal",
150150
prior_init_dataloader: DataLoader | None = None,
151-
var_activation: Literal["exp", "pow2"] = "exp",
152-
mean_activation: str = "identity",
151+
var_activation: Callable | Literal["exp", "pow2", "2sig"] | Callable = "exp",
152+
mean_activation: Callable | str = "identity",
153153
encoder_layer_factory: LayerFactory = None,
154154
decoder_layer_factory: LayerFactory = None,
155155
extra_encoder_kwargs: dict | None = None,

src/drvi/scvi_tools_based/nn/_base_components.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import collections
22
import math
3-
from collections.abc import Iterable, Sequence
3+
from collections.abc import Callable, Iterable, Sequence
44
from typing import Literal
55

66
import torch
@@ -421,8 +421,8 @@ def __init__(
421421
dropout_rate: float = 0.1,
422422
distribution: str = "normal",
423423
var_eps: float = 1e-4,
424-
var_activation: Literal["exp", "pow2"] = "exp",
425-
mean_activation: str = "identity",
424+
var_activation: Callable | Literal["exp", "pow2"] = "exp",
425+
mean_activation: Callable | str = "identity",
426426
layer_factory: LayerFactory = None,
427427
covariate_modeling_strategy: Literal[
428428
"one_hot",
@@ -499,8 +499,11 @@ def __init__(
499499
self.var_activation = torch.exp
500500
elif var_activation == "pow2":
501501
self.var_activation = lambda x: torch.pow(x, 2)
502+
elif var_activation == "2sig":
503+
self.var_activation = lambda x: 2 * torch.sigmoid(x)
502504
else:
503-
raise NotImplementedError()
505+
assert callable(var_activation)
506+
self.var_activation = var_activation
504507

505508
if mean_activation == "identity":
506509
self.mean_activation = nn.Identity()
@@ -516,8 +519,14 @@ def __init__(
516519
mean_activation = "elu_1.0"
517520
alpha = float(mean_activation.split("elu_")[1])
518521
self.mean_activation = nn.ELU(alpha=alpha)
522+
elif mean_activation.startswith("celu"):
523+
if mean_activation == "celu":
524+
mean_activation = "celu_1.0"
525+
alpha = float(mean_activation.split("celu_")[1])
526+
self.mean_activation = nn.CELU(alpha=alpha)
519527
else:
520-
raise NotImplementedError()
528+
assert callable(mean_activation)
529+
self.mean_activation = mean_activation
521530

522531
def forward(self, x: torch.Tensor, cat_full_tensor: torch.Tensor, cont_full_tensor: torch.Tensor = None):
523532
r"""The forward computation for a single sample.

src/drvi/utils/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._aggregation import latent_matching_score, most_similar_averaging_score, most_similar_gap_score
2+
from ._benchmark import DiscreteDisentanglementBenchmark
23
from ._pairwise import (
34
global_dim_mutual_info_score,
45
local_mutual_info_score,
@@ -14,4 +15,5 @@
1415
"most_similar_averaging_score",
1516
"latent_matching_score",
1617
"most_similar_gap_score",
18+
"DiscreteDisentanglementBenchmark",
1719
]

src/drvi/utils/metrics/_benchmark.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import pickle
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
from drvi.utils.metrics._aggregation import latent_matching_score, most_similar_averaging_score, most_similar_gap_score
7+
from drvi.utils.metrics._pairwise import local_mutual_info_score, nn_alignment_score, spearman_correlataion_score
8+
9+
AVAILABLE_METRICS = {
10+
"ASC": spearman_correlataion_score,
11+
"SPN": nn_alignment_score,
12+
"SMI": local_mutual_info_score,
13+
}
14+
15+
16+
AVAILABLE_AGGREGATION_METHODS = {
17+
"LMS": latent_matching_score,
18+
"MSAS": most_similar_averaging_score,
19+
"MSGS": most_similar_gap_score,
20+
}
21+
22+
23+
class DiscreteDisentanglementBenchmark:
24+
version = "v1"
25+
26+
def __init__(
27+
self,
28+
embed,
29+
discrete_target=None,
30+
one_hot_target=None,
31+
dim_titles=None,
32+
metrics=("SMI", "SPN", "ASC"),
33+
aggregation_methods=("LMS", "MSAS", "MSGS"),
34+
):
35+
if discrete_target is None and one_hot_target is None:
36+
raise ValueError("Either discrete_target or one_hot_target must be provided.")
37+
if discrete_target is not None and one_hot_target is not None:
38+
raise ValueError("Only one of discrete_target or one_hot_target should be provided.")
39+
40+
if discrete_target is not None:
41+
if isinstance(discrete_target, pd.Series):
42+
discrete_target = discrete_target.astype("category")
43+
elif isinstance(discrete_target, np.ndarray):
44+
discrete_target = pd.Series(discrete_target, dtype="category")
45+
else:
46+
raise ValueError("discrete_target must be a pandas Series or numpy array")
47+
one_hot_target = pd.DataFrame(
48+
np.eye(len(discrete_target.cat.categories))[discrete_target.cat.codes],
49+
columns=discrete_target.cat.categories,
50+
)
51+
52+
if isinstance(one_hot_target, pd.DataFrame):
53+
pass
54+
elif isinstance(one_hot_target, np.ndarray):
55+
one_hot_target = pd.DataFrame(
56+
one_hot_target, columns=[f"process_{i}" for i in range(one_hot_target.shape[1])]
57+
)
58+
else:
59+
raise ValueError("one_hot_target must be a pandas DataFrame or numpy array")
60+
61+
if dim_titles is None:
62+
dim_titles = [f"dim_{d}" for d in range(embed.shape[1])]
63+
64+
self.embed = embed.copy()
65+
self.one_hot_target = one_hot_target.copy()
66+
self.dim_titles = dim_titles
67+
self.metrics = metrics
68+
self.aggregation_methods = aggregation_methods
69+
70+
self.results = {}
71+
self.aggregated_results = {}
72+
73+
@staticmethod
74+
def _compute_metrics(embed, one_hot_target, dim_titles=None, metrics=()):
75+
if dim_titles is None:
76+
dim_titles = [f"dim_{d}" for d in range(embed.shape[1])]
77+
78+
results = {}
79+
for metric_name in metrics:
80+
result_df = pd.DataFrame(
81+
AVAILABLE_METRICS[metric_name](embed, gt_one_hot=one_hot_target.values),
82+
index=dim_titles,
83+
columns=one_hot_target.columns,
84+
)
85+
results[metric_name] = result_df
86+
87+
return results
88+
89+
@staticmethod
90+
def _aggregate_metrics(results, aggregation_methods=()):
91+
aggregated_results = {}
92+
for aggregation_method in aggregation_methods:
93+
for metric_name in results:
94+
aggregated_results[f"{aggregation_method}-{metric_name}"] = AVAILABLE_AGGREGATION_METHODS[
95+
aggregation_method
96+
](results[metric_name].values)
97+
return aggregated_results
98+
99+
def is_complete(self):
100+
for metric in self.metrics:
101+
if metric not in self.results:
102+
return False
103+
for aggregation_method in self.aggregation_methods:
104+
for metric in self.metrics:
105+
if f"{aggregation_method}-{metric}" not in self.aggregated_results:
106+
return False
107+
return True
108+
109+
def evaluate(self):
110+
if not self.is_complete():
111+
remaining_metrics = [metric for metric in self.metrics if metric not in self.results]
112+
self.results = {
113+
**self.results,
114+
**self._compute_metrics(self.embed, self.one_hot_target, self.dim_titles, remaining_metrics),
115+
}
116+
# Aggregation is cheap. Do it always.
117+
self.aggregated_results = {
118+
**self.aggregated_results,
119+
**self._aggregate_metrics(self.results, self.aggregation_methods),
120+
}
121+
122+
def get_results(self):
123+
return {
124+
f"{aggregation_method}-{metric}": self.aggregated_results[f"{aggregation_method}-{metric}"]
125+
for aggregation_method in self.aggregation_methods
126+
for metric in self.metrics
127+
}
128+
129+
def get_results_details(self):
130+
return {f"{metric}": self.results[metric] for metric in self.metrics}
131+
132+
def save(self, path):
133+
data = {
134+
"version": self.version,
135+
"results": self.results,
136+
"aggregated_results": self.aggregated_results,
137+
"metrics": self.metrics,
138+
"aggregation_methods": self.aggregation_methods,
139+
"dim_titles": self.dim_titles,
140+
}
141+
142+
with open(path, "wb") as f:
143+
pickle.dump(data, f)
144+
145+
@classmethod
146+
def load(cls, path, embed, discrete_target=None, one_hot_target=None, metrics=None, aggregation_methods=None):
147+
with open(path, "rb") as f:
148+
data = pickle.load(f)
149+
150+
assert cls.version == data["version"]
151+
if metrics is None:
152+
metrics = data["metrics"]
153+
if aggregation_methods is None:
154+
aggregation_methods = data["aggregation_methods"]
155+
instance = cls(embed, discrete_target, one_hot_target, data["dim_titles"], metrics, aggregation_methods)
156+
instance.results = data["results"]
157+
instance.aggregated_results = data["aggregated_results"]
158+
return instance
159+
160+
@classmethod
161+
def load_results(cls, path):
162+
with open(path, "rb") as f:
163+
data = pickle.load(f)
164+
return data["aggregated_results"]
165+
166+
@classmethod
167+
def load_results_details(cls, path):
168+
with open(path, "rb") as f:
169+
data = pickle.load(f)
170+
return data["results"]

src/drvi/utils/metrics/_pairwise.py

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,53 +4,71 @@
44
from sklearn.feature_selection import mutual_info_classif
55

66

7-
def _nn_alignment_score_per_dim(var_continues, ct_cat_series):
7+
def check_discrete_metric_input(gt_cat_series=None, gt_one_hot=None):
8+
if gt_cat_series is not None and gt_one_hot is not None:
9+
raise ValueError("Only one of gt_cat_series or gt_one_hot should be provided.")
10+
if gt_cat_series is None and gt_one_hot is None:
11+
raise ValueError("Either gt_cat_series or gt_one_hot must be provided.")
12+
13+
14+
def get_one_hot_encoding(gt_cat_series):
15+
return np.eye(len(gt_cat_series.cat.categories))[gt_cat_series.cat.codes]
16+
17+
18+
def _nn_alignment_score_per_dim(var_continues, gt_01):
819
order = var_continues.argsort()
9-
ct_cat_series = ct_cat_series[order]
10-
ct_01 = np.eye(len(ct_cat_series.cat.categories))[ct_cat_series.cat.codes]
20+
gt_01 = gt_01[order]
1121
alignment = np.clip(
1222
(
13-
np.sum(ct_01[:-1, :] * ct_01[1:, :], axis=0) / (np.sum(ct_01, axis=0) - 1)
23+
np.sum(gt_01[:-1, :] * gt_01[1:, :], axis=0) / (np.sum(gt_01, axis=0) - 1)
1424
) # fraction of cells of this type that are next to a cell of the same type
15-
- (np.sum(ct_01, axis=0) / ct_01.shape[0]), # cancel random neighbors when CT is frequent
25+
- (np.sum(gt_01, axis=0) / gt_01.shape[0]), # cancel random neighbors when GT (ground-truth) is frequent
1626
0,
1727
None,
18-
) / (1 - (np.sum(ct_01, axis=0) / ct_01.shape[0]))
28+
) / (1 - (np.sum(gt_01, axis=0) / gt_01.shape[0]))
1929
return alignment
2030

2131

22-
def _local_mutual_info_score_per_binary_ct(all_vars_continues, ct_binary):
23-
mi_score = mutual_info_classif(all_vars_continues, ct_binary, n_jobs=-1)
24-
ct_prob = np.sum(ct_binary == 1) / ct_binary.shape[0]
25-
ct_entropy = stats.entropy([ct_prob, 1 - ct_prob])
26-
return mi_score / ct_entropy
27-
32+
def nn_alignment_score(all_vars_continues, gt_cat_series=None, gt_one_hot=None):
33+
check_discrete_metric_input(gt_cat_series, gt_one_hot)
34+
gt_01 = get_one_hot_encoding(gt_cat_series) if gt_cat_series is not None else gt_one_hot
2835

29-
def nn_alignment_score(all_vars_continues, ct_cat_series):
3036
n_vars = all_vars_continues.shape[1]
31-
result = np.zeros([n_vars, len(ct_cat_series.cat.categories)])
37+
result = np.zeros([n_vars, gt_01.shape[1]])
3238
for i in range(n_vars):
33-
result[i, :] = _nn_alignment_score_per_dim(all_vars_continues[:, i], ct_cat_series)
39+
result[i, :] = _nn_alignment_score_per_dim(all_vars_continues[:, i], gt_01)
3440
return result
3541

3642

37-
def local_mutual_info_score(all_vars_continues, ct_cat_series):
43+
def _local_mutual_info_score_per_binary_gt(all_vars_continues, gt_binary):
44+
mi_score = mutual_info_classif(all_vars_continues, gt_binary, n_jobs=-1)
45+
gt_prob = np.sum(gt_binary == 1) / gt_binary.shape[0]
46+
gt_entropy = stats.entropy([gt_prob, 1 - gt_prob])
47+
return mi_score / gt_entropy
48+
49+
50+
def local_mutual_info_score(all_vars_continues, gt_cat_series=None, gt_one_hot=None):
51+
check_discrete_metric_input(gt_cat_series, gt_one_hot)
52+
gt_01 = get_one_hot_encoding(gt_cat_series) if gt_cat_series is not None else gt_one_hot
53+
3854
n_vars = all_vars_continues.shape[1]
39-
result = np.zeros([n_vars, len(ct_cat_series.cat.categories)])
40-
ct_01 = np.eye(len(ct_cat_series.cat.categories))[ct_cat_series.cat.codes].T
41-
for j in range(ct_01.shape[0]):
42-
result[:, j] = _local_mutual_info_score_per_binary_ct(all_vars_continues, ct_01[j])
55+
result = np.zeros([n_vars, gt_01.shape[1]])
56+
for j in range(gt_01.shape[1]):
57+
result[:, j] = _local_mutual_info_score_per_binary_gt(all_vars_continues, gt_01[:, j])
4358
return result
4459

4560

46-
def global_dim_mutual_info_score(all_vars_continues, ct_cat_series):
47-
mi_score = mutual_info_classif(all_vars_continues, ct_cat_series)
48-
ct_entropy = stats.entropy(pd.Series(ct_cat_series).value_counts(normalize=True, sort=False))
49-
return mi_score / ct_entropy
50-
61+
def spearman_correlataion_score(all_vars_continues, gt_cat_series=None, gt_one_hot=None):
62+
check_discrete_metric_input(gt_cat_series, gt_one_hot)
63+
gt_01 = get_one_hot_encoding(gt_cat_series) if gt_cat_series is not None else gt_one_hot
5164

52-
def spearman_correlataion_score(all_vars_continues, ct_cat_series):
5365
n_vars = all_vars_continues.shape[1]
54-
ct_01 = np.eye(len(ct_cat_series.cat.categories))[ct_cat_series.cat.codes]
55-
result = np.abs(stats.spearmanr(all_vars_continues, ct_01).statistic[:n_vars, n_vars:])
66+
result = np.abs(stats.spearmanr(all_vars_continues, gt_01).statistic[:n_vars, n_vars:])
5667
return result
68+
69+
70+
def global_dim_mutual_info_score(all_vars_continues, gt_cat_series):
71+
# This metric is not used in any analysis, but is provided for completeness.
72+
mi_score = mutual_info_classif(all_vars_continues, gt_cat_series)
73+
gt_entropy = stats.entropy(pd.Series(gt_cat_series).value_counts(normalize=True, sort=False))
74+
return mi_score / gt_entropy

tests/utils/__init__.py

Whitespace-only changes.

tests/utils/metrics/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)