forked from PaddlePaddle/PaddleHelix
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
114 lines (92 loc) · 3.11 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import random
import argparse
import time
from datetime import datetime
from tqdm import tqdm
import paddle
paddle.disable_static()
import paddle.nn.functional as F
import numpy as np
from model import ProteinSIGN
from dataset import GoTermDataset, GoTermDataLoader
from custom_metrics import do_compute_metrics
from utils import add_saved_args_and_params
def do_compute(model, batch):
logits = model(*batch[:-1])
return logits, batch[-1]
def run_batch(model, data_loader, desc):
logits_list = []
ground_truth = []
for batch in tqdm(data_loader, desc=f"{desc}"):
logits, labels = do_compute(model, batch)
logits_list.append(F.sigmoid(logits).tolist())
ground_truth.append(labels.tolist())
logits_list = np.concatenate(logits_list)
ground_truth = np.concatenate(ground_truth)
metrics = do_compute_metrics(ground_truth, logits_list)
return metrics
def test(model, test_data_loader):
model.eval()
with paddle.no_grad():
test_metrics = run_batch(model, test_data_loader, "test")
print(f"#### Test results")
print("f_max: {0:.4f}, auprc: {1:.4f}".format(*test_metrics))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--cuda", type=str, default="0")
parser.add_argument(
"--test_file",
type=str,
default=f"./data/nrPDB-GO_2019.06.18_test.txt",
help="File containing training protein chains",
)
parser.add_argument(
"--protein_chain_graphs",
type=str,
default="./data/chain_graphs",
help="Path to graph reprsentations of proteins",
)
parser.add_argument(
"--label_data_path",
type=str,
required=True,
help="Mapping containing protein chains with associated labeels",
)
parser.add_argument(
"--model_name",
type=str,
required=True,
help="Mapping containing protein chains with associated labeels",
)
parser.add_argument("--batch_size", type=int, default=32)
args = parser.parse_args()
args.activation = F.relu
task_name = os.path.split(args.label_data_path)[-1]
task_name = os.path.splitext(task_name)[0]
args.task = task_name
if int(args.cuda) == -1:
paddle.set_device("cpu")
else:
paddle.set_device("gpu:%s" % args.cuda)
test_chain_list = [p.strip() for p in open(args.test_file)]
saved_state_dict = paddle.load(args.model_name)
# In-place assignment
add_saved_args_and_params(args, saved_state_dict)
test_dataset = GoTermDataset(
test_chain_list,
args.num_angle,
args.n_channels,
args.protein_chain_graphs,
args.cmap_thresh,
args.label_data_path,
)
test_loader = GoTermDataLoader(test_dataset, batch_size=args.batch_size)
args.n_labels = test_dataset.n_labels
model = ProteinSIGN(args)
model.set_state_dict(saved_state_dict["model"])
model.eval()
print(f"\n{args.task}: Testing on {len(test_dataset)} protein samples.")
print(f"Starting at {datetime.now()}\n")
print(args)
test(model, test_loader)