-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
25 lines (22 loc) · 997 Bytes
/
test.py
File metadata and controls
25 lines (22 loc) · 997 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('all-MiniLM-L6-v2')
def filter_gt_sbert(generated_text, topic):
num = len(generated_text)
topic_ = [topic] * num
embeddings1 = model.encode(generated_text, convert_to_tensor=True)
embeddings2 = model.encode(topic_, convert_to_tensor=True)
cosine_scores = util.pytorch_cos_sim(embeddings1, embeddings2)
ranking = []
for i in range(num):
ranking.append({'index': i, 'score': cosine_scores[i][i]})
ranking = sorted(ranking, key=lambda x: x['score'], reverse=True)
new_gt = []
count = -1
for rank in ranking:
count += 1
if count < num / 2:
new_gt.append(generated_text[rank['index']])
return new_gt
print(filter_gt_sbert(['The cat is perfect.', 'Columbia got bombing threat this morning.', 'Black lives matter.', 'I love you'], ['Bombing attack'] * 4))
# output:
# ['Columbia got bombing threat this morning.', 'I love you']