@@ -110,8 +110,9 @@ def pitch_shift(
110110 input : torch .Tensor ,
111111 shift : Union [float , Fraction ],
112112 sample_rate : int ,
113- n_fft : Optional [int ] = 0 ,
114113 bins_per_octave : Optional [int ] = 12 ,
114+ n_fft : Optional [int ] = 0 ,
115+ hop_length : Optional [int ] = 0 ,
115116) -> torch .Tensor :
116117 """
117118 Shift the pitch of a batch of waveforms by a given amount.
@@ -125,30 +126,35 @@ def pitch_shift(
125126 `Fraction`: A `fractions.Fraction` object indicating the shift ratio. Usually an element in `get_fast_shifts()`.
126127 sample_rate: int
127128 The sample rate of the input audio clips.
128- n_fft: int [optional]
129- Size of FFT. Default is `sample_rate // 64`. Smaller is faster.
130129 bins_per_octave: int [optional]
131130 Number of bins per octave. Default is 12.
131+ n_fft: int [optional]
132+ Size of FFT. Default is `sample_rate // 64`.
133+ hop_length: int [optional]
134+ Size of hop length. Default is `n_fft // 32`.
132135
133136 Returns
134137 -------
135138 output: torch.Tensor [shape=(batch_size, channels, samples)]
136139 The pitch-shifted batch of audio clips
137140 """
141+
138142 if not n_fft :
139143 n_fft = sample_rate // 64
144+ if not hop_length :
145+ hop_length = n_fft // 32
140146 batch_size , channels , samples = input .shape
141147 if not isinstance (shift , Fraction ):
142148 shift = 2.0 ** (float (shift ) / bins_per_octave )
143149 resampler = T .Resample (sample_rate , int (sample_rate / shift )).to (input .device )
144150 output = input
145151 output = output .reshape (batch_size * channels , samples )
146- output = torch .stft (output , n_fft )[None , ...]
147- stretcher = T .TimeStretch (fixed_rate = float ( 1 / shift ), n_freq = output . shape [ 2 ]). to (
148- input . device
149- )
152+ output = torch .stft (output , n_fft , hop_length )[None , ...]
153+ stretcher = T .TimeStretch (
154+ fixed_rate = float ( 1 / shift ), n_freq = output . shape [ 2 ], hop_length = hop_length
155+ ). to ( input . device )
150156 output = stretcher (output )
151- output = torch .istft (output [0 ], n_fft )
157+ output = torch .istft (output [0 ], n_fft , hop_length )
152158 output = resampler (output )
153159 del resampler , stretcher
154160 if output .shape [1 ] >= input .shape [2 ]:
0 commit comments