diff --git a/cosyvoice/flow/decoder.py b/cosyvoice/flow/decoder.py index 420a1bfc..06cd9a47 100644 --- a/cosyvoice/flow/decoder.py +++ b/cosyvoice/flow/decoder.py @@ -75,7 +75,6 @@ def __init__( groups=groups, bias=bias, padding_mode=padding_mode, device=device, dtype=dtype) - assert stride == 1 self.causal_padding = (kernel_size - 1, 0) def forward(self, x: torch.Tensor): @@ -83,6 +82,21 @@ def forward(self, x: torch.Tensor): x = super(CausalConv1d, self).forward(x) return x +class CausalDownsample1D(nn.Module): + def __init__(self, dim, channel_first=True): + super().__init__() + self.channel_first = channel_first + self.conv = CausalConv1d(dim, dim, 3, stride=2) + + def forward(self, x): + if not self.channel_first: + x = x.transpose(1, 2).contiguous() + + out = self.conv(x) + + if not self.channel_first: + out = out.transpose(1, 2).contiguous() + return out class ConditionalDecoder(nn.Module): def __init__( @@ -138,8 +152,8 @@ def __init__( ] ) downsample = ( - Downsample1D(output_channel) if not is_last else - CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) + (CausalDownsample1D(output_channel) if self.causal else Downsample1D(output_channel)) if not is_last else + (CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)) ) self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))