-
Notifications
You must be signed in to change notification settings - Fork 77
Expand file tree
/
Copy pathsft_tp_ddp_gsm8k_config.yaml
More file actions
52 lines (46 loc) · 1.92 KB
/
sft_tp_ddp_gsm8k_config.yaml
File metadata and controls
52 lines (46 loc) · 1.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
# Model configuration
model:
model_type: "hf" # Hugging Face model
auto_class_name: "AutoModelForCausalLM" # Auto class to load the model with
model_name: "HuggingFaceTB/SmolLM-135M" # Pretrained model name
use_peft: true # Enable PEFT (Parameter Efficient Fine-Tuning)
peft_config:
lora_r: 8 # LoRA rank
lora_alpha: 16
lora_dropout: 0
target_modules: ["k_proj","gate_proj","q_proj","up_proj","v_proj","down_proj"] # Target modules for LoRA
task_type: "CAUSAL_LM" # Options: CAUSAL_LM, SEQ_2_SEQ_LM, etc.
peft_type: "LORA" # Options: LORA, IA3, etc.
# Dataset configuration
dataset:
dataset_type: "sft_dataset"
dataset_name: "openai/gsm8k" # Dataset name from Hugging Face Hub
prompt_template: "Solve the following math problem step by step.\n\n### Question:\n{question}\n\n### Answer:\n" # Template to create prompt from dataset fields
completion_template: "{answer}" # Model will be trained on this part.
config_name: "main" # Config name for the dataset
data_seed: 42 # Random seed for dataset shuffling
# Training configuration
training:
type: "sft"
gradient_accumulation_steps: 1 # Number of steps to accumulate gradients
per_device_train_batch_size: 1 # Batch size per device during training
num_train_epochs: 1
torch_compile: False # Whether to use torch.compile
tp_degree: 2
ddp_degree: 2
# Optimizer configuration
optimizers:
optimizer_name: "adamw"
lr: 1e-4
scheduler:
scheduler_name: "cosine"
callbacks:
early_stopping:
early_stopping_patience: 3 # Number of epochs to wait before stopping training
early_stopping_threshold: 0.001 # Minimum change in metric to qualify as improvement