forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_hf_whisper.yaml
165 lines (130 loc) · 5.58 KB
/
train_hf_whisper.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# ################################
# Model: Whisper (Encoder-Decoder) + NLL
# Augmentation: TimeDomainSpecAugment
# Authors: Pooneh Mousavi 2022
# ################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]
output_folder: !ref results/train_whisper/<seed>/<language>
test_wer_file: !ref <output_folder>/wer_test.txt
valid_wer_file: !ref <output_folder>/wer_valid.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
# URL for the biggest Fairseq english whisper model.
whisper_hub: openai/whisper-medium
# Normalize inputs with the same normalization done in the paper (https://cdn.openai.com/papers/whisper.pdf). Refer to Appendix C for further information.
normalized_transcripts: True
# Data files
language: fr # use 'it' for italian, 'fr' for french, 'en' for english , It is a language for common-voice data.
data_folder: !PLACEHOLDER
train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
accented_letters: True
train_csv: !ref <save_folder>/train.csv
valid_csv: !ref <save_folder>/dev.csv
test_csv: !ref <save_folder>/test.csv
skip_prep: False # Skip data preparation
# We remove utterance slonger than 10s in the train/dev/test sets as
# longer sentences certainly correspond to "open microphones".
avoid_if_longer_than: 10.0
ckpt_interval_minutes: 30 # save checkpoint every N min
####################### Training Parameters ####################################
freeze_whisper: False
freeze_encoder: True
number_of_epochs: 1
weight_decay: 0.01
lr_whisper: 1e-5
warmup_steps: 500
max_grad_norm: 2.0
sorting: ascending
precision: fp16 # bf16, fp16 or fp32
eval_precision: fp16
sample_rate: 16000
# With data_parallel batch_size is split into N jobs
batch_size: 8
test_batch_size: 8
grad_accumulation_factor: 2
# Decoding parameters
min_decode_ratio: 0.0
max_decode_ratio: 1.0
test_beam_size: 8
####################### Model Parameters #######################################
train_loader_kwargs:
batch_size: !ref <batch_size>
valid_loader_kwargs:
batch_size: !ref <batch_size>
test_loader_kwargs:
batch_size: !ref <test_batch_size>
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
############################## Augmentations ###################################
# UNCOMMENT THIS SECTION TO ADD AUGMENTATIONS
# speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
# orig_freq: !ref <sample_rate>
# speeds: [95, 100, 105]
# # Frequency drop: randomly drops a number of frequency bands to zero.
# drop_freq: !new:speechbrain.augment.time_domain.DropFreq
# drop_freq_low: 0 # Min frequency band dropout probability
# drop_freq_high: 1 # Max frequency band dropout probability
# drop_freq_count_low: 1 # Min number of frequency bands to drop
# drop_freq_count_high: 3 # Max number of frequency bands to drop
# drop_freq_width: 0.05 # Width of frequency bands to drop
# # Time drop: randomly drops a number of temporal chunks.
# drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
# drop_length_low: 1
# drop_length_high: 5
# drop_count_low: 1000
# drop_count_high: 2000
# # Augmenter: Combines previously defined augmentations to perform data augmentation
# wav_augment: !new:speechbrain.augment.augmenter.Augmenter
# concat_original: True
# min_augmentations: 3
# max_augmentations: 3
# augment_prob: 1.0
# augmentations: [
# !ref <speed_perturb>,
# !ref <drop_freq>,
# !ref <drop_chunk>]
############################## Models ##########################################
whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
source: !ref <whisper_hub>
freeze: !ref <freeze_whisper>
freeze_encoder: !ref <freeze_encoder>
save_path: !ref <save_folder>/whisper_checkpoint
language: !ref <language>
task: "transcribe"
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
nll_loss: !name:speechbrain.nnet.losses.nll_loss
modules:
whisper: !ref <whisper>
############################## Decoding & optimiser ############################
whisper_opt_class: !name:torch.optim.AdamW
lr: !ref <lr_whisper>
weight_decay: !ref <weight_decay>
valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearcher
model: !ref <whisper>
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearcher
module: [!ref <whisper>]
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
beam_size: !ref <test_beam_size>
lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NoamScheduler
lr_initial: !ref <lr_whisper>
n_warmup_steps: !ref <warmup_steps>
############################## Logging and Pretrainer ##########################
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
whisper: !ref <whisper>
scheduler_whisper: !ref <lr_annealing_whisper>
counter: !ref <epoch_counter>
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True