Skip to content

Commit 00b288a

Browse files
authored
Merge pull request #302 from robinfallegger/patch-3
Fix index retention in _tensor_truth function
2 parents 6118a32 + 54b4453 commit 00b288a

2 files changed

Lines changed: 4 additions & 2 deletions

File tree

src/decoupler/bm/_run.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def _tensor_scores(
5050

5151
def _tensor_truth(obs: pd.DataFrame, srcs: np.ndarray) -> pd.DataFrame:
5252
# Explode nested perturbs and pivot into mat
53-
grts = obs.explode("source").pivot(columns="source", values="type_p").notna().astype(float).fillna(0.0)
53+
grts = (
54+
obs.explode("source").pivot(columns="source", values="type_p").notna().astype(float).fillna(0.0).loc[obs.index]
55+
)
5456
miss_srcs = srcs[~np.isin(srcs, grts.columns)]
5557
miss_srcs = pd.DataFrame(0, index=grts.index, columns=miss_srcs)
5658
grts = pd.concat([grts, miss_srcs], axis=1)

tests/bm/test_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
[["auc"], None, "expr", False, 0.05, 5, False],
1414
[["auc", "fscore"], "group", "expr", False, 0.05, 5, False],
1515
[["auc", "fscore", "qrank"], None, "source", False, 0.05, 2, False],
16-
[["auc", "fscore", "qrank"], "group", "source", False, 0.05, 1, False],
16+
[["auc", "fscore", "qrank"], "class", "source", False, 0.05, 1, False],
1717
[["auc", "fscore", "qrank"], "bm_group", "expr", True, 0.05, 5, False],
1818
[["auc", "fscore", "qrank"], "source", "expr", True, 0.05, 5, False],
1919
],

0 commit comments

Comments
 (0)