Skip to content

Commit 522826d

Browse files
authored
Merge pull request #41 from keithito/tf-griffin-lim
Add TensorFlow implementation of Griffin-Lim
2 parents ab5dae1 + 2460969 commit 522826d

File tree

4 files changed

+83
-29
lines changed

4 files changed

+83
-29
lines changed

README.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,23 @@ Pull requests are welcome!
2828
## Quick Start
2929

3030
### Installing dependencies
31-
Make sure you have installed Python 3 and [TensorFlow](https://www.tensorflow.org/install/). Then:
32-
```
33-
pip install -r requirements.txt
34-
```
31+
32+
1. Install Python 3.
33+
34+
2. Install [TensorFlow 1.3](https://www.tensorflow.org/install/). Install with GPU support if it's
35+
available for your platform.
36+
37+
3. Install requirements:
38+
```
39+
pip install -r requirements.txt
40+
```
3541

3642

3743
### Using a pre-trained model
3844

3945
1. **Download and unpack a model**:
4046
```
41-
curl http://data.keithito.com/data/speech/tacotron-20170720.tar.bz2 | tar xj -C /tmp
47+
curl http://data.keithito.com/data/speech/tacotron-20170720.tar.bz2 | tar xjC /tmp
4248
```
4349

4450
2. **Run the demo server**:

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
# Note: this doesn't include tensorflow or tensorflow-gpu because the package you need to install
2+
# depends on your platform. It is assumed you have already installed tensorflow.
13
falcon==1.2.0
24
inflect==0.2.5
35
librosa==0.5.1
46
matplotlib==2.0.2
57
numpy==1.13.0
68
scipy==0.19.0
7-
tensorflow==1.2.0
8-
tensorflow-gpu==1.2.0
99
tqdm==4.11.2
1010
Unidecode==0.4.20

synthesizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def load(self, checkpoint_path, model_name='tacotron'):
1515
with tf.variable_scope('model') as scope:
1616
self.model = create_model(model_name, hparams)
1717
self.model.initialize(inputs, input_lengths)
18+
self.wav_output = audio.inv_spectrogram_tensorflow(self.model.linear_outputs[0])
1819

1920
print('Loading checkpoint: %s' % checkpoint_path)
2021
self.session = tf.Session()
@@ -30,7 +31,7 @@ def synthesize(self, text):
3031
self.model.inputs: [np.asarray(seq, dtype=np.int32)],
3132
self.model.input_lengths: np.asarray([len(seq)], dtype=np.int32)
3233
}
33-
spec = self.session.run(self.model.linear_outputs[0], feed_dict=feed_dict)
34+
wav = self.session.run(self.wav_output, feed_dict=feed_dict)
3435
out = io.BytesIO()
35-
audio.save_wav(audio.inv_spectrogram(spec.T), out)
36+
audio.save_wav(audio.inv_preemphasis(wav), out)
3637
return out.getvalue()

util/audio.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import librosa.filters
33
import math
44
import numpy as np
5+
import tensorflow as tf
56
from scipy import signal
67
from hparams import hparams
78

@@ -15,50 +16,96 @@ def save_wav(wav, path):
1516
librosa.output.write_wav(path, wav.astype(np.int16), hparams.sample_rate)
1617

1718

19+
def preemphasis(x):
20+
return signal.lfilter([1, -hparams.preemphasis], [1], x)
21+
22+
23+
def inv_preemphasis(x):
24+
return signal.lfilter([1], [1, -hparams.preemphasis], x)
25+
26+
1827
def spectrogram(y):
19-
D = _stft(_preemphasis(y))
28+
D = _stft(preemphasis(y))
2029
S = _amp_to_db(np.abs(D)) - hparams.ref_level_db
2130
return _normalize(S)
2231

2332

2433
def inv_spectrogram(spectrogram):
34+
'''Converts spectrogram to waveform using librosa'''
2535
S = _db_to_amp(_denormalize(spectrogram) + hparams.ref_level_db) # Convert back to linear
26-
return _inv_preemphasis(_griffin_lim(S ** hparams.power)) # Reconstruct phase
36+
return inv_preemphasis(_griffin_lim(S ** hparams.power)) # Reconstruct phase
37+
38+
39+
def inv_spectrogram_tensorflow(spectrogram):
40+
'''Builds computational graph to convert spectrogram to waveform using TensorFlow.
41+
42+
Unlike inv_spectrogram, this does NOT invert the preemphasis. The caller should call
43+
inv_preemphasis on the output after running the graph.
44+
'''
45+
S = _db_to_amp_tensorflow(_denormalize_tensorflow(spectrogram) + hparams.ref_level_db)
46+
return _griffin_lim_tensorflow(tf.pow(S, hparams.power))
2747

2848

2949
def melspectrogram(y):
30-
D = _stft(_preemphasis(y))
50+
D = _stft(preemphasis(y))
3151
S = _amp_to_db(_linear_to_mel(np.abs(D)))
3252
return _normalize(S)
3353

3454

35-
def inv_melspectrogram(melspectrogram):
36-
S = _mel_to_linear(_db_to_amp(_denormalize(melspectrogram))) # Convert back to linear
37-
return _inv_preemphasis(_griffin_lim(S ** hparams.power)) # Reconstruct phase
38-
39-
40-
# Based on https://github.com/librosa/librosa/issues/434
4155
def _griffin_lim(S):
56+
'''librosa implementation of Griffin-Lim
57+
Based on https://github.com/librosa/librosa/issues/434
58+
'''
4259
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
4360
S_complex = np.abs(S).astype(np.complex)
61+
y = _istft(S_complex * angles)
4462
for i in range(hparams.griffin_lim_iters):
45-
if i > 0:
46-
angles = np.exp(1j * np.angle(_stft(y)))
63+
angles = np.exp(1j * np.angle(_stft(y)))
4764
y = _istft(S_complex * angles)
4865
return y
4966

5067

68+
def _griffin_lim_tensorflow(S):
69+
'''TensorFlow implementation of Griffin-Lim
70+
Based on https://github.com/Kyubyong/tensorflow-exercises/blob/master/Audio_Processing.ipynb
71+
'''
72+
with tf.variable_scope('griffinlim'):
73+
# TensorFlow's stft and istft operate on a batch of spectrograms; create batch of size 1
74+
S = tf.expand_dims(S, 0)
75+
S_complex = tf.identity(tf.cast(S, dtype=tf.complex64))
76+
y = _istft_tensorflow(S_complex)
77+
for i in range(hparams.griffin_lim_iters):
78+
est = _stft_tensorflow(y)
79+
angles = est / tf.cast(tf.maximum(1e-8, tf.abs(est)), tf.complex64)
80+
y = _istft_tensorflow(S_complex * angles)
81+
return tf.squeeze(y, 0)
82+
83+
5184
def _stft(y):
52-
n_fft = (hparams.num_freq - 1) * 2
53-
hop_length = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
54-
win_length = int(hparams.frame_length_ms / 1000 * hparams.sample_rate)
85+
n_fft, hop_length, win_length = _stft_parameters()
5586
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
5687

5788

5889
def _istft(y):
90+
_, hop_length, win_length = _stft_parameters()
91+
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
92+
93+
94+
def _stft_tensorflow(signals):
95+
n_fft, hop_length, win_length = _stft_parameters()
96+
return tf.contrib.signal.stft(signals, win_length, hop_length, n_fft, pad_end=False)
97+
98+
99+
def _istft_tensorflow(stfts):
100+
n_fft, hop_length, win_length = _stft_parameters()
101+
return tf.contrib.signal.inverse_stft(stfts, win_length, hop_length, n_fft)
102+
103+
104+
def _stft_parameters():
105+
n_fft = (hparams.num_freq - 1) * 2
59106
hop_length = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
60107
win_length = int(hparams.frame_length_ms / 1000 * hparams.sample_rate)
61-
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
108+
return n_fft, hop_length, win_length
62109

63110

64111
# Conversions:
@@ -88,14 +135,14 @@ def _amp_to_db(x):
88135
def _db_to_amp(x):
89136
return np.power(10.0, x * 0.05)
90137

91-
def _preemphasis(x):
92-
return signal.lfilter([1, -hparams.preemphasis], [1], x)
93-
94-
def _inv_preemphasis(x):
95-
return signal.lfilter([1], [1, -hparams.preemphasis], x)
138+
def _db_to_amp_tensorflow(x):
139+
return tf.pow(tf.ones(tf.shape(x)) * 10.0, x * 0.05)
96140

97141
def _normalize(S):
98142
return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1)
99143

100144
def _denormalize(S):
101145
return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db
146+
147+
def _denormalize_tensorflow(S):
148+
return (tf.clip_by_value(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db

0 commit comments

Comments
 (0)