-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathtranslate.py
More file actions
124 lines (102 loc) · 3.84 KB
/
Copy pathtranslate.py
File metadata and controls
124 lines (102 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
import json
import torch
from IndicTransToolkit import IndicProcessor
from tqdm import tqdm
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
M2M100ForConditionalGeneration,
M2M100Tokenizer,
)
parser = argparse.ArgumentParser()
parser.add_argument("--eng_to_hin",action="store_true")
parser.add_argument("--m2m")
args = parser.parse_args()
mode = args.eng_to_hin
trans_model = args.m2m
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if trans_model != "true":
if mode:
model_name = "ai4bharat/indictrans2-en-indic-1B"
else:
model_name = "ai4bharat/indictrans2-indic-en-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
ip = IndicProcessor(inference=True)
model = model.to(DEVICE)
model.eval()
# Set the source and target languages
if mode:
src_lang, tgt_lang = "eng_Latn", "hin_Deva"
else:
src_lang, tgt_lang = "hin_Deva", "eng_Latn"
img2info = json.load(open('tmp/para_info.json'))
# Translate each para in the list
for img_id in tqdm(img2info.keys()):
img_info = img2info[img_id]
word = img_info['txt']
# Set the source language
tokenizer.src_lang = src_lang
batch = ip.preprocess_batch(
[word],
src_lang=src_lang,
tgt_lang=tgt_lang,
)
# Tokenize and encode the source text
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
# Generate translations
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated tokens into text
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
# Postprocess the translations, including entity replacement
translation = ip.postprocess_batch(generated_tokens, lang=tgt_lang)[0]
img2info[img_id]['trans_txt'] = translation
else:
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M").to(DEVICE)
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
# Set the source and target languages
if mode:
src_lang = "en"
target_lang = "hi"
else:
src_lang = "hi"
target_lang = "en"
# List of English words to translate
img2info = json.load(open("tmp/para_info.json"))
# Initialize an empty list to store translations
translations = {}
# Translate each word in the list
for img_id in tqdm(img2info.keys()):
img_info = img2info[img_id]
word = img_info['txt']
# Set the source language
tokenizer.src_lang = src_lang
# Tokenize and encode the source text
encoded_src = tokenizer(word.lower().strip(), return_tensors="pt").to(DEVICE)
# Generate translations
generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(target_lang)).to("cpu")
# Decode and append the translation to the list
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
img2info[img_id]['trans_txt'] = translation
json.dump(img2info,open("tmp/para_info.json",'w'),indent=4)
print("Translation completed.")