@@ -951,26 +951,24 @@ def forward(
951951 g = self .emb_g (sid ).unsqueeze (- 1 ) # [b, h, 1]
952952 else :
953953 g = self .ref_enc (y .transpose (1 , 2 )).unsqueeze (- 1 )
954- x , m_p , logs_p , x_mask = self .enc_p (
955- x , x_lengths , tone , language , bert , ja_bert , en_bert , g = g
956- )
957- z , m_q , logs_q , y_mask = self .enc_q (y , y_lengths , g = g )
958- z_p = self .flow (z , y_mask , g = g )
954+ z_p_text , m_p_text , logs_p_text , h_text , x_mask = self .enc_p (x , x_lengths , g = g )
955+ z_q_audio , m_q_audio , logs_q_audio , y_mask = self .enc_q (y , y_lengths , g = g )
956+ z_q_dur , m_q_dur , logs_q_dur = self .flow (z_q_audio , m_q_audio , logs_q_audio , y_mask , g = g )
959957
960958 with torch .no_grad ():
961959 # negative cross-entropy
962- s_p_sq_r = torch .exp (- 2 * logs_p ) # [b, d, t]
960+ s_p_sq_r = torch .exp (- 2 * logs_p_text ) # [b, d, t]
963961 neg_cent1 = torch .sum (
964- - 0.5 * math .log (2 * math .pi ) - logs_p , [1 ], keepdim = True
962+ - 0.5 * math .log (2 * math .pi ) - logs_p_text , [1 ], keepdim = True
965963 ) # [b, 1, t_s]
966964 neg_cent2 = torch .matmul (
967- - 0.5 * (z_p ** 2 ).transpose (1 , 2 ), s_p_sq_r
965+ - 0.5 * (z_q_dur ** 2 ).transpose (1 , 2 ), s_p_sq_r
968966 ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
969967 neg_cent3 = torch .matmul (
970- z_p .transpose (1 , 2 ), (m_p * s_p_sq_r )
968+ z_p .transpose (1 , 2 ), (m_p_text * s_p_sq_r )
971969 ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
972970 neg_cent4 = torch .sum (
973- - 0.5 * (m_p ** 2 ) * s_p_sq_r , [1 ], keepdim = True
971+ - 0.5 * (m_p_text ** 2 ) * s_p_sq_r , [1 ], keepdim = True
974972 ) # [b, 1, t_s]
975973 neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
976974 if self .use_noise_scaled_mas :
@@ -989,13 +987,16 @@ def forward(
989987 )
990988
991989 w = attn .sum (2 )
990+ attn_inv = attn .squeeze (1 ) * (1 / (w + 1e-9 ))
991+ m_q_text = torch .matmul (attn_inv .mT , m_q_dur .mT ).mT
992+ logs_q_text = torch .matmul (attn_inv .mT , logs_q_dur .mT ).mT
992993
993- l_length_sdp = self .sdp (x , x_mask , w , g = g )
994+ l_length_sdp = self .sdp (h_text , x_mask , w , g = g )
994995 l_length_sdp = l_length_sdp / torch .sum (x_mask )
995996
996997 logw_ = torch .log (w + 1e-6 ) * x_mask
997- logw = self .dp (x , x_mask , g = g )
998- logw_sdp = self .sdp (x , x_mask , g = g , reverse = True , noise_scale = 1.0 )
998+ logw = self .dp (h_text , x_mask , g = g )
999+ logw_sdp = self .sdp (h_text , x_mask , g = g , reverse = True , noise_scale = 1.0 )
9991000 l_length_dp = torch .sum ((logw - logw_ ) ** 2 , [1 , 2 ]) / torch .sum (
10001001 x_mask
10011002 ) # for averaging
@@ -1004,8 +1005,10 @@ def forward(
10041005 l_length = l_length_dp + l_length_sdp
10051006
10061007 # expand prior
1007- m_p = torch .matmul (attn .squeeze (1 ), m_p .transpose (1 , 2 )).transpose (1 , 2 )
1008- logs_p = torch .matmul (attn .squeeze (1 ), logs_p .transpose (1 , 2 )).transpose (1 , 2 )
1008+ m_p_dur = torch .matmul (attn .squeeze (1 ), m_p_text .mT ).mT
1009+ logs_p_dur = torch .matmul (attn .squeeze (1 ), logs_p_text .mT ).mT
1010+ z_p_dur = m_p_dur + torch .randn_like (m_p_dur ) * torch .exp (logs_p_dur ) * y_mask
1011+ z_p_audio , m_p_audio , logs_p_audio = self .flow (z_p_dur , m_p_dur , logs_p_dur , y_mask , g = g , reverse = True )
10091012
10101013 z_slice , ids_slice = commons .rand_slice_segments (
10111014 z , y_lengths , self .segment_size
@@ -1018,8 +1021,9 @@ def forward(
10181021 ids_slice ,
10191022 x_mask ,
10201023 y_mask ,
1021- (z , z_p , m_p , logs_p , m_q , logs_q ),
1022- (x , logw , logw_ , logw_sdp ),
1024+ (m_p_text , logs_p_text ),
1025+ (m_p_dur , logs_p_dur , z_q_dur , logs_q_dur ),
1026+ (m_p_audio , logs_p_audio , m_q_audio , logs_q_audio ),
10231027 g ,
10241028 )
10251029
0 commit comments