|
| 1 | +import warnings |
1 | 2 | from collections import Counter |
2 | 3 | from fractions import Fraction |
3 | 4 | from functools import reduce |
4 | 5 | from itertools import chain, count, islice, repeat |
5 | | -from typing import Union, Callable, List, Optional |
6 | | -from torch.nn.functional import pad |
| 6 | +from math import log2 |
| 7 | +from typing import Callable, List, Optional, Union |
| 8 | + |
7 | 9 | import torch |
| 10 | +import torchaudio |
8 | 11 | import torchaudio.transforms as T |
| 12 | +from packaging import version |
9 | 13 | from primePy import primes |
10 | | -from math import log2 |
11 | | -import warnings |
| 14 | +from torch.nn.functional import pad |
12 | 15 |
|
13 | 16 | warnings.simplefilter("ignore") |
14 | 17 |
|
@@ -149,7 +152,8 @@ def pitch_shift( |
149 | 152 | resampler = T.Resample(sample_rate, int(sample_rate / shift)).to(input.device) |
150 | 153 | output = input |
151 | 154 | output = output.reshape(batch_size * channels, samples) |
152 | | - output = torch.stft(output, n_fft, hop_length, return_complex=True)[None, ...] |
| 155 | + v011 = version.parse(torchaudio.__version__) >= version.parse("0.11.0") |
| 156 | + output = torch.stft(output, n_fft, hop_length, return_complex=v011)[None, ...] |
153 | 157 | stretcher = T.TimeStretch( |
154 | 158 | fixed_rate=float(1 / shift), n_freq=output.shape[2], hop_length=hop_length |
155 | 159 | ).to(input.device) |
|
0 commit comments