forked from speechbrain/speechbrain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_llama2.yaml
119 lines (91 loc) · 3.19 KB
/
train_llama2.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# ########################################
# Model: LLAMA2-chat + NLL
# Authors:
# Pooneh Mousavi 2023
# ########################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1995
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]
# Dataset will be downloaded to the `data_original`
data_folder: !PLACEHOLDER
output_folder: !ref results/train_with_llama2/<seed>
replacements_path: ../mapping.pair
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
bleu_4_test_file: !ref <output_folder>/bleu_4_test.txt
bleu_4_valid_file: !ref <output_folder>/bleu_4_valid.txt
# URL for the LLAMA2-chat model
model_hub: meta-llama/Llama-2-7b-chat-hf
llama2_folder: !ref <save_folder>/llama2_checkpoint
# Path where data manifest files will be stored
train_annotation: !ref <output_folder>/train.json
valid_annotation: !ref <output_folder>/dev.json
test_annotation: !ref <output_folder>/test.json
skip_prep: False
# The train logger writes training statistics to a file, as well as stdout.
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
save_file: !ref <train_log>
ckpt_interval_minutes: 30 # save checkpoint every N min
# history_window, i.e. how many user-system exchanges consider as context.
max_history: 2
ignore_index: -100
label_smoothing: 0
####################### Training Parameters ####################################
number_of_epochs: 4
batch_size: 1
test_batch_size: 1
lr: 2e-4
#freeze model
freeze_model: False
num_beams: 3
max_new_tokens: 50
top_k: 45
top_p: 0.9
train_dataloader_options:
batch_size: !ref <batch_size>
shuffle: True
num_workers: 2
drop_last: False
test_dataloader_options:
batch_size: !ref <test_batch_size>
shuffle: True
num_workers: 2
drop_last: True
# Masks
padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask
#LLAMA2 model
llama2_model: !new:speechbrain.lobes.models.huggingface_transformers.llama2.LLAMA2
source: !ref <model_hub>
freeze: !ref <freeze_model>
save_path: !ref <llama2_folder>
max_new_tokens: !ref <max_new_tokens>
num_beams: !ref <num_beams>
top_k: !ref <top_k>
top_p: !ref <top_p>
with_peft: True
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
limit: !ref <number_of_epochs>
modules:
llama2_model: !ref <llama2_model>
model: !new:torch.nn.ModuleList
- [!ref <llama2_model>]
ce_loss: !new:torch.nn.CrossEntropyLoss
ignore_index: !ref <ignore_index>
label_smoothing: !ref <label_smoothing>
opt_class: !name:bitsandbytes.optim.PagedAdam32bit
lr: !ref <lr>
lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler
initial_value: !ref <lr>
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 0
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder>
recoverables:
llama2_model: !ref <llama2_model>
lr_annealing_output: !ref <lr_annealing>
counter: !ref <epoch_counter>
bleu_4_computer: !name:speechbrain.utils.bleu.BLEUStats
max_ngram_order: 4
bleu_2_computer: !name:speechbrain.utils.bleu.BLEUStats
max_ngram_order: 2