-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathper_dataset_benchmark.py
More file actions
54 lines (47 loc) · 1.69 KB
/
Copy pathper_dataset_benchmark.py
File metadata and controls
54 lines (47 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import json
import subprocess
import collections
import ast
import pandas as pd
src_file = "./Data/Validation_set/validation_set.csv"
dataset = pd.read_csv(src_file).to_dict("records")
db_splits = collections.defaultdict(lambda: [])
for entry in dataset:
edb = ast.literal_eval(entry["datasets"])
db_splits[edb[0]].append(entry)
db_splits = dict(db_splits)
dataset_names = list(db_splits.keys())
data_files = []
for k, v in db_splits.items():
file = "valset_{}.csv".format(k)
data_files.append(file)
df = pd.DataFrame(v)
df.to_csv(file, index=False)
for f in data_files:
run_cmd = ["python3", "-m", "synrbl", "run"]
run_cmd.extend(["--out-columns", "expected_reaction"])
run_cmd.extend(["--cache"])
synrbl_p = subprocess.Popen(run_cmd + [f])
rcode = synrbl_p.wait()
if rcode != 0:
raise RuntimeError("SynRBL returned with exit code {}".format(rcode))
benchmark_files = []
for ds in dataset_names:
cmd = ["python3", "-m", "synrbl", "benchmark"]
cmd.extend(["--target-col", "expected_reaction"])
cmd.extend(["--min-confidence", "0"])
benchmark_file = "valset_{}_benchmark.json".format(ds)
benchmark_files.append(benchmark_file)
cmd.extend(["-o", benchmark_file])
file = "valset_{}_out.csv".format(ds)
synrbl_p = subprocess.Popen(cmd + [file])
rcode = synrbl_p.wait()
if rcode != 0:
raise RuntimeError("SynRBL benchmark returned with exit code {}".format(rcode))
benchmark_results = {}
for file, ds in zip(benchmark_files, dataset_names):
with open(file, "r") as f:
b_data = json.load(f)
benchmark_results[ds] = b_data
with open("benchmark_result.json", "w") as f:
json.dump(benchmark_results, f, indent=4)