Skip to content

Commit 6aed847

Browse files
committed
Add make_factor_obs and bugfixes
1 parent 0483f19 commit 6aed847

File tree

5 files changed

+58
-23
lines changed

5 files changed

+58
-23
lines changed

src/scdef/models/_scdef.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,7 +1556,8 @@ def filter_factors(
15561556
self,
15571557
thres: Optional[float] = 0.0,
15581558
iqr_mult: Optional[float] = 0.0,
1559-
min_cells: Optional[float] = 0.001,
1559+
min_cells_upper: Optional[float] = 0.001,
1560+
min_cells_lower: Optional[float] = 0.0,
15601561
filter_up: Optional[bool] = True,
15611562
normalized: Optional[bool] = False,
15621563
):
@@ -1568,17 +1569,20 @@ def filter_factors(
15681569
min_cells: minimum number of cells that each factor must have attached to it for it to be kept. If between 0 and 1, fraction. Otherwise, absolute value
15691570
filter_up: whether to remove factors in upper layers via inter-layer attachments
15701571
"""
1571-
if min_cells != 0:
1572-
if min_cells < 1.0:
1573-
min_cells = max(min_cells * self.adata.shape[0], 10)
1572+
if min_cells_upper != 0:
1573+
if min_cells_upper < 1.0:
1574+
min_cells_upper = max(min_cells_upper * self.adata.shape[0], 10)
1575+
if min_cells_lower != 0:
1576+
if min_cells_lower < 1.0:
1577+
min_cells_lower = max(min_cells_lower * self.adata.shape[0], 10)
15741578

15751579
self.factor_lists = []
15761580
for i, layer_name in enumerate(self.layer_names):
15771581
if i == 0:
15781582
keep = self.get_effective_factors(
15791583
thres=thres,
15801584
iqr_mult=iqr_mult,
1581-
min_cells=min_cells,
1585+
min_cells=min_cells_lower,
15821586
normalized=normalized,
15831587
)
15841588
else:
@@ -1590,7 +1594,7 @@ def filter_factors(
15901594
]
15911595
)
15921596
keep = np.array(range(self.layer_sizes[i]))[
1593-
np.where(counts >= min_cells)[0]
1597+
np.where(counts >= min_cells_upper)[0]
15941598
]
15951599
if filter_up:
15961600
mat = self.pmeans[f"{layer_name}W"][keep]

src/scdef/plotting/graph.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -939,22 +939,22 @@ def plot_biological_hierarchy(model, **kwargs):
939939
return g
940940

941941

942-
def plot_technical_hierarchy(model, **kwargs):
942+
def plot_technical_hierarchy(model, show_signatures=True, **kwargs):
943943
technical_signature = None
944944
technical_scores = None
945-
if "show_signatures" in kwargs:
946-
if kwargs["show_signatures"]:
947-
technical_signature, technical_scores = get_technical_signature(
948-
model,
949-
return_scores=True,
950-
top_genes=None,
951-
)
945+
if show_signatures:
946+
technical_signature, technical_scores = get_technical_signature(
947+
model,
948+
return_scores=True,
949+
top_genes=None,
950+
)
952951
g = make_technical_hierarchy_graph(
953952
model,
954953
hierarchy=model.adata.uns["technical_hierarchy"],
955954
root_gene_rankings=[technical_signature],
956955
root_gene_scores=[technical_scores],
957956
root_name="tech_top",
957+
show_signatures=show_signatures,
958958
**kwargs,
959959
)
960960
return g

src/scdef/tools/factor.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
import numpy as np
22
import pandas as pd
3+
from .hierarchy import get_hierarchy, compute_hierarchy_scores
4+
5+
6+
def make_factor_obs(model):
7+
res = compute_hierarchy_scores(model)
8+
model.adata.uns["factor_obs"] = res["per_factor"].set_index("child_factor")
9+
model.adata.uns["factor_obs"]["ARD"] = np.array(
10+
[np.nan] * len(model.adata.uns["factor_obs"])
11+
)
12+
model.adata.uns["factor_obs"]["BRD"] = np.array(
13+
[np.nan] * len(model.adata.uns["factor_obs"])
14+
)
15+
model.adata.uns["factor_obs"].loc[model.factor_names[0], "ARD"] = np.asarray(
16+
model.pmeans["factor_means"]
17+
)[model.factor_lists[0]].ravel()
18+
model.adata.uns["factor_obs"].loc[model.factor_names[0], "BRD"] = np.asarray(
19+
model.pmeans["factor_concentrations"]
20+
)[model.factor_lists[0]].ravel()
321

422

523
def set_factor_signatures(model, signatures=None, top_genes=10):
@@ -10,16 +28,30 @@ def set_factor_signatures(model, signatures=None, top_genes=10):
1028

1129

1230
def set_technical_factors(model, factors=None):
13-
"""Set the technical factors of the model."""
31+
"""Set the technical factors of the model. They must be layer 0 factors."""
1432
# in model.adata.uns["factor_obs"], annotate as technical or not.
33+
all_factor_names = [name for names in model.factor_names for name in names][
34+
:-1
35+
] # do not use root
1536
if "factor_obs" not in model.adata.uns:
1637
# Collect all factor names from all layers into a flat list
17-
all_factor_names = [name for names in model.factor_names for name in names]
18-
model.adata.uns["factor_obs"] = pd.DataFrame(index=all_factor_names)
38+
make_factor_obs(model)
1939
model.adata.uns["factor_obs"]["technical"] = np.array(
2040
[False] * len(all_factor_names)
2141
)
22-
model.adata.uns["factor_obs"]["technical"].loc[factors] = True
42+
model.adata.uns["factor_obs"].loc[factors, "technical"] = True
43+
44+
# Get complete hierarchy
45+
complete_hierarchy = get_hierarchy(model, simplified=False)
46+
# Traverse hierarchy. If all the children of a factor are technical, set the factor as technical.
47+
for factor, children in complete_hierarchy.items():
48+
if all(
49+
[
50+
model.adata.uns["factor_obs"].loc[child, "technical"]
51+
for child in children
52+
]
53+
):
54+
model.adata.uns["factor_obs"].loc[factor, "technical"] = True
2355

2456

2557
def __build_consensus_signature(var_names, gene_scores_array, sizes_array):

src/scdef/tools/hierarchy.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,6 @@ def _usage_weights(child_layer_idx, child_factor_indices):
159159
np.sum(scores_concat * weights_concat) / (np.sum(weights_concat) + eps)
160160
)
161161

162-
if "factor_obs" not in model.adata.uns:
163-
model.adata.uns["factor_obs"] = per_factor.set_index("child_factor")
164-
165162
return {
166163
"per_factor": per_factor,
167164
"per_transition": per_transition,
@@ -185,7 +182,9 @@ def make_technical_hierarchy(model):
185182
"""Make the technical hierarchy of the model."""
186183
technical_factors = model.adata.uns["factor_obs"][
187184
model.adata.uns["factor_obs"]["technical"]
188-
].index.tolist() # layer 0
185+
][
186+
model.adata.uns["factor_obs"]["child_layer"] == "L0"
187+
].index.tolist() # only layer 0 factors
189188
# technical hierarchy is a root with all technical factors as direct children.
190189
# connection weights are proportional to the usage of each factor
191190
technical_hierarchy = dict()

tests/test_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_scdef():
9797

9898
model.fit(n_epoch=3)
9999

100-
model.filter_factors(thres=0.0, min_cells=0) # make sure we keep factors
100+
model.filter_factors(thres=0.0, min_cells_lower=0) # make sure we keep factors
101101

102102
model.logger.info(model.factor_lists)
103103

0 commit comments

Comments
 (0)