Skip to content

Commit 6733e63

Browse files
committed
Fix multilang generation
1 parent f1a1b4c commit 6733e63

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

infer.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,19 @@ def infer(
199199
)
200200
emo = get_emo_(reference_audio, emotion)
201201
if skip_start:
202-
phones = phones[1:]
203-
tones = tones[1:]
204-
lang_ids = lang_ids[1:]
205-
bert = bert[:, 1:]
206-
ja_bert = ja_bert[:, 1:]
207-
en_bert = en_bert[:, 1:]
202+
phones = phones[3:]
203+
tones = tones[3:]
204+
lang_ids = lang_ids[3:]
205+
bert = bert[:, 3:]
206+
ja_bert = ja_bert[:, 3:]
207+
en_bert = en_bert[:, 3:]
208208
if skip_end:
209-
phones = phones[:-1]
210-
tones = tones[:-1]
211-
lang_ids = lang_ids[:-1]
212-
bert = bert[:, :-1]
213-
ja_bert = ja_bert[:, :-1]
214-
en_bert = en_bert[:, :-1]
209+
phones = phones[:-2]
210+
tones = tones[:-2]
211+
lang_ids = lang_ids[:-2]
212+
bert = bert[:, :-2]
213+
ja_bert = ja_bert[:, :-2]
214+
en_bert = en_bert[:, :-2]
215215
with torch.no_grad():
216216
x_tst = phones.to(device).unsqueeze(0)
217217
tones = tones.to(device).unsqueeze(0)
@@ -279,19 +279,19 @@ def infer_multilang(
279279
temp_lang_ids,
280280
) = get_text(txt, lang, hps, device)
281281
if skip_start:
282-
temp_bert = temp_bert[:, 1:]
283-
temp_ja_bert = temp_ja_bert[:, 1:]
284-
temp_en_bert = temp_en_bert[:, 1:]
285-
temp_phones = temp_phones[1:]
286-
temp_tones = temp_tones[1:]
287-
temp_lang_ids = temp_lang_ids[1:]
282+
temp_bert = temp_bert[:, 3:]
283+
temp_ja_bert = temp_ja_bert[:, 3:]
284+
temp_en_bert = temp_en_bert[:, 3:]
285+
temp_phones = temp_phones[3:]
286+
temp_tones = temp_tones[3:]
287+
temp_lang_ids = temp_lang_ids[3:]
288288
if skip_end:
289-
temp_bert = temp_bert[:, :-1]
290-
temp_ja_bert = temp_ja_bert[:, :-1]
291-
temp_en_bert = temp_en_bert[:, :-1]
292-
temp_phones = temp_phones[:-1]
293-
temp_tones = temp_tones[:-1]
294-
temp_lang_ids = temp_lang_ids[:-1]
289+
temp_bert = temp_bert[:, :-2]
290+
temp_ja_bert = temp_ja_bert[:, :-2]
291+
temp_en_bert = temp_en_bert[:, :-2]
292+
temp_phones = temp_phones[:-2]
293+
temp_tones = temp_tones[:-2]
294+
temp_lang_ids = temp_lang_ids[:-2]
295295
bert.append(temp_bert)
296296
ja_bert.append(temp_ja_bert)
297297
en_bert.append(temp_en_bert)

0 commit comments

Comments
 (0)