@@ -66,7 +66,12 @@ def __init__(
6666 self .stride = stride
6767 self .dilation = dilation
6868 self .padding = padding
69- self .padding_mode = padding_mode
69+
70+ # padding_mode is forcibly set to 'constant' when using the npu device because npu only support mode=constant right now
71+ if paddle .get_device ().startswith ('npu' ):
72+ self .padding_mode = 'constant'
73+ else :
74+ self .padding_mode = padding_mode
7075
7176 self .conv = nn .Conv1D (
7277 in_channels ,
@@ -335,10 +340,16 @@ def _compute_statistics(x, m, axis=2, eps=self.eps):
335340 # Apply layers
336341 attn = self .conv (self .tanh (self .tdnn (attn )))
337342
343+ if paddle .get_device ().startswith ('npu' ):
344+ # The following way is designed to fix the 'Broadcast dimension mismatch' error
345+ # that occurs when using the npu device and setting padding_mode to 'constant'.
346+ inf_tensor = paddle .full_like (attn , float ("-inf" ))
347+ else :
348+ # the default way
349+ inf_tensor = paddle .ones_like (attn ) * float ("-inf" )
350+
338351 # Filter out zero-paddings
339- attn = paddle .where (
340- mask .tile ((1 , C , 1 )) == 0 ,
341- paddle .ones_like (attn ) * float ("-inf" ), attn )
352+ attn = paddle .where (mask .tile ((1 , C , 1 )) == 0 , inf_tensor , attn )
342353
343354 attn = F .softmax (attn , axis = 2 )
344355 mean , std = _compute_statistics (x , attn )
0 commit comments