-
Notifications
You must be signed in to change notification settings - Fork 104
Expand file tree
/
Copy pathdpo_ultrafeedback_llama3_8b.yaml
More file actions
89 lines (81 loc) · 2.73 KB
/
dpo_ultrafeedback_llama3_8b.yaml
File metadata and controls
89 lines (81 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Example DPO config for train_dpo.py using the Ultrafeedback preference dataset.
# Update the GCS paths to match your transformed preference dataset and desired cache location.
# python infra/launch.py --zone us-central1-a --tpu_name debug --tpu_type v5p-8 -- python src/levanter/main/train_dpo.py --config_path config/dpo_ultrafeedback_llama3_8b.yaml
data:
tokenizer: marin-community/marin-tokenizer
shuffle: true
permutation_type: feistel
components:
ultrafeedback_train_prefs:
source:
type: url
train_urls:
# Replace with the output path from transform_preference_dataset_step (train_prefs split).
- gs://marin-us-central1/preference/HuggingFaceH4--ultrafeedback_binarized-3949bf5-69e206/train_prefs/*.jsonl.gz
format:
type: preference_chat
slice_strategy: raise
cache_dir: gs://marin-us-central1/tokenized/ultrafeedback_binarized_train_prefs_marin_tokenizer-7040eb
ultrafeedback_test_prefs:
source:
type: url
validation_urls:
# Replace with the output path from transform_preference_dataset_step (test_prefs split).
- gs://marin-us-central1/preference/HuggingFaceH4--ultrafeedback_binarized-3949bf5-69e206/test_prefs/*.jsonl.gz
format:
type: preference_chat
slice_strategy: raise
cache_dir: gs://marin-us-central1/tokenized/ultrafeedback_binarized_test_prefs_marin_tokenizer-29ecfb
train_weights:
ultrafeedback_train_prefs: 1.0
ultrafeedback_test_prefs: 0.0
model:
type: llama
max_seq_len: 4096
hidden_dim: 4096
intermediate_dim: 14336
num_layers: 32
num_heads: 32
num_kv_heads: 8
flash_attention_block_size: 512
use_bias: false
use_layer_norm_weight: true
initializer_range: 0.02
rope:
type: "llama3"
train_seq_len: 4096
trainer:
seed: 0
per_device_parallelism: -1
per_device_eval_parallelism: -1
tracker:
type: wandb
project: "dpo"
tags: ["dpo", "ultrafeedback", "llama3", "simpo"]
mp: p=f32,c=bfloat16
train_batch_size: 128
num_train_steps: 2150
steps_per_eval: 200
model_averaging: null
checkpointer:
save_interval: 10m
base_path: gs://marin-us-central1/checkpoints/dpo/ultrafeedback_llama3_8b
optimizer:
learning_rate: 5e-7
weight_decay: 0.0
min_lr_ratio: 0.0
lr_schedule: "cosine"
warmup: 0.1
max_grad_norm: 1.0
adapter:
type: none
reference:
type: separate
model_path: gs://marin-us-central1/gcsfuse_mount/models/meta-llama--Llama-3-1-8B--main
is_hf: true
beta: 0.01
validation_split_fraction: null
initialize_from_hf: gs://marin-us-central1/gcsfuse_mount/models/meta-llama--Llama-3-1-8B--main
use_hf_model_config: false
hf_save_steps: 1000
hf_save_path: gs://marin-us-central1/checkpoints/dpo/ultrafeedback_llama3_8b/hf