forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_dar_with_wav2vec.yaml
212 lines (176 loc) · 6.41 KB
/
train_dar_with_wav2vec.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# ################################
# Model: wav2vec2 + DNN + CTC
# Augmentation: SpecAugment
# Authors: Abdou Mohamed Naira 2022
# ################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1234
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]
output_folder: !ref results/wav2vec2_ctc_DAR/<seed>
test_wer_file: !ref <output_folder>/wer_test.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
# URL for the xlsr wav2vec2
wav2vec2_hub: facebook/wav2vec2-large-xlsr-53
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
# Data files
data_folder: #!PLACEHOLDER # e.g, /dataset/
train_csv_file: !ref <data_folder>/texts/train.csv
dev_csv_file: !ref <data_folder>/texts/dev.csv
test_csv_file: !ref <data_folder>/texts/test.csv
accented_letters: True
language: darija
train_csv: !ref <save_folder>/train.csv
valid_csv: !ref <save_folder>/dev.csv
test_csv: !ref <save_folder>/test.csv
skip_prep: False # Skip data preparation
# We remove utterance slonger than 10s in the train/dev/test sets as
# longer sentences certainly correspond to "open microphones".
avoid_if_longer_than: 15.0
####################### Training Parameters ####################################
number_of_epochs: 30
lr: 1.0
lr_wav2vec: 0.0001
sorting: ascending
precision: fp32 # bf16, fp16 or fp32
sample_rate: 16000
ckpt_interval_minutes: 30 # save checkpoint every N min
# With data_parallel batch_size is split into N jobs
# With DDP batch_size is multiplied by N jobs
# Must be 6 per GPU to fit 16GB of VRAM
batch_size: 4
test_batch_size: 4
dataloader_options:
batch_size: !ref <batch_size>
num_workers: 2
test_dataloader_options:
batch_size: !ref <test_batch_size>
num_workers: 2
# BPE parameters
token_type: char # ["unigram", "bpe", "char"]
character_coverage: 1.0
####################### Model Parameters #######################################
wav2vec_output_dim: 1024
dnn_neurons: 1024
freeze_wav2vec: False
# Outputs
output_neurons: 36 # BPE size, index(blank/eos/bos) = 0
# Decoding parameters
# Be sure that the bos and eos index match with the BPEs ones
blank_index: 0
bos_index: 1
eos_index: 2
#
# Functions and classes
#
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
############################## Augmentations ###################################
# Speed perturbation
speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
orig_freq: !ref <sample_rate>
speeds: [95, 100, 105]
# Frequency drop: randomly drops a number of frequency bands to zero.
drop_freq: !new:speechbrain.augment.time_domain.DropFreq
drop_freq_low: 0
drop_freq_high: 1
drop_freq_count_low: 1
drop_freq_count_high: 3
drop_freq_width: 0.05
# Time drop: randomly drops a number of temporal chunks.
drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
drop_length_low: 1000
drop_length_high: 2000
drop_count_low: 1
drop_count_high: 5
# Augmenter: Combines previously defined augmentations to perform data augmentation
wav_augment: !new:speechbrain.augment.augmenter.Augmenter
concat_original: True
min_augmentations: 4
max_augmentations: 4
augment_prob: 1.0
augmentations: [
!ref <speed_perturb>,
!ref <drop_freq>,
!ref <drop_chunk>]
############################## Models ##########################################
enc: !new:speechbrain.nnet.containers.Sequential
input_shape: [null, null, !ref <wav2vec_output_dim>]
linear1: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
activation: !new:torch.nn.LeakyReLU
drop: !new:torch.nn.Dropout
p: 0.15
linear2: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
activation2: !new:torch.nn.LeakyReLU
drop2: !new:torch.nn.Dropout
p: 0.15
linear3: !name:speechbrain.nnet.linear.Linear
n_neurons: !ref <dnn_neurons>
bias: True
bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
activation3: !new:torch.nn.LeakyReLU
wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
source: !ref <wav2vec2_hub>
output_norm: True
freeze: !ref <freeze_wav2vec>
save_path: !ref <wav2vec2_folder>
#####
# Uncomment this block if you prefer to use a Fairseq pretrained model instead
# of a HuggingFace one. Here, we provide an URL that is obtained from the
# Fairseq github for the multilingual XLSR.
#
#wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt
#wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
# pretrained_path: !ref <wav2vec2_url>
# output_norm: True
# freeze: False
# save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
#####
ctc_lin: !new:speechbrain.nnet.linear.Linear
input_size: !ref <dnn_neurons>
n_neurons: !ref <output_neurons>
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
blank_index: !ref <blank_index>
modules:
wav2vec2: !ref <wav2vec2>
enc: !ref <enc>
ctc_lin: !ref <ctc_lin>
model: !new:torch.nn.ModuleList
- [!ref <enc>, !ref <ctc_lin>]
model_opt_class: !name:torch.optim.Adadelta
lr: !ref <lr>
rho: 0.95
eps: 1.e-8
wav2vec_opt_class: !name:torch.optim.Adam
lr: !ref <lr_wav2vec>
lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.0025
annealing_factor: 0.8
patient: 0
lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr_wav2vec>
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 0
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
wav2vec2: !ref <wav2vec2>
model: !ref <model>
scheduler_model: !ref <lr_annealing_model>
scheduler_wav2vec: !ref <lr_annealing_wav2vec>
counter: !ref <epoch_counter>
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True