Skip to content

Commit a36635d

Browse files
committed
figured out the quality issue!
1 parent 1305f57 commit a36635d

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

torch_pitch_shift/main.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,9 @@ def pitch_shift(
110110
input: torch.Tensor,
111111
shift: Union[float, Fraction],
112112
sample_rate: int,
113-
n_fft: Optional[int] = 0,
114113
bins_per_octave: Optional[int] = 12,
114+
n_fft: Optional[int] = 0,
115+
hop_length: Optional[int] = 0,
115116
) -> torch.Tensor:
116117
"""
117118
Shift the pitch of a batch of waveforms by a given amount.
@@ -125,30 +126,35 @@ def pitch_shift(
125126
`Fraction`: A `fractions.Fraction` object indicating the shift ratio. Usually an element in `get_fast_shifts()`.
126127
sample_rate: int
127128
The sample rate of the input audio clips.
128-
n_fft: int [optional]
129-
Size of FFT. Default is `sample_rate // 64`. Smaller is faster.
130129
bins_per_octave: int [optional]
131130
Number of bins per octave. Default is 12.
131+
n_fft: int [optional]
132+
Size of FFT. Default is `sample_rate // 64`.
133+
hop_length: int [optional]
134+
Size of hop length. Default is `n_fft // 32`.
132135
133136
Returns
134137
-------
135138
output: torch.Tensor [shape=(batch_size, channels, samples)]
136139
The pitch-shifted batch of audio clips
137140
"""
141+
138142
if not n_fft:
139143
n_fft = sample_rate // 64
144+
if not hop_length:
145+
hop_length = n_fft // 32
140146
batch_size, channels, samples = input.shape
141147
if not isinstance(shift, Fraction):
142148
shift = 2.0 ** (float(shift) / bins_per_octave)
143149
resampler = T.Resample(sample_rate, int(sample_rate / shift)).to(input.device)
144150
output = input
145151
output = output.reshape(batch_size * channels, samples)
146-
output = torch.stft(output, n_fft)[None, ...]
147-
stretcher = T.TimeStretch(fixed_rate=float(1 / shift), n_freq=output.shape[2]).to(
148-
input.device
149-
)
152+
output = torch.stft(output, n_fft, hop_length)[None, ...]
153+
stretcher = T.TimeStretch(
154+
fixed_rate=float(1 / shift), n_freq=output.shape[2], hop_length=hop_length
155+
).to(input.device)
150156
output = stretcher(output)
151-
output = torch.istft(output[0], n_fft)
157+
output = torch.istft(output[0], n_fft, hop_length)
152158
output = resampler(output)
153159
del resampler, stretcher
154160
if output.shape[1] >= input.shape[2]:

wavs/test.wav

-7.38 MB
Binary file not shown.

wavs/test_unused.wav

1.53 MB
Binary file not shown.

0 commit comments

Comments
 (0)