forked from martiansideofthemoon/ai-detection-paraphrases
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparaphrase.py
105 lines (86 loc) · 3.98 KB
/
paraphrase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import json
import nltk
import time
import os
import tqdm
from nltk.tokenize import sent_tokenize
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
from utils import form_partitions
nltk.download('punkt')
parser = argparse.ArgumentParser()
parser.add_argument('--output_file', default="lfqa-data/gpt2_xl_strength_2.0_frac_0.5_300_len_top_p_0.9.jsonl")
parser.add_argument('--model', type=str, default='kalpeshk2011/dipper-paraphraser-xxl', help='data to paraphrase')
parser.add_argument("--refresh", action='store_true', help="Renew the data")
parser.add_argument('--num_shards', type=int, default=1, help='data to paraphrase')
parser.add_argument('--local_rank', type=int, default=0, help='data to paraphrase')
parser.add_argument('--sent_interval', type=int, default=3, help='data to paraphrase')
args = parser.parse_args()
# read args.output_file
with open(args.output_file, "r") as f:
data = [json.loads(x) for x in f.read().strip().split("\n")]
if "no-ctx" in args.model:
args.output_file = args.output_file + "_no_ctx"
if args.sent_interval == 1:
args.output_file = args.output_file + "_sent"
if args.num_shards > 1:
partitions = form_partitions(data, args.num_shards)
data = partitions[args.local_rank]
output_file = f'{args.output_file}_pp.shard_{args.local_rank}'
else:
output_file = args.output_file + "_pp"
if os.path.exists(output_file):
with open(output_file, "r") as f:
num_output_points = len([json.loads(x) for x in f.read().strip().split("\n")])
else:
num_output_points = 0
print(f"Skipping {num_output_points} points")
time1 = time.time()
tokenizer = T5Tokenizer.from_pretrained('google/t5-v1_1-xxl')
model = T5ForConditionalGeneration.from_pretrained(args.model)
print("Model loaded in ", time.time() - time1)
model.cuda()
model.eval()
# iterate over data and tokenize each instance
for idx, dd in tqdm.tqdm(enumerate(data), total=len(data)):
if idx < num_output_points:
continue
# tokenize prefix
if args.refresh or "paraphrase_outputs" not in dd:
paraphrase_outputs = {}
for lex, order in [(40, 40), (40, 100), (60, 100), (80, 100)]:
if isinstance(dd['gen_completion'], str):
input_gen = dd['gen_completion'].strip()
else:
input_gen = dd['gen_completion'][0].strip()
# remove spurious newlines
input_gen = " ".join(input_gen.split())
sentences = sent_tokenize(input_gen)
prefix = " ".join(dd['prefix'].replace("\n", " ").split())
output_text = ""
final_input_text = ""
for sent_idx in range(0, len(sentences), args.sent_interval):
curr_sent_window = " ".join(sentences[sent_idx:sent_idx + args.sent_interval])
if "no-context" in args.model:
final_input_text = f"lexical = {lex}, order = {order} {curr_sent_window}"
else:
final_input_text = f"lexical = {lex}, order = {order} {prefix} <sent> {curr_sent_window} </sent>"
if idx == 0 and lex == 40 and order == 40:
print(final_input_text)
final_input = tokenizer([final_input_text], return_tensors="pt")
final_input = {k: v.cuda() for k, v in final_input.items()}
with torch.inference_mode():
outputs = model.generate(**final_input, do_sample=True, top_p=0.75, top_k=None, max_length=512)
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
prefix += " " + outputs[0]
output_text += " " + outputs[0]
paraphrase_outputs[f"lex_{lex}_order_{order}"] = {
"final_input": final_input_text,
"output": [output_text.strip()],
"lex": lex,
"order": order
}
dd["paraphrase_outputs"] = paraphrase_outputs
with open(output_file, "a") as f:
f.write(json.dumps(dd) + "\n")