-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest-models.py
More file actions
57 lines (51 loc) · 1.91 KB
/
Copy pathtest-models.py
File metadata and controls
57 lines (51 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
COMMAND LINE ARGUMENTS -
1. Model family
2. Model name (or path for a saved model)
3. Path to the directory where predictions should be written
4. Test Dataset (optional)
"""
import sys
import os
from numpy import argmax
import pandas as pd
from sklearn.metrics import accuracy_score
from simpletransformers.classification import ClassificationModel
if not os.path.exists(sys.argv[3]):
os.makedirs(sys.argv[3])
model = ClassificationModel(
sys.argv[1],
sys.argv[2],
num_labels=2,
use_cuda=True,
cuda_device=0,
args={
"n_gpu": 1,
"op_dir": sys.argv[3],
"reprocess_input_data": True,
},
)
msrp = pd.DataFrame(columns=["text_a", "text_b", "labels"])
with open("/raid/datasets/msrp/msr_paraphrase_train.txt", "r") as f:
lines = f.readlines()[1:]
for i in range(len(lines)):
l = lines[i].strip().split("\t")
msrp.loc[i] = [l[3], l[4], int(l[0])]
print(msrp.shape)
print(msrp[msrp.labels > 0].shape)
result, model_outputs, wrong_predictions = model.eval_model(msrp, acc=accuracy_score)
print(result)
result_rev, model_outputs_rev, wrong_predictions_rev = model.eval_model(msrp.reindex(columns=["text_b", "text_a", "labels"]).rename(columns={"text_b":"text_a", "text_a":"text_b"}), acc=accuracy_score)
print(result_rev)
msrp = pd.DataFrame(columns=["text_a", "text_b", "labels"])
with open("/raid/datasets/msrp/msr_paraphrase_test.txt", "r") as f:
lines = f.readlines()[1:]
for i in range(len(lines)):
l = lines[i].strip().split("\t")
msrp.loc[i] = [l[3], l[4], int(l[0])]
print(msrp.shape)
print(msrp[msrp.labels > 0].shape)
result, model_outputs, wrong_predictions = model.eval_model(msrp, acc=accuracy_score)
print(result)
result_rev, model_outputs_rev, wrong_predictions_rev = model.eval_model(msrp.reindex(columns=["text_b", "text_a", "labels"]).rename(columns={"text_b":"text_a", "text_a":"text_b"}), acc=accuracy_score)
print(result_rev)