forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_hf_whisper.yaml
164 lines (130 loc) · 5.31 KB
/
train_hf_whisper.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# ################################
# Model: Whisper (Encoder-Decoder) + NLL
# Augmentation: TimeDomainSpecAugment
# Authors: Adel Moumen 2022 & 2024, Titouan Parcollet 2022
# ################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]
output_folder: !ref results/whisper/<seed>
output_wer_folder: !ref <output_folder>/
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
# URL for the biggest Fairseq english whisper model.
whisper_hub: openai/whisper-medium.en
whisper_folder: !ref <save_folder>/whisper_checkpoint
# Normalize the english inputs with
# the same normalization done in the paper
normalized_transcripts: True
# Data files
data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech
train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
dev_splits: ["dev-clean"]
test_splits: ["dev-clean", "test-clean", "test-other"]
skip_prep: False
train_csv: !ref <output_folder>/train.csv
valid_csv: !ref <output_folder>/dev-clean.csv
test_csv:
- !ref <output_folder>/dev-clean.csv
- !ref <output_folder>/test-clean.csv
- !ref <output_folder>/test-other.csv
ckpt_interval_minutes: 10 # save checkpoint every N min
############################## Training Parameters #############################
freeze_encoder: True
number_of_epochs: 1
weight_decay: 0.01
lr_whisper: 1e-5
warmup_steps: 500
max_grad_norm: 2.0
sorting: ascending
precision: fp16 # bf16, fp16 or fp32
eval_precision: fp16
sampling_rate: 16_000
# With data_parallel batch_size is split into N jobs
# With DDP batch_size is multiplied by N jobs
# This setup works well with 1x 32GB GPU
batch_size: 16
test_batch_size: 16
grad_accumulation_factor: 1
# Decoding parameters
min_decode_ratio: 0.0
max_decode_ratio: 1.0
test_beam_size: 8
####################### Model Parameters #######################################
train_loader_kwargs:
batch_size: !ref <batch_size>
valid_loader_kwargs:
batch_size: !ref <test_batch_size>
test_loader_kwargs:
batch_size: !ref <test_batch_size>
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
############################## Augmentations ###################################
# UNCOMMENT THIS SECTION TO ADD AUGMENTATIONS
# speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
# orig_freq: !ref <sample_rate>
# speeds: [95, 100, 105]
# # Frequency drop: randomly drops a number of frequency bands to zero.
# drop_freq: !new:speechbrain.augment.time_domain.DropFreq
# drop_freq_low: 0 # Min frequency band dropout probability
# drop_freq_high: 1 # Max frequency band dropout probability
# drop_freq_count_low: 1 # Min number of frequency bands to drop
# drop_freq_count_high: 3 # Max number of frequency bands to drop
# drop_freq_width: 0.05 # Width of frequency bands to drop
# # Time drop: randomly drops a number of temporal chunks.
# drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
# drop_length_low: 1
# drop_length_high: 5
# drop_count_low: 1000
# drop_count_high: 2000
# # Augmenter: Combines previously defined augmentations to perform data augmentation
# wav_augment: !new:speechbrain.augment.augmenter.Augmenter
# concat_original: True
# min_augmentations: 3
# max_augmentations: 3
# augment_prob: 1.0
# augmentations: [
# !ref <speed_perturb>,
# !ref <drop_freq>,
# !ref <drop_chunk>]
############################## Models ##########################################
whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
source: !ref <whisper_hub>
freeze_encoder: !ref <freeze_encoder>
save_path: !ref <whisper_folder>
language: "english"
task: "transcribe"
sampling_rate: !ref <sampling_rate>
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
nll_loss: !name:speechbrain.nnet.losses.nll_loss
modules:
whisper: !ref <whisper>
############################## Decoding & optimiser ############################
whisper_opt_class: !name:torch.optim.AdamW
lr: !ref <lr_whisper>
weight_decay: !ref <weight_decay>
valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearcher
model: !ref <whisper>
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearcher
module: [!ref <whisper>]
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
beam_size: !ref <test_beam_size>
lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NoamScheduler
lr_initial: !ref <lr_whisper>
n_warmup_steps: !ref <warmup_steps>
############################## Logging and Pretrainer ##########################
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
whisper: !ref <whisper>
scheduler_whisper: !ref <lr_annealing_whisper>
counter: !ref <epoch_counter>
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
split_tokens: True