forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.yaml
97 lines (77 loc) · 2.91 KB
/
train.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# #################################
# Basic training parameters
# To train a different model, change "!include:" statement to new model file
# To compute loss in the time domain, switch "waveform_target" to True
# Authors:
# * Szu-Wei Fu 2020
# * Peter Plantinga 2020, 2021
# #################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 17234
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]
data_folder: !PLACEHOLDER # e.g, /data/member1/user_jasonfu/noisy-vctk-16k
output_folder: !ref ./results/spectral_mask/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
enhanced_folder: !ref <output_folder>/enhanced_wavs
# Basic parameters
use_tensorboard: False
tensorboard_logs: !ref <output_folder>/logs/
# FFT parameters
Sample_rate: 16000
Win_length: 32
Hop_length: 16
N_fft: 512
window_fn: !name:torch.hamming_window
# Data files
train_annotation: !ref <output_folder>/train_revb.json
valid_annotation: !ref <output_folder>/valid_revb.json
test_annotation: !ref <output_folder>/test_revb.json
skip_prep: False
# Training Parameters
number_of_epochs: 150
N_batch: 8
lr: 0.00013
n_jobs: 1 # number of jobs for metric evaluation (increase it for a speed up)
sorting: ascending
dataloader_options:
batch_size: !ref <N_batch>
waveform_target: False # Switch to TRUE to
# Change this import to use a different model
models: !include:models/BLSTM.yaml
N_fft: !ref <N_fft>
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
modules:
generator: !ref <models[generator]>
g_opt_class: !name:torch.optim.Adam
lr: !ref <lr>
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
generator: !ref <models[generator]>
counter: !ref <epoch_counter>
compute_cost: !name:speechbrain.nnet.losses.mse_loss
# To use STOI loss, switch "waveform_target" to True
# compute_cost: !name:speechbrain.nnet.loss.stoi_loss.stoi_loss
compute_STFT: !new:speechbrain.processing.features.STFT
sample_rate: !ref <Sample_rate>
win_length: !ref <Win_length>
hop_length: !ref <Hop_length>
n_fft: !ref <N_fft>
window_fn: !ref <window_fn>
compute_ISTFT: !new:speechbrain.processing.features.ISTFT
sample_rate: !ref <Sample_rate>
win_length: !ref <Win_length>
hop_length: !ref <Hop_length>
window_fn: !ref <window_fn>
resynth: !name:speechbrain.processing.signal_processing.resynthesize
stft: !ref <compute_STFT>
istft: !ref <compute_ISTFT>
# mean_var_norm: !new:speechbrain.processing.features.InputNormalization
# norm_type: sentence
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
# torch.Tensorboard logger (optional)
tensorboard_train_logger: !new:speechbrain.utils.train_logger.TensorboardLogger
save_dir: !ref <tensorboard_logs>