-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathevaluation.py
More file actions
80 lines (59 loc) · 2.42 KB
/
Copy pathevaluation.py
File metadata and controls
80 lines (59 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# !/usr/bin/env python3
"""
Evaluation code for Quora paraphrase detection.
model_eval_paraphrase is suitable for the dev (and train) dataloaders where the label information is available.
model_test_paraphrase is suitable for the test dataloader where label information is not available.
"""
import torch
from sklearn.metrics import f1_score, accuracy_score
from tqdm import tqdm
import numpy as np
from sacrebleu.metrics import CHRF
from datasets import (
SonnetsDataset,
)
TQDM_DISABLE = False
@torch.no_grad()
def model_eval_paraphrase(dataloader, model, device):
model.eval() # Switch to eval model, will turn off randomness like dropout.
y_true, y_pred, sent_ids = [], [], []
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
b_ids, b_mask, b_sent_ids, labels = batch['token_ids'], batch['attention_mask'], batch['sent_ids'], batch[
'labels'].flatten()
b_ids = b_ids.to(device)
b_mask = b_mask.to(device)
logits = model(b_ids, b_mask).cpu().numpy()
preds = np.argmax(logits, axis=1).flatten()
y_true.extend(labels)
y_pred.extend(preds)
sent_ids.extend(b_sent_ids)
f1 = f1_score(y_true, y_pred, average='macro')
acc = accuracy_score(y_true, y_pred)
return acc, f1, y_pred, y_true, sent_ids
@torch.no_grad()
def model_test_paraphrase(dataloader, model, device):
model.eval() # Switch to eval model, will turn off randomness like dropout.
y_true, y_pred, sent_ids = [], [], []
for step, batch in enumerate(tqdm(dataloader, desc=f'eval', disable=TQDM_DISABLE)):
b_ids, b_mask, b_sent_ids = batch['token_ids'], batch['attention_mask'], batch['sent_ids']
b_ids = b_ids.to(device)
b_mask = b_mask.to(device)
logits = model(b_ids, b_mask).cpu().numpy()
preds = np.argmax(logits, axis=1).flatten()
y_pred.extend(preds)
sent_ids.extend(b_sent_ids)
return y_pred, sent_ids
def test_sonnet(
test_path='predictions/generated_sonnets.txt',
gold_path='data/TRUE_sonnets_held_out.txt'
):
chrf = CHRF()
# get the sonnets
generated_sonnets = [x[1] for x in SonnetsDataset(test_path)]
true_sonnets = [x[1] for x in SonnetsDataset(gold_path)]
max_len = min(len(true_sonnets), len(generated_sonnets))
true_sonnets = true_sonnets[:max_len]
generated_sonnets = generated_sonnets[:max_len]
# compute chrf
chrf_score = chrf.corpus_score(generated_sonnets, [true_sonnets])
return float(chrf_score.score)