-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbot.py
More file actions
107 lines (85 loc) · 3.45 KB
/
bot.py
File metadata and controls
107 lines (85 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from os import getenv
from tempfile import NamedTemporaryFile
import soundfile
import torch
from librosa import load
from telebot import TeleBot
from telebot.types import Message
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from sys import path
path.append("vits")
import commons
import utils
from models import SynthesizerTrn
repo_name = "volodya-leveryev/mms-300m-sah"
revision = "9b2efc815c0f3085e2902d8647c4c35d122e8b82"
model = Wav2Vec2ForCTC.from_pretrained(repo_name, revision=revision)
processor = Wav2Vec2Processor.from_pretrained(repo_name, revision=revision)
vocab_file = "./sah/vocab.txt"
config_file = "./sah/config.json"
hps = utils.get_hparams_from_file(config_file)
with open(vocab_file, encoding="utf-8") as f:
data = (x.replace("\n", "") for x in f.readlines())
symbols_to_id = {s: i for i, s in enumerate(data)}
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
net_g = SynthesizerTrn(
len(symbols_to_id),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model,
)
net_g.to(device)
_ = net_g.eval()
g_pth = f"./sah/G_100000.pth"
_ = utils.load_checkpoint(g_pth, net_g, None)
token = getenv("TG_TOKEN", "")
bot = TeleBot(token)
def filter_oov(text):
return "".join(list(filter(lambda x: x in symbols_to_id, text)))
def text_to_sequence(text):
sequence = []
clean_text = text.strip()
for symbol in clean_text:
symbol_id = symbols_to_id[symbol]
sequence += [symbol_id]
return sequence
@bot.message_handler(commands=["start", "help"])
def send_welcome(msg: Message):
text = "👋🏻 Бу — киһи сахалыы саҥатын тиэкискэ кубулутар тэрил (bot, робот). Микрофоҥҥа саҥарыаххын эбэтэр аудио ыытыаххын сөп."
bot.reply_to(msg, text)
@bot.message_handler(content_types=["audio", "voice", "text"])
def speech_to_text(msg: Message):
if msg.content_type in ("audio", "voice"):
# Скачивание записи
record = msg.audio or msg.voice
file = bot.get_file(record.file_id)
content = bot.download_file(file.file_path)
with NamedTemporaryFile() as tmp_file:
tmp_file.write(content)
audio, _ = load(tmp_file.name, sr=16_000)
# Распознавание речи
input_dict = processor(audio, return_tensors="pt", padding=True)
output = model(input_dict.input_values).logits
predictions = torch.argmax(output, dim=-1)[0]
# Ответ бота
bot.reply_to(msg, processor.decode(predictions))
elif msg.content_type in ("text",):
text = filter_oov(msg.text.lower())
text_norm = text_to_sequence(text)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
with torch.no_grad():
x_tst = text_norm.unsqueeze(0).to(device)
x_tst_lengths = torch.LongTensor([text_norm.size(0)]).to(device)
hyp = net_g.infer(
x_tst, x_tst_lengths, noise_scale=.667,
noise_scale_w=0.8, length_scale=1.0,
)[0][0,0].cpu().float().numpy()
with NamedTemporaryFile() as tmp_file:
soundfile.write(tmp_file.name, hyp, 16_000, format="WAV")
bot.send_audio(msg.chat.id, tmp_file, reply_to_message_id=msg.id)
bot.infinity_polling()