1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import re
1514from typing import Dict
1615from typing import List
1716
@@ -30,7 +29,6 @@ def __init__(self,
3029 self .zh_frontend = Frontend (
3130 phone_vocab_path = phone_vocab_path , tone_vocab_path = tone_vocab_path )
3231 self .en_frontend = English (phone_vocab_path = phone_vocab_path )
33- self .SENTENCE_SPLITOR = re .compile (r'([:、,;。?!,;?!][”’]?)' )
3432 self .sp_id = self .zh_frontend .vocab_phones ["sp" ]
3533 self .sp_id_tensor = paddle .to_tensor ([self .sp_id ])
3634
@@ -47,188 +45,56 @@ def is_alphabet(self, char):
4745 else :
4846 return False
4947
50- def is_number (self , char ):
51- if char >= '\u0030 ' and char <= '\u0039 ' :
52- return True
53- else :
54- return False
55-
5648 def is_other (self , char ):
57- if not (self .is_chinese (char ) or self .is_number (char ) or
58- self .is_alphabet (char )):
49+ if not (self .is_chinese (char ) or self .is_alphabet (char )):
5950 return True
6051 else :
6152 return False
6253
63- def is_end (self , before_char , after_char ) -> bool :
64- flag = 0
65- for char in (before_char , after_char ):
66- if self .is_alphabet (char ) or char == " " :
67- flag += 1
68- if flag == 2 :
69- return True
70- else :
71- return False
72-
73- def _replace (self , text : str ) -> str :
74- new_text = ""
75-
76- # get "." indexs
77- point = "."
78- point_indexs = []
79- index = - 1
80- for i in range (text .count (point )):
81- index = text .find ("." , index + 1 , len (text ))
82- point_indexs .append (index )
83-
84- # replace "." -> "。" when English sentence ending
85- if len (point_indexs ) == 0 :
86- new_text = text
87-
88- elif len (point_indexs ) == 1 :
89- point_index = point_indexs [0 ]
90- if point_index == 0 or point_index == len (text ) - 1 :
91- new_text = text
92- else :
93- if not self .is_end (text [point_index - 1 ], text [point_index +
94- 1 ]):
95- new_text = text
96- else :
97- new_text = text [:point_index ] + "。" + text [point_index + 1 :]
98-
99- elif len (point_indexs ) == 2 :
100- first_index = point_indexs [0 ]
101- end_index = point_indexs [1 ]
102-
103- # first
104- if first_index != 0 :
105- if not self .is_end (text [first_index - 1 ], text [first_index +
106- 1 ]):
107- new_text += (text [:first_index ] + "." )
108- else :
109- new_text += (text [:first_index ] + "。" )
110- else :
111- new_text += "."
112- # last
113- if end_index != len (text ) - 1 :
114- if not self .is_end (text [end_index - 1 ], text [end_index + 1 ]):
115- new_text += text [point_indexs [- 2 ] + 1 :]
116- else :
117- new_text += (text [point_indexs [- 2 ] + 1 :end_index ] + "。" +
118- text [end_index + 1 :])
119- else :
120- new_text += "."
121-
122- else :
123- first_index = point_indexs [0 ]
124- end_index = point_indexs [- 1 ]
125- # first
126- if first_index != 0 :
127- if not self .is_end (text [first_index - 1 ], text [first_index +
128- 1 ]):
129- new_text += (text [:first_index ] + "." )
130- else :
131- new_text += (text [:first_index ] + "。" )
132- else :
133- new_text += "."
134- # middle
135- for j in range (1 , len (point_indexs ) - 1 ):
136- point_index = point_indexs [j ]
137- if not self .is_end (text [point_index - 1 ], text [point_index +
138- 1 ]):
139- new_text += (
140- text [point_indexs [j - 1 ] + 1 :point_index ] + "." )
141- else :
142- new_text += (
143- text [point_indexs [j - 1 ] + 1 :point_index ] + "。" )
144- # last
145- if end_index != len (text ) - 1 :
146- if not self .is_end (text [end_index - 1 ], text [end_index + 1 ]):
147- new_text += text [point_indexs [- 2 ] + 1 :]
148- else :
149- new_text += (text [point_indexs [- 2 ] + 1 :end_index ] + "。" +
150- text [end_index + 1 :])
151- else :
152- new_text += "."
153-
154- return new_text
155-
156- def _split (self , text : str ) -> List [str ]:
157- text = re .sub (r'[《》【】<=>{}()()#&@“”^_|…\\]' , '' , text )
158- # 替换英文句子的句号 "." --> "。" 用于后续分句
159- text = self ._replace (text )
160- text = self .SENTENCE_SPLITOR .sub (r'\1\n' , text )
161- text = text .strip ()
162- sentences = [sentence .strip () for sentence in re .split (r'\n+' , text )]
163- return sentences
164-
165- def _distinguish (self , text : str ) -> List [str ]:
54+ def get_segment (self , text : str ) -> List [str ]:
16655 # sentence --> [ch_part, en_part, ch_part, ...]
167-
16856 segments = []
16957 types = []
170-
17158 flag = 0
17259 temp_seg = ""
17360 temp_lang = ""
17461
17562 # Determine the type of each character. type: blank, chinese, alphabet, number, unk and point.
17663 for ch in text :
177- if ch == "." :
178- types .append ("point" )
179- elif self .is_chinese (ch ):
64+ if self .is_chinese (ch ):
18065 types .append ("zh" )
18166 elif self .is_alphabet (ch ):
18267 types .append ("en" )
183- elif ch == " " :
184- types .append ("blank" )
185- elif self .is_number (ch ):
186- types .append ("num" )
18768 else :
188- types .append ("unk " )
69+ types .append ("other " )
18970
19071 assert len (types ) == len (text )
19172
19273 for i in range (len (types )):
193-
19474 # find the first char of the seg
19575 if flag == 0 :
196- # 首个字符是中文,英文或者数字
197- if types [i ] == "zh" or types [i ] == "en" or types [i ] == "num" :
198- temp_seg += text [i ]
199- temp_lang = types [i ]
200- flag = 1
76+ temp_seg += text [i ]
77+ temp_lang = types [i ]
78+ flag = 1
20179
20280 else :
203- # 数字和小数点均与前面的字符合并,类型属于前面一个字符的类型
204- if types [i ] == temp_lang or types [i ] == "num" or types [
205- i ] == "point" :
206- temp_seg += text [i ]
207-
208- # 数字与后面的任意字符都拼接
209- elif temp_lang == "num" :
210- temp_seg += text [i ]
211- if types [i ] == "zh" or types [i ] == "en" :
81+ if temp_lang == "other" :
82+ if types [i ] == temp_lang :
83+ temp_seg += text [i ]
84+ else :
85+ temp_seg += text [i ]
21286 temp_lang = types [i ]
21387
214- # 如果是空格则与前面字符拼接
215- elif types [i ] == "blank" :
216- temp_seg += text [i ]
217-
218- elif types [i ] == "unk" :
219- pass
220-
22188 else :
222- segments .append ((temp_seg , temp_lang ))
223-
224- if types [i ] == "zh" or types [i ] == "en" :
89+ if types [i ] == temp_lang :
90+ temp_seg += text [i ]
91+ elif types [i ] == "other" :
92+ temp_seg += text [i ]
93+ else :
94+ segments .append ((temp_seg , temp_lang ))
22595 temp_seg = text [i ]
22696 temp_lang = types [i ]
22797 flag = 1
228- else :
229- flag = 0
230- temp_seg = ""
231- temp_lang = ""
23298
23399 segments .append ((temp_seg , temp_lang ))
234100
@@ -241,34 +107,30 @@ def get_input_ids(self,
241107 add_sp : bool = True ,
242108 to_tensor : bool = True ) -> Dict [str , List [paddle .Tensor ]]:
243109
244- sentences = self ._split (sentence )
110+ segments = self .get_segment (sentence )
111+
245112 phones_list = []
246113 result = {}
247- for text in sentences :
248- phones_seg = []
249- segments = self ._distinguish (text )
250- for seg in segments :
251- content = seg [0 ]
252- lang = seg [1 ]
253- if content != '' :
254- if lang == "en" :
255- input_ids = self .en_frontend .get_input_ids (
256- content , merge_sentences = True , to_tensor = to_tensor )
257- else :
258- input_ids = self .zh_frontend .get_input_ids (
259- content ,
260- merge_sentences = True ,
261- get_tone_ids = get_tone_ids ,
262- to_tensor = to_tensor )
263114
264- phones_seg .append (input_ids ["phone_ids" ][0 ])
265- if add_sp :
266- phones_seg .append (self .sp_id_tensor )
267-
268- if phones_seg == []:
269- phones_seg .append (self .sp_id_tensor )
270- phones = paddle .concat (phones_seg )
271- phones_list .append (phones )
115+ for seg in segments :
116+ content = seg [0 ]
117+ lang = seg [1 ]
118+ if content != '' :
119+ if lang == "en" :
120+ input_ids = self .en_frontend .get_input_ids (
121+ content , merge_sentences = False , to_tensor = to_tensor )
122+ else :
123+ input_ids = self .zh_frontend .get_input_ids (
124+ content ,
125+ merge_sentences = False ,
126+ get_tone_ids = get_tone_ids ,
127+ to_tensor = to_tensor )
128+ if add_sp :
129+ input_ids ["phone_ids" ][- 1 ] = paddle .concat (
130+ [input_ids ["phone_ids" ][- 1 ], self .sp_id_tensor ])
131+
132+ for phones in input_ids ["phone_ids" ]:
133+ phones_list .append (phones )
272134
273135 if merge_sentences :
274136 merge_list = paddle .concat (phones_list )
0 commit comments