Skip to content

Commit e6e095f

Browse files
Update models.py
1 parent a01562a commit e6e095f

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

models.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -951,26 +951,24 @@ def forward(
951951
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
952952
else:
953953
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
954-
x, m_p, logs_p, x_mask = self.enc_p(
955-
x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
956-
)
957-
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
958-
z_p = self.flow(z, y_mask, g=g)
954+
z_p_text, m_p_text, logs_p_text, h_text, x_mask = self.enc_p(x, x_lengths, g=g)
955+
z_q_audio, m_q_audio, logs_q_audio, y_mask = self.enc_q(y, y_lengths, g=g)
956+
z_q_dur, m_q_dur, logs_q_dur = self.flow(z_q_audio, m_q_audio, logs_q_audio, y_mask, g=g)
959957

960958
with torch.no_grad():
961959
# negative cross-entropy
962-
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
960+
s_p_sq_r = torch.exp(-2 * logs_p_text) # [b, d, t]
963961
neg_cent1 = torch.sum(
964-
-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
962+
-0.5 * math.log(2 * math.pi) - logs_p_text, [1], keepdim=True
965963
) # [b, 1, t_s]
966964
neg_cent2 = torch.matmul(
967-
-0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
965+
-0.5 * (z_q_dur**2).transpose(1, 2), s_p_sq_r
968966
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
969967
neg_cent3 = torch.matmul(
970-
z_p.transpose(1, 2), (m_p * s_p_sq_r)
968+
z_p.transpose(1, 2), (m_p_text * s_p_sq_r)
971969
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
972970
neg_cent4 = torch.sum(
973-
-0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
971+
-0.5 * (m_p_text**2) * s_p_sq_r, [1], keepdim=True
974972
) # [b, 1, t_s]
975973
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
976974
if self.use_noise_scaled_mas:
@@ -989,13 +987,16 @@ def forward(
989987
)
990988

991989
w = attn.sum(2)
990+
attn_inv = attn.squeeze(1) * (1 / (w + 1e-9))
991+
m_q_text = torch.matmul(attn_inv.mT, m_q_dur.mT).mT
992+
logs_q_text = torch.matmul(attn_inv.mT, logs_q_dur.mT).mT
992993

993-
l_length_sdp = self.sdp(x, x_mask, w, g=g)
994+
l_length_sdp = self.sdp(h_text, x_mask, w, g=g)
994995
l_length_sdp = l_length_sdp / torch.sum(x_mask)
995996

996997
logw_ = torch.log(w + 1e-6) * x_mask
997-
logw = self.dp(x, x_mask, g=g)
998-
logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
998+
logw = self.dp(h_text, x_mask, g=g)
999+
logw_sdp = self.sdp(h_text, x_mask, g=g, reverse=True, noise_scale=1.0)
9991000
l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
10001001
x_mask
10011002
) # for averaging
@@ -1004,8 +1005,10 @@ def forward(
10041005
l_length = l_length_dp + l_length_sdp
10051006

10061007
# expand prior
1007-
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
1008-
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
1008+
m_p_dur = torch.matmul(attn.squeeze(1), m_p_text.mT).mT
1009+
logs_p_dur = torch.matmul(attn.squeeze(1), logs_p_text.mT).mT
1010+
z_p_dur = m_p_dur + torch.randn_like(m_p_dur) * torch.exp(logs_p_dur) * y_mask
1011+
z_p_audio, m_p_audio, logs_p_audio = self.flow(z_p_dur, m_p_dur, logs_p_dur, y_mask, g=g, reverse=True)
10091012

10101013
z_slice, ids_slice = commons.rand_slice_segments(
10111014
z, y_lengths, self.segment_size
@@ -1018,8 +1021,9 @@ def forward(
10181021
ids_slice,
10191022
x_mask,
10201023
y_mask,
1021-
(z, z_p, m_p, logs_p, m_q, logs_q),
1022-
(x, logw, logw_, logw_sdp),
1024+
(m_p_text, logs_p_text),
1025+
(m_p_dur, logs_p_dur, z_q_dur, logs_q_dur),
1026+
(m_p_audio, logs_p_audio, m_q_audio, logs_q_audio),
10231027
g,
10241028
)
10251029

0 commit comments

Comments
 (0)