|
| 1 | +from pickle import load |
| 2 | +from typing import List, Iterable, Dict, Union |
| 3 | +from functools import cache |
| 4 | +from pathlib import Path |
| 5 | +import re |
| 6 | +from itertools import zip_longest |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import onnxruntime |
| 10 | +from enum import Enum |
| 11 | + |
| 12 | +class ModelConfig: |
| 13 | + MODEL_NAME: str = 'models/phomenizer_en.onnx' |
| 14 | + PHONEME_DICT_PATH: Path = Path('./models/lang_phoneme_dict.pkl') |
| 15 | + TOKEN_TO_IDX_PATH: Path = Path('./models/token_to_idx.pkl') |
| 16 | + IDX_TO_TOKEN_PATH: Path = Path('./models/idx_to_token.pkl') |
| 17 | + CHAR_REPEATS: int = 3 |
| 18 | + MODEL_INPUT_LENGTH: int = 64 |
| 19 | + EXPAND_ACRONYMS: bool = True |
| 20 | + USE_CUDA: bool = True |
| 21 | + |
| 22 | +class SpecialTokens(Enum): |
| 23 | + PAD = '_' |
| 24 | + START = '<start>' |
| 25 | + END = '<end>' |
| 26 | + EN_US = '<en_us>' |
| 27 | + |
| 28 | +class Punctuation(Enum): |
| 29 | + PUNCTUATION = '().,:?!/–' |
| 30 | + HYPHEN = '-' |
| 31 | + SPACE = ' ' |
| 32 | + |
| 33 | + @classmethod |
| 34 | + @cache |
| 35 | + def get_punc_set(cls) -> set[str]: |
| 36 | + return set(cls.PUNCTUATION.value + cls.HYPHEN.value + cls.SPACE.value) |
| 37 | + |
| 38 | + @classmethod |
| 39 | + @cache |
| 40 | + def get_punc_pattern(cls) -> re.Pattern: |
| 41 | + return re.compile(f'([{cls.PUNCTUATION.value + cls.SPACE.value}])') |
| 42 | + |
| 43 | +class Phonemizer: |
| 44 | + def __init__(self) -> None: |
| 45 | + # self.ort_session = onnxruntime.InferenceSession(ModelConfig.MODEL_NAME) |
| 46 | + |
| 47 | + |
| 48 | + self.phoneme_dict = self._load_pickle(ModelConfig.PHONEME_DICT_PATH) |
| 49 | + self.token_to_idx = self._load_pickle(ModelConfig.TOKEN_TO_IDX_PATH) |
| 50 | + self.idx_to_token = self._load_pickle(ModelConfig.IDX_TO_TOKEN_PATH) |
| 51 | + |
| 52 | + providers = ["CPUExecutionProvider"] |
| 53 | + if ModelConfig.USE_CUDA: |
| 54 | + providers = [ |
| 55 | + ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "HEURISTIC"}), |
| 56 | + "CPUExecutionProvider", |
| 57 | + ] |
| 58 | + |
| 59 | + self.ort_session = onnxruntime.InferenceSession( |
| 60 | + ModelConfig.MODEL_NAME, |
| 61 | + sess_options=onnxruntime.SessionOptions(), |
| 62 | + providers=providers, |
| 63 | + ) |
| 64 | + |
| 65 | + |
| 66 | + self.special_tokens: set[str] = {SpecialTokens.PAD.value, SpecialTokens.END.value, SpecialTokens.EN_US.value} |
| 67 | + |
| 68 | + @staticmethod |
| 69 | + def _load_pickle(path: Path) -> dict: |
| 70 | + """Load a pickled dictionary from path.""" |
| 71 | + with path.open('rb') as f: |
| 72 | + return load(f) |
| 73 | + |
| 74 | + @staticmethod |
| 75 | + def unique_consecutive(arr: np.ndarray) -> List[np.ndarray]: |
| 76 | + """ |
| 77 | + Equivalent to torch.unique_consecutive for numpy arrays. |
| 78 | +
|
| 79 | + :param arr: Array to process. |
| 80 | + :return: Array with consecutive duplicates removed. |
| 81 | + """ |
| 82 | + result = [] |
| 83 | + for row in arr: |
| 84 | + if len(row) == 0: |
| 85 | + result.append(row) |
| 86 | + else: |
| 87 | + mask = np.concatenate(([True], row[1:] != row[:-1])) |
| 88 | + result.append(row[mask]) |
| 89 | + return result |
| 90 | + |
| 91 | + @staticmethod |
| 92 | + def remove_padding(arr: np.ndarray, padding_value: int = 0) -> List[np.ndarray]: |
| 93 | + return [row[row != padding_value] for row in arr] |
| 94 | + |
| 95 | + @staticmethod |
| 96 | + def trim_to_stop(arr: np.ndarray, end_index: int = 2) -> List[np.ndarray]: |
| 97 | + """ |
| 98 | + Trims the array to the stop index. |
| 99 | +
|
| 100 | + :param arr: Array to trim. |
| 101 | + :param end_index: Index to stop at. |
| 102 | + """ |
| 103 | + result = [] |
| 104 | + for row in arr: |
| 105 | + stop_index = np.where(row == end_index)[0] |
| 106 | + if len(stop_index) > 0: |
| 107 | + result.append(row[:stop_index[0]+1]) |
| 108 | + else: |
| 109 | + result.append(row) |
| 110 | + return result |
| 111 | + |
| 112 | + def process_model_output(self, arr: np.ndarray) -> List[np.ndarray]: |
| 113 | + """ |
| 114 | + Processes the output of the model to get the phoneme indices. |
| 115 | +
|
| 116 | + :param arr: Model output. |
| 117 | + :return: List of phoneme indices. |
| 118 | + """ |
| 119 | + |
| 120 | + arr = np.argmax(arr[0], axis=2) |
| 121 | + arr = self.unique_consecutive(arr) |
| 122 | + arr = self.remove_padding(arr) |
| 123 | + arr = self.trim_to_stop(arr) |
| 124 | + return arr |
| 125 | + |
| 126 | + @staticmethod |
| 127 | + def expand_acronym(word: str) -> str: |
| 128 | + """ |
| 129 | + Expands an acronym into its subwords. |
| 130 | +
|
| 131 | + :param word: Acronym to expand. |
| 132 | + :return: Expanded acronym. |
| 133 | + """ |
| 134 | + subwords = [] |
| 135 | + for subword in word.split(Punctuation.HYPHEN.value): |
| 136 | + expanded = [] |
| 137 | + for a, b in zip_longest(subword, subword[1:]): |
| 138 | + expanded.append(a) |
| 139 | + if b is not None and b.isupper(): |
| 140 | + expanded.append(Punctuation.HYPHEN.value) |
| 141 | + expanded = ''.join(expanded) |
| 142 | + subwords.append(expanded) |
| 143 | + return Punctuation.HYPHEN.value.join(subwords) |
| 144 | + |
| 145 | + def encode(self, sentence: Iterable[str]) -> List[int]: |
| 146 | + """ |
| 147 | + Maps a sequence of symbols for a language to a sequence of indices. |
| 148 | +
|
| 149 | + :param sentence: Sentence (or word) as a sequence of symbols. |
| 150 | + :return: Sequence of token indices. |
| 151 | + """ |
| 152 | + sentence = [item for item in sentence for _ in range(ModelConfig.CHAR_REPEATS)] |
| 153 | + sentence = [s.lower() for s in sentence] |
| 154 | + sequence = [self.token_to_idx[c] for c in sentence if c in self.token_to_idx] |
| 155 | + return [self.token_to_idx[SpecialTokens.START.value]] + sequence + [self.token_to_idx[SpecialTokens.END.value]] |
| 156 | + |
| 157 | + def decode(self, sequence: np.ndarray) -> str: |
| 158 | + """ |
| 159 | + Maps a sequence of indices to an array of symbols. |
| 160 | +
|
| 161 | + :param sequence: Encoded sequence to be decoded. |
| 162 | + :return: Decoded sequence of symbols. |
| 163 | + """ |
| 164 | + decoded = [self.idx_to_token[int(t)] for t in sequence if int(t) in self.idx_to_token] |
| 165 | + return ''.join(d for d in decoded if d not in self.special_tokens) |
| 166 | + |
| 167 | + @staticmethod |
| 168 | + def pad_sequence_fixed(v: List[np.ndarray], target_length: int = ModelConfig.MODEL_INPUT_LENGTH) -> np.ndarray: |
| 169 | + """ |
| 170 | + Pad or truncate a list of arrays to a fixed length. |
| 171 | + |
| 172 | + :param v: List of arrays. |
| 173 | + :param target_length: Target length to pad or truncate to. |
| 174 | + :return: Padded array. |
| 175 | + """ |
| 176 | + result = np.zeros((len(v), target_length), dtype=np.int64) |
| 177 | + |
| 178 | + for i, seq in enumerate(v): |
| 179 | + length = min(len(seq), target_length) # Handle both shorter and longer sequences |
| 180 | + result[i, :length] = seq[:length] # Copy either the full sequence or its truncated version |
| 181 | + |
| 182 | + return result |
| 183 | + |
| 184 | + |
| 185 | + def _get_dict_entry(self, word: str, lang: str, punc_set: set[str]) -> str | None: |
| 186 | + """ |
| 187 | + Gets the phoneme entry for a word in the dictionary. |
| 188 | +
|
| 189 | + :param word: Word to get phoneme entry for. |
| 190 | + :param lang: Language of the word. |
| 191 | + :param punc_set: Set of punctuation characters. |
| 192 | + :return: Phoneme entry for the word. |
| 193 | + """ |
| 194 | + if word in punc_set or len(word) == 0: |
| 195 | + return word |
| 196 | + if not self.phoneme_dict or lang not in self.phoneme_dict: |
| 197 | + return None |
| 198 | + phoneme_dict = self.phoneme_dict[lang] |
| 199 | + if word in phoneme_dict: |
| 200 | + return phoneme_dict[word] |
| 201 | + elif word.lower() in phoneme_dict: |
| 202 | + return phoneme_dict[word.lower()] |
| 203 | + elif word.title() in phoneme_dict: |
| 204 | + return phoneme_dict[word.title()] |
| 205 | + else: |
| 206 | + return None |
| 207 | + |
| 208 | + @staticmethod |
| 209 | + def _get_phonemes(word: str, |
| 210 | + word_phonemes: Dict[str, Union[str, None]], |
| 211 | + word_splits: Dict[str, List[str]]) -> str: |
| 212 | + """ |
| 213 | + Gets the phonemes for a word. If the word is not in the phoneme dictionary, it is split into subwords. |
| 214 | +
|
| 215 | + :param word: Word to get phonemes for. |
| 216 | + :param word_phonemes: Dictionary of word phonemes. |
| 217 | + """ |
| 218 | + phons = word_phonemes[word] |
| 219 | + if phons is None: |
| 220 | + subwords = word_splits[word] |
| 221 | + subphons = [word_phonemes[w] for w in subwords] |
| 222 | + phons = ''.join(subphons) |
| 223 | + return phons |
| 224 | + |
| 225 | + def _clean_and_split_texts(self, texts: List[str], punc_set: set[str], punc_pattern: re.Pattern) -> tuple[List[List[str]], set[str]]: |
| 226 | + split_text, cleaned_words = [], set() |
| 227 | + for text in texts: |
| 228 | + cleaned_text = ''.join(t for t in text if t.isalnum() or t in punc_set) |
| 229 | + split = [s for s in re.split(punc_pattern, cleaned_text) if len(s) > 0] |
| 230 | + split_text.append(split) |
| 231 | + cleaned_words.update(split) |
| 232 | + return split_text, cleaned_words |
| 233 | + |
| 234 | + def convert_to_phonemes(self, texts: List[str], lang: str) -> List[str]: |
| 235 | + """ |
| 236 | + Converts a list of texts to phonemes using a phonemizer. |
| 237 | +
|
| 238 | + :param texts: List of texts to convert. |
| 239 | + :param lang: Language of the texts. |
| 240 | + :return: List of phonemes. |
| 241 | + """ |
| 242 | + split_text, cleaned_words = [], set() |
| 243 | + punc_set = Punctuation.get_punc_set() |
| 244 | + punc_pattern = Punctuation.get_punc_pattern() |
| 245 | + |
| 246 | + # Step 1: Preprocess texts |
| 247 | + split_text, cleaned_words = self._clean_and_split_texts(texts, punc_set, punc_pattern) |
| 248 | + |
| 249 | + # Step 2: Collect dictionary phonemes for words and hyphenated words |
| 250 | + for punct in punc_set: |
| 251 | + self.phoneme_dict[punct] = punct |
| 252 | + word_phonemes = {word: self.phoneme_dict.get(word) for word in cleaned_words} |
| 253 | + |
| 254 | + |
| 255 | + # Step 3: If word is not in dictionary, split it into subwords |
| 256 | + words_to_split = [w for w in cleaned_words if word_phonemes[w] is None] |
| 257 | + |
| 258 | + word_splits = { |
| 259 | + key: re.split(r'([-])', self.expand_acronym(word) if ModelConfig.EXPAND_ACRONYMS else word) |
| 260 | + for key, word in zip(words_to_split, words_to_split) |
| 261 | + } |
| 262 | + |
| 263 | + subwords = {w for values in word_splits.values() for w in values if w not in word_phonemes} |
| 264 | + |
| 265 | + for subword in subwords: |
| 266 | + word_phonemes[subword] = self._get_dict_entry(word=subword, lang=lang, punc_set=punc_set) |
| 267 | + |
| 268 | + # Step 4: Predict all subwords that are missing in the phoneme dict |
| 269 | + words_to_predict = [word for word, phons in word_phonemes.items() |
| 270 | + if phons is None and len(word_splits.get(word, [])) <= 1] |
| 271 | + input_batch = [self.encode(word) for word in words_to_predict] |
| 272 | + input_batch = self.pad_sequence_fixed(input_batch) |
| 273 | + |
| 274 | + |
| 275 | + ort_inputs = {self.ort_session.get_inputs()[0].name: input_batch} |
| 276 | + ort_outs = self.ort_session.run(None, ort_inputs) |
| 277 | + |
| 278 | + ids = self.process_model_output(ort_outs) |
| 279 | + |
| 280 | + # Step 5: Add predictions to the dictionary |
| 281 | + for id, word in zip(ids, words_to_predict): |
| 282 | + word_phonemes[word] = self.decode(id) |
| 283 | + |
| 284 | + # Step 6: Get phonemes for each word in the text |
| 285 | + phoneme_lists = [] |
| 286 | + for text in split_text: |
| 287 | + text_phons = [ |
| 288 | + self._get_phonemes(word=word, word_phonemes=word_phonemes, |
| 289 | + word_splits=word_splits) |
| 290 | + for word in text |
| 291 | + ] |
| 292 | + phoneme_lists.append(text_phons) |
| 293 | + |
| 294 | + return [''.join(phoneme_list) for phoneme_list in phoneme_lists] |
| 295 | + |
0 commit comments