@@ -199,19 +199,19 @@ def infer(
199
199
)
200
200
emo = get_emo_ (reference_audio , emotion )
201
201
if skip_start :
202
- phones = phones [1 :]
203
- tones = tones [1 :]
204
- lang_ids = lang_ids [1 :]
205
- bert = bert [:, 1 :]
206
- ja_bert = ja_bert [:, 1 :]
207
- en_bert = en_bert [:, 1 :]
202
+ phones = phones [3 :]
203
+ tones = tones [3 :]
204
+ lang_ids = lang_ids [3 :]
205
+ bert = bert [:, 3 :]
206
+ ja_bert = ja_bert [:, 3 :]
207
+ en_bert = en_bert [:, 3 :]
208
208
if skip_end :
209
- phones = phones [:- 1 ]
210
- tones = tones [:- 1 ]
211
- lang_ids = lang_ids [:- 1 ]
212
- bert = bert [:, :- 1 ]
213
- ja_bert = ja_bert [:, :- 1 ]
214
- en_bert = en_bert [:, :- 1 ]
209
+ phones = phones [:- 2 ]
210
+ tones = tones [:- 2 ]
211
+ lang_ids = lang_ids [:- 2 ]
212
+ bert = bert [:, :- 2 ]
213
+ ja_bert = ja_bert [:, :- 2 ]
214
+ en_bert = en_bert [:, :- 2 ]
215
215
with torch .no_grad ():
216
216
x_tst = phones .to (device ).unsqueeze (0 )
217
217
tones = tones .to (device ).unsqueeze (0 )
@@ -279,19 +279,19 @@ def infer_multilang(
279
279
temp_lang_ids ,
280
280
) = get_text (txt , lang , hps , device )
281
281
if skip_start :
282
- temp_bert = temp_bert [:, 1 :]
283
- temp_ja_bert = temp_ja_bert [:, 1 :]
284
- temp_en_bert = temp_en_bert [:, 1 :]
285
- temp_phones = temp_phones [1 :]
286
- temp_tones = temp_tones [1 :]
287
- temp_lang_ids = temp_lang_ids [1 :]
282
+ temp_bert = temp_bert [:, 3 :]
283
+ temp_ja_bert = temp_ja_bert [:, 3 :]
284
+ temp_en_bert = temp_en_bert [:, 3 :]
285
+ temp_phones = temp_phones [3 :]
286
+ temp_tones = temp_tones [3 :]
287
+ temp_lang_ids = temp_lang_ids [3 :]
288
288
if skip_end :
289
- temp_bert = temp_bert [:, :- 1 ]
290
- temp_ja_bert = temp_ja_bert [:, :- 1 ]
291
- temp_en_bert = temp_en_bert [:, :- 1 ]
292
- temp_phones = temp_phones [:- 1 ]
293
- temp_tones = temp_tones [:- 1 ]
294
- temp_lang_ids = temp_lang_ids [:- 1 ]
289
+ temp_bert = temp_bert [:, :- 2 ]
290
+ temp_ja_bert = temp_ja_bert [:, :- 2 ]
291
+ temp_en_bert = temp_en_bert [:, :- 2 ]
292
+ temp_phones = temp_phones [:- 2 ]
293
+ temp_tones = temp_tones [:- 2 ]
294
+ temp_lang_ids = temp_lang_ids [:- 2 ]
295
295
bert .append (temp_bert )
296
296
ja_bert .append (temp_ja_bert )
297
297
en_bert .append (temp_en_bert )
0 commit comments