forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_dereverb.yaml
119 lines (98 loc) · 3.53 KB
/
train_dereverb.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# #################################
# Basic training parameters
# To train a different model, change "!include:" statement to new model file
# To compute loss in the time domain, switch "waveform_target" to True
# Authors:
# * Szu-Wei Fu 2021
# * Peter Plantinga 2020, 2021
# #################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 12234
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]
data_folder: !PLACEHOLDER # e.g, /data/member1/user_jasonfu/noisy-vctk-16k
MetricGAN_folder: !ref <output_folder>/enhanced_wavs
output_folder: !ref ./results/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
enhanced_folder: !ref <output_folder>/enhanced_wavs
historical_file: !ref <output_folder>/historical.txt
# Basic parameters
use_tensorboard: False
tensorboard_logs: !ref <output_folder>/logs/
# FFT parameters
Sample_rate: 16000
Win_length: 32
Hop_length: 16
N_fft: 512
window_fn: !name:torch.hamming_window
# Data files
train_annotation: !ref <data_folder>/train_revb.json
valid_annotation: !ref <data_folder>/valid_revb.json
test_annotation: !ref <data_folder>/test_revb.json
skip_prep: False
# The target metrics that you want to optimize.
# Right now we only support 'dnsmos', and 'srmr'.
# (Of course, it can be any arbitrary metric.)
target_metric: srmr
calculate_dnsmos_on_validation_set: False
target_score: 1
n_jobs: 1 # Number of jobs to compute metrics (increase it for a speed up)
# Training Parameters
number_of_epochs: 250
number_of_samples: 100
min_mask: 0.2
train_N_batch: 1
valid_N_batch: 20
history_portion: 0.2
G_lr: 0.000002
D_lr: 0.0005
mse_weight: 0.6
dataloader_options:
batch_size: !ref <train_N_batch>
valid_dataloader_options:
batch_size: !ref <valid_N_batch>
# Change this import to use a different model
models: !include:models/MetricGAN_U.yaml
N_fft: !ref <N_fft>
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
modules:
generator: !ref <models[generator]>
discriminator: !ref <models[discriminator]>
g_opt_class: !name:torch.optim.Adam
lr: !ref <G_lr>
d_opt_class: !name:torch.optim.Adam
lr: !ref <D_lr>
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
generator: !ref <models[generator]>
discriminator: !ref <models[discriminator]>
counter: !ref <epoch_counter>
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <G_lr>
annealing_factor: 1
improvement_threshold: 0.0
patient: 0
compute_cost: !name:speechbrain.nnet.losses.mse_loss
compute_si_snr: !name:speechbrain.nnet.loss.si_snr_loss.si_snr_loss
compute_STFT: !new:speechbrain.processing.features.STFT
sample_rate: !ref <Sample_rate>
win_length: !ref <Win_length>
hop_length: !ref <Hop_length>
n_fft: !ref <N_fft>
window_fn: !ref <window_fn>
compute_ISTFT: !new:speechbrain.processing.features.ISTFT
sample_rate: !ref <Sample_rate>
win_length: !ref <Win_length>
hop_length: !ref <Hop_length>
window_fn: !ref <window_fn>
resynth: !name:speechbrain.processing.signal_processing.resynthesize
stft: !ref <compute_STFT>
istft: !ref <compute_ISTFT>
normalize_wavs: False
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
# torch.Tensorboard logger (optional)
tensorboard_train_logger: !new:speechbrain.utils.train_logger.TensorboardLogger
save_dir: !ref <tensorboard_logs>