Skip to content

Commit cec0f34

Browse files
committed
Remove espeak
1 parent 95b0b78 commit cec0f34

File tree

7 files changed

+322
-31
lines changed

7 files changed

+322
-31
lines changed

demo.ipynb

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@
2323
"metadata": {},
2424
"outputs": [],
2525
"source": [
26-
"# Instantiate the TTS engine\n",
27-
"glados_tts = tts.TTSEngine()"
26+
"glados_tts = tts.Synthesizer(\n",
27+
" model_path=str(\"models/glados.onnx\"),\n",
28+
" use_cuda=True,\n",
29+
" speaker_id=0,\n",
30+
" )"
2831
]
2932
},
3033
{
@@ -34,11 +37,10 @@
3437
"outputs": [],
3538
"source": [
3639
"# Generate the audio.\n",
37-
"# Glados is spelt incorrectly on purpose to make the pronunciation more accurate.\n",
38-
"audio = glados_tts.generate_speech_audio(\"Hello, my name is Gladohs. I am an AI created by Aperture Science.\")\n",
40+
"audio = glados_tts.generate_speech_audio(\"Hello, my name is Glados. I am an AI created by Aperture Science.\")\n",
3941
"\n",
4042
"# Play the audio\n",
41-
"sd.play(audio, tts.RATE)"
43+
"sd.play(audio, glados_tts.rate)"
4244
]
4345
},
4446
{
@@ -88,7 +90,7 @@
8890
],
8991
"metadata": {
9092
"kernelspec": {
91-
"display_name": "GLaDOS",
93+
"display_name": "glados",
9294
"language": "python",
9395
"name": "python3"
9496
},
@@ -102,7 +104,7 @@
102104
"name": "python",
103105
"nbconvert_exporter": "python",
104106
"pygments_lexer": "ipython3",
105-
"version": "3.11.6"
107+
"version": "3.10.15"
106108
}
107109
},
108110
"nbformat": 4,

glados/phonemizer.py

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
from pickle import load
2+
from typing import List, Iterable, Dict, Union
3+
from functools import cache
4+
from pathlib import Path
5+
import re
6+
from itertools import zip_longest
7+
8+
import numpy as np
9+
import onnxruntime
10+
from enum import Enum
11+
12+
class ModelConfig:
13+
MODEL_NAME: str = 'models/phomenizer_en.onnx'
14+
PHONEME_DICT_PATH: Path = Path('./models/lang_phoneme_dict.pkl')
15+
TOKEN_TO_IDX_PATH: Path = Path('./models/token_to_idx.pkl')
16+
IDX_TO_TOKEN_PATH: Path = Path('./models/idx_to_token.pkl')
17+
CHAR_REPEATS: int = 3
18+
MODEL_INPUT_LENGTH: int = 64
19+
EXPAND_ACRONYMS: bool = True
20+
USE_CUDA: bool = True
21+
22+
class SpecialTokens(Enum):
23+
PAD = '_'
24+
START = '<start>'
25+
END = '<end>'
26+
EN_US = '<en_us>'
27+
28+
class Punctuation(Enum):
29+
PUNCTUATION = '().,:?!/–'
30+
HYPHEN = '-'
31+
SPACE = ' '
32+
33+
@classmethod
34+
@cache
35+
def get_punc_set(cls) -> set[str]:
36+
return set(cls.PUNCTUATION.value + cls.HYPHEN.value + cls.SPACE.value)
37+
38+
@classmethod
39+
@cache
40+
def get_punc_pattern(cls) -> re.Pattern:
41+
return re.compile(f'([{cls.PUNCTUATION.value + cls.SPACE.value}])')
42+
43+
class Phonemizer:
44+
def __init__(self) -> None:
45+
# self.ort_session = onnxruntime.InferenceSession(ModelConfig.MODEL_NAME)
46+
47+
48+
self.phoneme_dict = self._load_pickle(ModelConfig.PHONEME_DICT_PATH)
49+
self.token_to_idx = self._load_pickle(ModelConfig.TOKEN_TO_IDX_PATH)
50+
self.idx_to_token = self._load_pickle(ModelConfig.IDX_TO_TOKEN_PATH)
51+
52+
providers = ["CPUExecutionProvider"]
53+
if ModelConfig.USE_CUDA:
54+
providers = [
55+
("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}),
56+
"CPUExecutionProvider",
57+
]
58+
59+
self.ort_session = onnxruntime.InferenceSession(
60+
ModelConfig.MODEL_NAME,
61+
sess_options=onnxruntime.SessionOptions(),
62+
providers=providers,
63+
)
64+
65+
66+
self.special_tokens: set[str] = {SpecialTokens.PAD.value, SpecialTokens.END.value, SpecialTokens.EN_US.value}
67+
68+
@staticmethod
69+
def _load_pickle(path: Path) -> dict:
70+
"""Load a pickled dictionary from path."""
71+
with path.open('rb') as f:
72+
return load(f)
73+
74+
@staticmethod
75+
def unique_consecutive(arr: np.ndarray) -> List[np.ndarray]:
76+
"""
77+
Equivalent to torch.unique_consecutive for numpy arrays.
78+
79+
:param arr: Array to process.
80+
:return: Array with consecutive duplicates removed.
81+
"""
82+
result = []
83+
for row in arr:
84+
if len(row) == 0:
85+
result.append(row)
86+
else:
87+
mask = np.concatenate(([True], row[1:] != row[:-1]))
88+
result.append(row[mask])
89+
return result
90+
91+
@staticmethod
92+
def remove_padding(arr: np.ndarray, padding_value: int = 0) -> List[np.ndarray]:
93+
return [row[row != padding_value] for row in arr]
94+
95+
@staticmethod
96+
def trim_to_stop(arr: np.ndarray, end_index: int = 2) -> List[np.ndarray]:
97+
"""
98+
Trims the array to the stop index.
99+
100+
:param arr: Array to trim.
101+
:param end_index: Index to stop at.
102+
"""
103+
result = []
104+
for row in arr:
105+
stop_index = np.where(row == end_index)[0]
106+
if len(stop_index) > 0:
107+
result.append(row[:stop_index[0]+1])
108+
else:
109+
result.append(row)
110+
return result
111+
112+
def process_model_output(self, arr: np.ndarray) -> List[np.ndarray]:
113+
"""
114+
Processes the output of the model to get the phoneme indices.
115+
116+
:param arr: Model output.
117+
:return: List of phoneme indices.
118+
"""
119+
120+
arr = np.argmax(arr[0], axis=2)
121+
arr = self.unique_consecutive(arr)
122+
arr = self.remove_padding(arr)
123+
arr = self.trim_to_stop(arr)
124+
return arr
125+
126+
@staticmethod
127+
def expand_acronym(word: str) -> str:
128+
"""
129+
Expands an acronym into its subwords.
130+
131+
:param word: Acronym to expand.
132+
:return: Expanded acronym.
133+
"""
134+
subwords = []
135+
for subword in word.split(Punctuation.HYPHEN.value):
136+
expanded = []
137+
for a, b in zip_longest(subword, subword[1:]):
138+
expanded.append(a)
139+
if b is not None and b.isupper():
140+
expanded.append(Punctuation.HYPHEN.value)
141+
expanded = ''.join(expanded)
142+
subwords.append(expanded)
143+
return Punctuation.HYPHEN.value.join(subwords)
144+
145+
def encode(self, sentence: Iterable[str]) -> List[int]:
146+
"""
147+
Maps a sequence of symbols for a language to a sequence of indices.
148+
149+
:param sentence: Sentence (or word) as a sequence of symbols.
150+
:return: Sequence of token indices.
151+
"""
152+
sentence = [item for item in sentence for _ in range(ModelConfig.CHAR_REPEATS)]
153+
sentence = [s.lower() for s in sentence]
154+
sequence = [self.token_to_idx[c] for c in sentence if c in self.token_to_idx]
155+
return [self.token_to_idx[SpecialTokens.START.value]] + sequence + [self.token_to_idx[SpecialTokens.END.value]]
156+
157+
def decode(self, sequence: np.ndarray) -> str:
158+
"""
159+
Maps a sequence of indices to an array of symbols.
160+
161+
:param sequence: Encoded sequence to be decoded.
162+
:return: Decoded sequence of symbols.
163+
"""
164+
decoded = [self.idx_to_token[int(t)] for t in sequence if int(t) in self.idx_to_token]
165+
return ''.join(d for d in decoded if d not in self.special_tokens)
166+
167+
@staticmethod
168+
def pad_sequence_fixed(v: List[np.ndarray], target_length: int = ModelConfig.MODEL_INPUT_LENGTH) -> np.ndarray:
169+
"""
170+
Pad or truncate a list of arrays to a fixed length.
171+
172+
:param v: List of arrays.
173+
:param target_length: Target length to pad or truncate to.
174+
:return: Padded array.
175+
"""
176+
result = np.zeros((len(v), target_length), dtype=np.int64)
177+
178+
for i, seq in enumerate(v):
179+
length = min(len(seq), target_length) # Handle both shorter and longer sequences
180+
result[i, :length] = seq[:length] # Copy either the full sequence or its truncated version
181+
182+
return result
183+
184+
185+
def _get_dict_entry(self, word: str, lang: str, punc_set: set[str]) -> str | None:
186+
"""
187+
Gets the phoneme entry for a word in the dictionary.
188+
189+
:param word: Word to get phoneme entry for.
190+
:param lang: Language of the word.
191+
:param punc_set: Set of punctuation characters.
192+
:return: Phoneme entry for the word.
193+
"""
194+
if word in punc_set or len(word) == 0:
195+
return word
196+
if not self.phoneme_dict or lang not in self.phoneme_dict:
197+
return None
198+
phoneme_dict = self.phoneme_dict[lang]
199+
if word in phoneme_dict:
200+
return phoneme_dict[word]
201+
elif word.lower() in phoneme_dict:
202+
return phoneme_dict[word.lower()]
203+
elif word.title() in phoneme_dict:
204+
return phoneme_dict[word.title()]
205+
else:
206+
return None
207+
208+
@staticmethod
209+
def _get_phonemes(word: str,
210+
word_phonemes: Dict[str, Union[str, None]],
211+
word_splits: Dict[str, List[str]]) -> str:
212+
"""
213+
Gets the phonemes for a word. If the word is not in the phoneme dictionary, it is split into subwords.
214+
215+
:param word: Word to get phonemes for.
216+
:param word_phonemes: Dictionary of word phonemes.
217+
"""
218+
phons = word_phonemes[word]
219+
if phons is None:
220+
subwords = word_splits[word]
221+
subphons = [word_phonemes[w] for w in subwords]
222+
phons = ''.join(subphons)
223+
return phons
224+
225+
def _clean_and_split_texts(self, texts: List[str], punc_set: set[str], punc_pattern: re.Pattern) -> tuple[List[List[str]], set[str]]:
226+
split_text, cleaned_words = [], set()
227+
for text in texts:
228+
cleaned_text = ''.join(t for t in text if t.isalnum() or t in punc_set)
229+
split = [s for s in re.split(punc_pattern, cleaned_text) if len(s) > 0]
230+
split_text.append(split)
231+
cleaned_words.update(split)
232+
return split_text, cleaned_words
233+
234+
def convert_to_phonemes(self, texts: List[str], lang: str) -> List[str]:
235+
"""
236+
Converts a list of texts to phonemes using a phonemizer.
237+
238+
:param texts: List of texts to convert.
239+
:param lang: Language of the texts.
240+
:return: List of phonemes.
241+
"""
242+
split_text, cleaned_words = [], set()
243+
punc_set = Punctuation.get_punc_set()
244+
punc_pattern = Punctuation.get_punc_pattern()
245+
246+
# Step 1: Preprocess texts
247+
split_text, cleaned_words = self._clean_and_split_texts(texts, punc_set, punc_pattern)
248+
249+
# Step 2: Collect dictionary phonemes for words and hyphenated words
250+
for punct in punc_set:
251+
self.phoneme_dict[punct] = punct
252+
word_phonemes = {word: self.phoneme_dict.get(word) for word in cleaned_words}
253+
254+
255+
# Step 3: If word is not in dictionary, split it into subwords
256+
words_to_split = [w for w in cleaned_words if word_phonemes[w] is None]
257+
258+
word_splits = {
259+
key: re.split(r'([-])', self.expand_acronym(word) if ModelConfig.EXPAND_ACRONYMS else word)
260+
for key, word in zip(words_to_split, words_to_split)
261+
}
262+
263+
subwords = {w for values in word_splits.values() for w in values if w not in word_phonemes}
264+
265+
for subword in subwords:
266+
word_phonemes[subword] = self._get_dict_entry(word=subword, lang=lang, punc_set=punc_set)
267+
268+
# Step 4: Predict all subwords that are missing in the phoneme dict
269+
words_to_predict = [word for word, phons in word_phonemes.items()
270+
if phons is None and len(word_splits.get(word, [])) <= 1]
271+
input_batch = [self.encode(word) for word in words_to_predict]
272+
input_batch = self.pad_sequence_fixed(input_batch)
273+
274+
275+
ort_inputs = {self.ort_session.get_inputs()[0].name: input_batch}
276+
ort_outs = self.ort_session.run(None, ort_inputs)
277+
278+
ids = self.process_model_output(ort_outs)
279+
280+
# Step 5: Add predictions to the dictionary
281+
for id, word in zip(ids, words_to_predict):
282+
word_phonemes[word] = self.decode(id)
283+
284+
# Step 6: Get phonemes for each word in the text
285+
phoneme_lists = []
286+
for text in split_text:
287+
text_phons = [
288+
self._get_phonemes(word=word, word_phonemes=word_phonemes,
289+
word_splits=word_splits)
290+
for word in text
291+
]
292+
phoneme_lists.append(text_phons)
293+
294+
return [''.join(phoneme_list) for phoneme_list in phoneme_lists]
295+

0 commit comments

Comments
 (0)