Skip to content

Commit 28e0128

Browse files
committed
👙 Merge branch 'hifigan'
2 parents 042abd0 + e1ff1ec commit 28e0128

20 files changed

+1233
-19
lines changed

.github/workflows/ci.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
max-parallel: 10
1818
matrix:
1919
python-version: [3.7]
20-
tensorflow-version: [2.3.0]
20+
tensorflow-version: [2.3.1]
2121
steps:
2222
- uses: actions/checkout@master
2323
- uses: actions/setup-python@v1

.gitignore

-1
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,4 @@ dump_baker/
4242
dump_ljspeech/
4343
dump_kss/
4444
dump_libritts/
45-
/examples/*/*
4645
/notebooks/test_saved/

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
:zany_face: TensorFlowTTS provides real-time state-of-the-art speech synthesis architectures such as Tacotron-2, Melgan, Multiband-Melgan, FastSpeech, FastSpeech2 based-on TensorFlow 2. With Tensorflow 2, we can speed-up training/inference progress, optimizer further by using [fake-quantize aware](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide) and [pruning](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras), make TTS models can be run faster than real-time and be able to deploy on mobile devices or embedded systems.
2020

2121
## What's new
22+
- 2020/11/24 **(NEW!)** Add HiFi-GAN vocoder. See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/hifigan)
2223
- 2020/11/19 **(NEW!)** Add Multi-GPU gradient accumulator. See [here](https://github.com/TensorSpeech/TensorFlowTTS/pull/377)
2324
- 2020/08/23 Add Parallel WaveGAN tensorflow implementation. See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/parallel_wavegan)
2425
- 2020/08/23 Add MBMelGAN G + ParallelWaveGAN G example. See [here](https://github.com/TensorSpeech/TensorFlowTTS/tree/master/examples/multiband_pwgan)
@@ -85,6 +86,7 @@ TensorFlowTTS currently provides the following architectures:
8586
4. **Multi-band MelGAN** released with the paper [Multi-band MelGAN: Faster Waveform Generation for High-Quality Text-to-Speech](https://arxiv.org/abs/2005.05106) by Geng Yang, Shan Yang, Kai Liu, Peng Fang, Wei Chen, Lei Xie.
8687
5. **FastSpeech2** released with the paper [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558) by Yi Ren, Chenxu Hu, Xu Tan, Tao Qin, Sheng Zhao, Zhou Zhao, Tie-Yan Liu.
8788
6. **Parallel WaveGAN** released with the paper [Parallel WaveGAN: A fast waveform generation model based on generative adversarial networks with multi-resolution spectrogram](https://arxiv.org/abs/1910.11480) by Ryuichi Yamamoto, Eunwoo Song, Jae-Min Kim.
89+
7. **HiFi-GAN** released with the paper [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646) by Jungil Kong, Jaehyeon Kim, Jaekyoung Bae.
8890

8991
We are also implementing some techniques to improve quality and convergence speed from the following papers:
9092

@@ -217,6 +219,7 @@ To know how to train model from scratch or fine-tune with other datasets/languag
217219
- For Multiband-MelGAN tutorial, pls see [examples/multiband_melgan](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/multiband_melgan)
218220
- For Parallel WaveGAN tutorial, pls see [examples/parallel_wavegan](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/parallel_wavegan)
219221
- For Multiband-MelGAN Generator + Parallel WaveGAN Discriminator tutorial, pls see [examples/multiband_pwgan](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/multiband_pwgan)
222+
- For HiFi-GAN tutorial, pls see [examples/hifigan](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/hifigan)
220223
# Abstract Class Explaination
221224

222225
## Abstract DataLoader Tensorflow-based dataset

examples/hifigan/README.md

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
2+
Based on the script [`train_hifigan.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/hifigan/train_hifigan.py).
3+
4+
## Training HiFi-GAN from scratch with LJSpeech dataset.
5+
This example code show you how to train MelGAN from scratch with Tensorflow 2 based on custom training loop and tf.function. The data used for this example is LJSpeech, you can download the dataset at [link](https://keithito.com/LJ-Speech-Dataset/).
6+
7+
### Step 1: Create Tensorflow based Dataloader (tf.dataset)
8+
First, you need define data loader based on AbstractDataset class (see [`abstract_dataset.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/tensorflow_tts/datasets/abstract_dataset.py)). On this example, a dataloader read dataset from path. I use suffix to classify what file is a audio and mel-spectrogram (see [`audio_mel_dataset.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/melgan/audio_mel_dataset.py)). If you already have preprocessed version of your target dataset, you don't need to use this example dataloader, you just need refer my dataloader and modify **generator function** to adapt with your case. Normally, a generator function should return [audio, mel].
9+
10+
### Step 2: Training from scratch
11+
After you re-define your dataloader, pls modify an input arguments, train_dataset and valid_dataset from [`train_hifigan.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/hifigan/train_hifigan.py). Here is an example command line to training HiFi-GAN from scratch:
12+
13+
First, you need training generator with only stft loss:
14+
15+
```bash
16+
CUDA_VISIBLE_DEVICES=0 python examples/hifigan/train_hifigan.py \
17+
--train-dir ./dump/train/ \
18+
--dev-dir ./dump/valid/ \
19+
--outdir ./examples/hifigan/exp/train.hifigan.v1/ \
20+
--config ./examples/hifigan/conf/hifigan.v1.yaml \
21+
--use-norm 1
22+
--generator_mixed_precision 1 \
23+
--resume ""
24+
```
25+
26+
Then resume and start training generator + discriminator:
27+
28+
```bash
29+
CUDA_VISIBLE_DEVICES=0 python examples/hifigan/train_hifigan.py \
30+
--train-dir ./dump/train/ \
31+
--dev-dir ./dump/valid/ \
32+
--outdir ./examples/hifigan/exp/train.hifigan.v1/ \
33+
--config ./examples/hifigan/conf/hifigan.v1.yaml \
34+
--use-norm 1
35+
--resume ./examples/hifigan/exp/train.hifigan.v1/checkpoints/ckpt-100000
36+
```
37+
38+
IF you want to use MultiGPU to training you can replace `CUDA_VISIBLE_DEVICES=0` by `CUDA_VISIBLE_DEVICES=0,1,2,3` for example. You also need to tune the `batch_size` for each GPU (in config file) by yourself to maximize the performance. Note that MultiGPU now support for Training but not yet support for Decode.
39+
40+
In case you want to resume the training progress, please following below example command line:
41+
42+
```bash
43+
--resume ./examples/hifigan/exp/train.hifigan.v1/checkpoints/ckpt-100000
44+
```
45+
46+
If you want to finetune a model, use `--pretrained` like this with the filename of the generator
47+
```bash
48+
--pretrained ptgenerator.h5
49+
```
50+
51+
**IMPORTANT NOTES**:
52+
53+
- When training generator only, we enable mixed precision to speed-up training progress.
54+
- We don't apply mixed precision when training both generator and discriminator. (Discriminator include group-convolution, which cause discriminator slower when enable mixed precision).
55+
- 100k here is a *discriminator_train_start_steps* parameters from [hifigan.v1.yaml](https://github.com/tensorspeech/TensorflowTTS/tree/master/examples/hifigan/conf/hifigan.v1.yaml)
56+
57+
58+
## Reference
59+
60+
1. https://github.com/descriptinc/melgan-neurips
61+
2. https://github.com/kan-bayashi/ParallelWaveGAN
62+
3. https://github.com/tensorflow/addons
63+
4. [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646)
64+
5. [MelGAN: Generative Adversarial Networks for Conditional Waveform Synthesis](https://arxiv.org/abs/1910.06711)
65+
6. [Parallel WaveGAN: A fast waveform generation model based on generative adversarial networks with multi-resolution spectrogram](https://arxiv.org/abs/1910.11480)

examples/hifigan/conf/hifigan.v1.yaml

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
2+
# This is the hyperparameter configuration file for Hifigan.
3+
# Please make sure this is adjusted for the LJSpeech dataset. If you want to
4+
# apply to the other dataset, you might need to carefully change some parameters.
5+
# This configuration performs 4000k iters.
6+
7+
###########################################################
8+
# FEATURE EXTRACTION SETTING #
9+
###########################################################
10+
sampling_rate: 22050 # Sampling rate of dataset.
11+
hop_size: 256 # Hop size.
12+
format: "npy"
13+
14+
15+
###########################################################
16+
# GENERATOR NETWORK ARCHITECTURE SETTING #
17+
###########################################################
18+
model_type: "hifigan_generator"
19+
20+
hifigan_generator_params:
21+
out_channels: 1
22+
kernel_size: 7
23+
filters: 512
24+
use_bias: true
25+
upsample_scales: [8, 8, 2, 2]
26+
stacks: 3
27+
stack_kernel_size: [3, 7, 11]
28+
stack_dilation_rate: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
29+
use_final_nolinear_activation: true
30+
is_weight_norm: false
31+
32+
###########################################################
33+
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
34+
###########################################################
35+
hifigan_discriminator_params:
36+
out_channels: 1 # Number of output channels (number of subbands).
37+
period_scales: [2, 3, 5, 7, 11] # List of period scales.
38+
n_layers: 5 # Number of layer of each period discriminator.
39+
kernel_size: 5 # Kernel size.
40+
strides: 3 # Strides
41+
filters: 8 # In Conv filters of each period discriminator
42+
filter_scales: 4 # Filter scales.
43+
max_filters: 1024 # maximum filters of period discriminator's conv.
44+
is_weight_norm: false # Use weight-norm or not.
45+
46+
melgan_discriminator_params:
47+
out_channels: 1 # Number of output channels.
48+
scales: 3 # Number of multi-scales.
49+
downsample_pooling: "AveragePooling1D" # Pooling type for the input downsampling.
50+
downsample_pooling_params: # Parameters of the above pooling function.
51+
pool_size: 4
52+
strides: 2
53+
kernel_sizes: [5, 3] # List of kernel size.
54+
filters: 16 # Number of channels of the initial conv layer.
55+
max_downsample_filters: 1024 # Maximum number of channels of downsampling layers.
56+
downsample_scales: [4, 4, 4, 4] # List of downsampling scales.
57+
nonlinear_activation: "LeakyReLU" # Nonlinear activation function.
58+
nonlinear_activation_params: # Parameters of nonlinear activation function.
59+
alpha: 0.2
60+
is_weight_norm: false # Use weight-norm or not.
61+
62+
###########################################################
63+
# STFT LOSS SETTING #
64+
###########################################################
65+
stft_loss_params:
66+
fft_lengths: [1024, 2048, 512] # List of FFT size for STFT-based loss.
67+
frame_steps: [120, 240, 50] # List of hop size for STFT-based loss
68+
frame_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
69+
70+
###########################################################
71+
# ADVERSARIAL LOSS SETTING #
72+
###########################################################
73+
lambda_feat_match: 10.0
74+
lambda_adv: 4.0
75+
76+
###########################################################
77+
# DATA LOADER SETTING #
78+
###########################################################
79+
batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1.
80+
batch_max_steps: 8192 # Length of each audio in batch for training. Make sure dividable by hop_size.
81+
batch_max_steps_valid: 81920 # Length of each audio for validation. Make sure dividable by hope_size.
82+
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
83+
allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory.
84+
is_shuffle: true # shuffle dataset after each epoch.
85+
86+
###########################################################
87+
# OPTIMIZER & SCHEDULER SETTING #
88+
###########################################################
89+
generator_optimizer_params:
90+
lr_fn: "PiecewiseConstantDecay"
91+
lr_params:
92+
boundaries: [100000, 200000, 300000, 400000, 500000, 600000, 700000]
93+
values: [0.0005, 0.0005, 0.00025, 0.000125, 0.0000625, 0.00003125, 0.000015625, 0.000001]
94+
amsgrad: false
95+
96+
discriminator_optimizer_params:
97+
lr_fn: "PiecewiseConstantDecay"
98+
lr_params:
99+
boundaries: [100000, 200000, 300000, 400000, 500000]
100+
values: [0.00025, 0.000125, 0.0000625, 0.00003125, 0.000015625, 0.000001]
101+
amsgrad: false
102+
103+
gradient_accumulation_steps: 1 # should be even number or 1.
104+
###########################################################
105+
# INTERVAL SETTING #
106+
###########################################################
107+
discriminator_train_start_steps: 100000 # steps begin training discriminator
108+
train_max_steps: 4000000 # Number of training steps.
109+
save_interval_steps: 20000 # Interval steps to save checkpoint.
110+
eval_interval_steps: 5000 # Interval steps to evaluate the network.
111+
log_interval_steps: 200 # Interval steps to record the training log.
112+
113+
###########################################################
114+
# OTHER SETTING #
115+
###########################################################
116+
num_save_intermediate_results: 1 # Number of batch to be saved as intermediate results.

examples/hifigan/conf/hifigan.v2.yaml

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
2+
# This is the hyperparameter configuration file for Hifigan.
3+
# Please make sure this is adjusted for the LJSpeech dataset. If you want to
4+
# apply to the other dataset, you might need to carefully change some parameters.
5+
# This configuration performs 4000k iters.
6+
7+
###########################################################
8+
# FEATURE EXTRACTION SETTING #
9+
###########################################################
10+
sampling_rate: 22050 # Sampling rate of dataset.
11+
hop_size: 256 # Hop size.
12+
format: "npy"
13+
14+
15+
###########################################################
16+
# GENERATOR NETWORK ARCHITECTURE SETTING #
17+
###########################################################
18+
model_type: "hifigan_generator"
19+
20+
hifigan_generator_params:
21+
out_channels: 1
22+
kernel_size: 7
23+
filters: 128
24+
use_bias: true
25+
upsample_scales: [8, 8, 2, 2]
26+
stacks: 3
27+
stack_kernel_size: [3, 7, 11]
28+
stack_dilation_rate: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
29+
use_final_nolinear_activation: true
30+
is_weight_norm: false
31+
32+
###########################################################
33+
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
34+
###########################################################
35+
hifigan_discriminator_params:
36+
out_channels: 1 # Number of output channels (number of subbands).
37+
period_scales: [2, 3, 5, 7, 11] # List of period scales.
38+
n_layers: 5 # Number of layer of each period discriminator.
39+
kernel_size: 5 # Kernel size.
40+
strides: 3 # Strides
41+
filters: 8 # In Conv filters of each period discriminator
42+
filter_scales: 4 # Filter scales.
43+
max_filters: 512 # maximum filters of period discriminator's conv.
44+
is_weight_norm: false # Use weight-norm or not.
45+
46+
melgan_discriminator_params:
47+
out_channels: 1 # Number of output channels.
48+
scales: 3 # Number of multi-scales.
49+
downsample_pooling: "AveragePooling1D" # Pooling type for the input downsampling.
50+
downsample_pooling_params: # Parameters of the above pooling function.
51+
pool_size: 4
52+
strides: 2
53+
kernel_sizes: [5, 3] # List of kernel size.
54+
filters: 16 # Number of channels of the initial conv layer.
55+
max_downsample_filters: 512 # Maximum number of channels of downsampling layers.
56+
downsample_scales: [4, 4, 4, 4] # List of downsampling scales.
57+
nonlinear_activation: "LeakyReLU" # Nonlinear activation function.
58+
nonlinear_activation_params: # Parameters of nonlinear activation function.
59+
alpha: 0.2
60+
is_weight_norm: false # Use weight-norm or not.
61+
62+
###########################################################
63+
# STFT LOSS SETTING #
64+
###########################################################
65+
stft_loss_params:
66+
fft_lengths: [1024, 2048, 512] # List of FFT size for STFT-based loss.
67+
frame_steps: [120, 240, 50] # List of hop size for STFT-based loss
68+
frame_lengths: [600, 1200, 240] # List of window length for STFT-based loss.
69+
70+
###########################################################
71+
# ADVERSARIAL LOSS SETTING #
72+
###########################################################
73+
lambda_feat_match: 10.0
74+
lambda_adv: 4.0
75+
76+
###########################################################
77+
# DATA LOADER SETTING #
78+
###########################################################
79+
batch_size: 16 # Batch size for each GPU with assuming that gradient_accumulation_steps == 1.
80+
batch_max_steps: 8192 # Length of each audio in batch for training. Make sure dividable by hop_size.
81+
batch_max_steps_valid: 81920 # Length of each audio for validation. Make sure dividable by hope_size.
82+
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
83+
allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory.
84+
is_shuffle: true # shuffle dataset after each epoch.
85+
86+
###########################################################
87+
# OPTIMIZER & SCHEDULER SETTING #
88+
###########################################################
89+
generator_optimizer_params:
90+
lr_fn: "PiecewiseConstantDecay"
91+
lr_params:
92+
boundaries: [100000, 200000, 300000, 400000, 500000, 600000, 700000]
93+
values: [0.0005, 0.0005, 0.00025, 0.000125, 0.0000625, 0.00003125, 0.000015625, 0.000001]
94+
amsgrad: false
95+
96+
discriminator_optimizer_params:
97+
lr_fn: "PiecewiseConstantDecay"
98+
lr_params:
99+
boundaries: [100000, 200000, 300000, 400000, 500000]
100+
values: [0.00025, 0.000125, 0.0000625, 0.00003125, 0.000015625, 0.000001]
101+
amsgrad: false
102+
103+
gradient_accumulation_steps: 1 # should be even number or 1.
104+
###########################################################
105+
# INTERVAL SETTING #
106+
###########################################################
107+
discriminator_train_start_steps: 100000 # steps begin training discriminator
108+
train_max_steps: 4000000 # Number of training steps.
109+
save_interval_steps: 20000 # Interval steps to save checkpoint.
110+
eval_interval_steps: 5000 # Interval steps to evaluate the network.
111+
log_interval_steps: 200 # Interval steps to record the training log.
112+
113+
###########################################################
114+
# OTHER SETTING #
115+
###########################################################
116+
num_save_intermediate_results: 1 # Number of batch to be saved as intermediate results.

0 commit comments

Comments
 (0)