Skip to content

Commit d1bccd0

Browse files
committed
Refactored benchmark grn dict
1 parent 825baf4 commit d1bccd0

6 files changed

Lines changed: 130 additions & 97 deletions

File tree

src/gretapy/_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def show_terms(organism: str | None = None) -> pd.DataFrame:
114114
assert organism in organisms, f"organism={organism} not available ({organisms})"
115115
df = df[df["organism"] == organism].drop(columns="organism")
116116
df = df.reset_index(drop=True)
117+
df["db_name"] = df["db_name"].str.replace("Human Protein Atlas (HPA)", "HPA", regex=False)
117118
return df
118119

119120

src/gretapy/config.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,11 @@
112112
},
113113
# Mechanistic
114114
"KnockTF (scoring)": {
115-
"fname": "hg38_prt_knocktf.h5ad",
115+
"fname": "hg38_prt_knocktf.h5ad.gz",
116116
"metric": "TF Scoring",
117117
},
118118
"KnockTF (forecasting)": {
119-
"fname": "hg38_prt_knocktf.h5ad",
119+
"fname": "hg38_prt_knocktf.h5ad.gz",
120120
"metric": "Perturbation Forecasting",
121121
},
122122
"Boolean rules": {
@@ -2424,7 +2424,7 @@
24242424
"metric": None,
24252425
},
24262426
"DoRothEA": {
2427-
"fname": "hg38_gst_dorothea.csv.gz",
2427+
"fname": "mm10_gst_dorothea.csv.gz",
24282428
"metric": None,
24292429
},
24302430
# Literature
@@ -2446,19 +2446,19 @@
24462446
"metric": "TF Binding",
24472447
},
24482448
"ENCODE Blacklist": {
2449-
"fname": "hg38_cre_blacklist.bed.gz",
2449+
"fname": "mm10_cre_blacklist.bed.gz",
24502450
"metric": "CREs",
24512451
},
24522452
"ENCODE CREs": {
2453-
"fname": "hg38_cre_encode.bed.gz",
2453+
"fname": "mm10_cre_encode.bed.gz",
24542454
"metric": "CREs",
24552455
},
24562456
"phastCons": {
2457-
"fname": "hg38_cre_phastcons.bed.gz",
2457+
"fname": "mm10_cre_phastcons.bed.gz",
24582458
"metric": "CREs",
24592459
},
24602460
"Promoters": {
2461-
"fname": "hg38_cre_promoters.bed.gz",
2461+
"fname": "mm10_cre_promoters.bed.gz",
24622462
"metric": "CREs",
24632463
},
24642464
# Predictive
@@ -2488,11 +2488,11 @@
24882488
},
24892489
# Mechanistic
24902490
"KnockTF (scoring)": {
2491-
"fname": "m10_prt_knocktf.h5ad.gz",
2491+
"fname": "mm10_prt_knocktf.h5ad.gz",
24922492
"metric": "TF Scoring",
24932493
},
24942494
"KnockTF (forecasting)": {
2495-
"fname": "m10_prt_knocktf.h5ad.gz",
2495+
"fname": "mm10_prt_knocktf.h5ad.gz",
24962496
"metric": "Perturbation Forecasting",
24972497
},
24982498
"Boolean rules": {

src/gretapy/ds/_db.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import gzip
12
import os
23
import shutil
4+
import tempfile
35

46
import anndata as ad
57
import decoupler as dc
@@ -23,17 +25,22 @@ def _download_db(
2325
fname = DATA[organism]["dbs"][db_name]["fname"]
2426
path_fname = os.path.join(PATH_DATA, fname)
2527
if not os.path.isfile(path_fname):
26-
if fname != "hg38_prt_knocktf.h5ad":
27-
url = URL_STR + fname + URL_END
28-
data = _download(url, verbose=verbose)
29-
data.seek(0) # Move pointer to beginning
28+
url = URL_STR + fname + URL_END
29+
data = _download(url, verbose=verbose)
30+
data.seek(0)
31+
if not '.h5ad' in fname:
3032
with open(path_fname, "wb") as f:
3133
shutil.copyfileobj(data, f)
32-
m = f"Database {db_name} saved in {path_fname}"
33-
_log(m, level="info", verbose=verbose)
3434
else:
35-
adata = dc.ds.knocktf(thr_fc=100_000, verbose=verbose) # Do not filter here
35+
with tempfile.NamedTemporaryFile(suffix=".h5ad", delete=False) as tmp:
36+
tmp_path = tmp.name
37+
with gzip.GzipFile(fileobj=data) as gz:
38+
shutil.copyfileobj(gz, tmp)
39+
adata = ad.read_h5ad(tmp_path)
3640
adata.write(path_fname)
41+
os.remove(tmp_path)
42+
m = f"Database {db_name} saved in {path_fname}"
43+
_log(m, level="info", verbose=verbose)
3744
else:
3845
m = f"Database {db_name} found in {path_fname}"
3946
_log(m, level="info", verbose=verbose)

src/gretapy/pp/_check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,6 @@ def _check_terms(
261261
diff_terms = list(db_terms - og_db_terms)
262262
n_diff = len(diff_terms)
263263
assert n_diff == 0, (
264-
f"{n_diff} terms do not exist in db={db}: {diff_terms[:5]} View available options: gretapy.show_terms()"
264+
f"{n_diff} terms do not exist in organism={organism}, dataset={dataset}, db={db}:\n{diff_terms[:5]} View available options: gretapy.show_terms()"
265265
)
266266
return terms

src/gretapy/tl/_eval.py

Lines changed: 100 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from gretapy.ds._db import read_db
1111
from gretapy.pp._check import (
1212
_check_dataset,
13-
_check_datasets,
1413
_check_dts_grn,
1514
_check_grn,
1615
_check_metrics,
@@ -48,8 +47,8 @@ def _format_label(grn_name: str | None = None, dataset_name: str | None = None)
4847

4948

5049
def benchmark(
51-
organism: str,
5250
grns: dict,
51+
organism: str | None = None,
5352
datasets: list | dict | None = None,
5453
terms: dict | None = None,
5554
metrics: str | list | None = None,
@@ -61,19 +60,20 @@ def benchmark(
6160
6261
Parameters
6362
----------
64-
organism
65-
Which organism to use (e.g., "hg38", "mm10").
6663
grns
67-
Dictionary mapping GRN names to per-dataset GRN DataFrames.
68-
Structure: ``{grn_name: {dataset_name: DataFrame}}``.
64+
Dictionary mapping GRN names to per-organism per-dataset GRN DataFrames.
65+
Structure: ``{grn_name: {organism: {dataset_name: DataFrame}}}``.
66+
organism
67+
Ignored when organism keys are present in ``grns``. Kept for clarity
68+
but organisms are inferred from the second level of ``grns``.
6969
datasets
7070
Dataset(s) to evaluate against. Can be:
71-
- None: Use all datasets available in config for the organism.
72-
- list: A list of dataset names from config.
73-
- dict: A dictionary mapping dataset names to pre-loaded MuData/AnnData objects.
71+
- None: Use all datasets present in the grns dict for each organism.
72+
- list: A whitelist of dataset names (applied across all organisms).
73+
- dict: A flat dictionary mapping dataset names to pre-loaded MuData/AnnData objects.
7474
terms
75-
Optional dictionary specifying filtering terms per dataset and metric.
76-
Structure: ``{dataset_name: {db_name: [terms]}}``.
75+
Optional dictionary specifying filtering terms per organism, dataset, and metric.
76+
Structure: ``{organism: {dataset_name: {db_name: [terms]}}}``.
7777
If None, terms are auto-loaded from config for each dataset.
7878
metrics
7979
Metric(s) to evaluate. Can be category name, metric type, or database name.
@@ -85,7 +85,7 @@ def benchmark(
8585
8686
Returns
8787
-------
88-
DataFrame with columns: grn, dataset, category, metric, db, precision, recall, f01.
88+
DataFrame with columns: grn, organism, dataset, class, task, db, precision, recall, f01.
8989
9090
Example
9191
-------
@@ -94,105 +94,130 @@ def benchmark(
9494
import gretapy as gt
9595
import pandas as pd
9696
97-
# Dataset-specific GRNs
97+
# Multi-organism GRNs
9898
grns = {
9999
"method_a": {
100-
"pbmc10k": pd.read_csv("grn_a_pbmc10k.csv"),
101-
"brain": pd.read_csv("grn_a_brain.csv"),
100+
"hg38": {
101+
"PBMC": pd.read_csv("grn_a_pbmc.csv"),
102+
"Lung": pd.read_csv("grn_a_lung.csv"),
103+
},
104+
"mm10": {
105+
"Palate": pd.read_csv("grn_a_palate.csv"),
106+
},
102107
},
103108
"method_b": {
104-
"pbmc10k": pd.read_csv("grn_b_pbmc10k.csv"),
109+
"hg38": {
110+
"PBMC": pd.read_csv("grn_b_pbmc.csv"),
111+
},
105112
},
106113
}
107-
results = gt.tl.benchmark(
108-
organism="hg38",
109-
grns=grns,
110-
datasets=None, # all datasets from config
111-
)
114+
results = gt.tl.benchmark(grns=grns)
112115
113116
# With pre-loaded datasets
114117
results = gt.tl.benchmark(
115-
organism="hg38",
116118
grns=grns,
117-
datasets={"pbmc10k": mudata_obj, "brain": mudata_obj2},
119+
datasets={"PBMC": mudata_obj, "Lung": mudata_obj2},
118120
)
119121
"""
120-
# Validate organism
121-
_check_organism(organism=organism)
122-
# Validate grns: must be dict[str, dict[str, pd.DataFrame]]
122+
# Validate grns: must be dict[str, dict[str, dict[str, pd.DataFrame]]]
123123
if not isinstance(grns, dict):
124-
raise ValueError(f"grns must be dict[str, dict[str, DataFrame]], got {type(grns)}")
124+
raise ValueError(f"grns must be dict[str, dict[str, dict[str, DataFrame]]], got {type(grns)}")
125125
for grn_name, grn_inner in grns.items():
126126
if not isinstance(grn_inner, dict):
127127
raise ValueError(
128-
f"grns['{grn_name}'] must be a dict mapping dataset names to DataFrames, got {type(grn_inner)}"
128+
f"grns['{grn_name}'] must be a dict mapping organism names to dicts, got {type(grn_inner)}"
129129
)
130+
for org_key, org_inner in grn_inner.items():
131+
if not isinstance(org_inner, dict):
132+
raise ValueError(
133+
f"grns['{grn_name}']['{org_key}'] must be a dict mapping dataset names to DataFrames, "
134+
f"got {type(org_inner)}"
135+
)
130136
grns_dict = grns
131-
# Validate and normalize datasets
132-
datasets_objects = None
133-
if datasets is None or isinstance(datasets, list):
134-
datasets_list = _check_datasets(organism=organism, datasets=datasets)
135-
elif isinstance(datasets, dict):
136-
datasets_list = list(datasets.keys())
137-
datasets_objects = datasets
138-
else:
137+
# Extract and validate organisms from grns
138+
organisms_in_grns = {org for inner in grns_dict.values() for org in inner}
139+
if not organisms_in_grns:
140+
raise ValueError("grns is empty or contains no organism keys. Provide at least one organism.")
141+
if organism is not None:
142+
_log(
143+
f"'organism' parameter ('{organism}') is ignored when organisms are encoded in the grns dict. "
144+
"Organisms are inferred from grns keys.",
145+
level="warning",
146+
verbose=verbose,
147+
)
148+
for org in organisms_in_grns:
149+
_check_organism(organism=org)
150+
# Validate datasets input type
151+
if not (datasets is None or isinstance(datasets, (list, dict))):
139152
raise ValueError(f"datasets must be None, list, or dict, got {type(datasets)}")
140-
# Validate metrics
141-
_check_metrics(organism=organism, metrics=metrics)
142-
# Run benchmark
143-
n_pairs = sum(1 for inner in grns_dict.values() for ds in datasets_list if ds in inner)
153+
datasets_objects = datasets if isinstance(datasets, dict) else None
154+
# Count pairs for logging
155+
n_pairs = sum(
156+
1
157+
for inner in grns_dict.values()
158+
for org_inner in inner.values()
159+
for ds in org_inner
160+
if datasets is None or (isinstance(datasets, list) and ds in datasets) or (isinstance(datasets, dict) and ds in datasets)
161+
)
144162
_log(_SEP, level="info", verbose=verbose)
145163
_log(
146-
f"Starting benchmark: {len(grns_dict)} GRN(s), {len(datasets_list)} dataset(s), {n_pairs} pair(s)",
164+
f"Starting benchmark: {len(grns_dict)} GRN(s), {len(organisms_in_grns)} organism(s), {n_pairs} pair(s)",
147165
level="info",
148166
verbose=verbose,
149167
)
150168
_log(_SEP, level="info", verbose=verbose)
151169
t_start_bench = time.time()
152170
all_results = []
153171
for grn_name, grn_inner in grns_dict.items():
154-
for dataset_name in datasets_list:
155-
if dataset_name not in grn_inner:
156-
continue # skip silently
157-
grn_df = grn_inner[dataset_name]
158-
# Resolve dataset: string name or pre-loaded object
159-
dataset_arg = datasets_objects[dataset_name] if datasets_objects else dataset_name
160-
# Resolve terms before eval
161-
if terms is None:
162-
dataset_terms = _check_terms(organism=organism, dataset=dataset_name, terms=None)
172+
for org, org_inner in grn_inner.items():
173+
# Determine dataset list for this organism
174+
if datasets is None:
175+
ds_list = list(org_inner.keys())
163176
else:
164-
dataset_terms = terms.get(dataset_name, {})
165-
# Warn if no auto-loaded terms for pre-loaded datasets not in config
166-
if terms is None and datasets_objects is not None and not dataset_terms:
167-
_log(
168-
f"No terms auto-loaded for dataset '{dataset_name}' (not in config). "
169-
"Metrics requiring terms will run unfiltered.",
170-
level="warning",
177+
ds_list = [d for d in (datasets if isinstance(datasets, list) else datasets.keys()) if d in org_inner]
178+
# Validate metrics per organism
179+
_check_metrics(organism=org, metrics=metrics)
180+
for dataset_name in ds_list:
181+
grn_df = org_inner[dataset_name]
182+
# Resolve dataset: string name or pre-loaded object
183+
dataset_arg = datasets_objects[dataset_name] if datasets_objects else dataset_name
184+
# Resolve terms before eval (new 3-level structure)
185+
if terms is None:
186+
dataset_terms = _check_terms(organism=org, dataset=dataset_name, terms=None)
187+
else:
188+
dataset_terms = terms.get(org, {}).get(dataset_name, {})
189+
# Warn if no auto-loaded terms for pre-loaded datasets not in config
190+
if terms is None and datasets_objects is not None and not dataset_terms:
191+
_log(
192+
f"No terms auto-loaded for dataset '{dataset_name}' (not in config). "
193+
"Metrics requiring terms will run unfiltered.",
194+
level="warning",
195+
verbose=verbose,
196+
)
197+
# Run evaluation
198+
result = eval_grn_dataset(
199+
organism=org,
200+
grn=grn_df,
201+
dataset=dataset_arg,
202+
terms=dataset_terms,
203+
metrics=metrics,
204+
min_edges=min_edges,
205+
grn_name=grn_name,
206+
dataset_name=dataset_name,
171207
verbose=verbose,
172208
)
173-
# Run evaluation
174-
result = eval_grn_dataset(
175-
organism=organism,
176-
grn=grn_df,
177-
dataset=dataset_arg,
178-
terms=dataset_terms,
179-
metrics=metrics,
180-
min_edges=min_edges,
181-
grn_name=grn_name,
182-
dataset_name=dataset_name,
183-
verbose=verbose,
184-
)
185-
# Add identifiers
186-
if not result.empty:
187-
result.insert(0, "grn", grn_name)
188-
result.insert(1, "dataset", dataset_name)
189-
all_results.append(result)
209+
# Add identifiers
210+
if not result.empty:
211+
result.insert(0, "grn", grn_name)
212+
result.insert(1, "organism", org)
213+
result.insert(2, "dataset", dataset_name)
214+
all_results.append(result)
190215
elapsed = time.time() - t_start_bench
191216
_log(_SEP, level="info", verbose=verbose)
192217
_log(f"Benchmark complete ({len(all_results)} result(s), {elapsed:.1f}s)", level="info", verbose=verbose)
193218
_log(_SEP, level="info", verbose=verbose)
194219
if not all_results:
195-
return pd.DataFrame(columns=["grn", "dataset", "class", "task", "db", "precision", "recall", "f01"])
220+
return pd.DataFrame(columns=["grn", "organism", "dataset", "class", "task", "db", "precision", "recall", "f01"])
196221
return pd.concat(all_results, ignore_index=True)
197222

198223

0 commit comments

Comments
 (0)