-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
107 lines (86 loc) · 4.39 KB
/
Copy pathconfig.py
File metadata and controls
107 lines (86 loc) · 4.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Configure training parameters
import os
from trl import SFTConfig
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
class Config:
"""Configuration class to manage all training and model parameters."""
def __init__(self):
# Model Configuration
self.model_name = os.getenv("MODEL_NAME", "HuggingFaceTB/SmolLM3-3B")
self.new_model_name = os.getenv("NEW_MODEL_NAME", "SmolLM3-Custom-SFT")
self.model_dtype = os.getenv("MODEL_DTYPE", "float16")
self.device_map = os.getenv("DEVICE_MAP", "auto")
self.trust_remote_code = os.getenv("TRUST_REMOTE_CODE", "True").lower() == "true"
# Dataset Configuration
self.dataset_name = os.getenv("DATASET_NAME", "HuggingFaceTB/smoltalk2")
self.dataset_split = os.getenv("DATASET_SPLIT", "SFT")
self.training_split = os.getenv("TRAINING_SPLIT", "smoltalk_everyday_convs_reasoning_Qwen3_32B_think")
self.dataset_subset_size = int(os.getenv("DATASET_SUBSET_SIZE", "1000"))
self.dataset_text_field = os.getenv("DATASET_TEXT_FIELD", "text")
self.max_length = int(os.getenv("MAX_LENGTH", "2048"))
# Training Hyperparameters
self.per_device_train_batch_size = int(os.getenv("PER_DEVICE_TRAIN_BATCH_SIZE", "2"))
self.gradient_accumulation_steps = int(os.getenv("GRADIENT_ACCUMULATION_STEPS", "2"))
self.learning_rate = float(os.getenv("LEARNING_RATE", "5e-5"))
self.num_train_epochs = int(os.getenv("NUM_TRAIN_EPOCHS", "1"))
self.max_steps = int(os.getenv("MAX_STEPS", "500"))
# Optimization
self.warmup_steps = int(os.getenv("WARMUP_STEPS", "50"))
self.weight_decay = float(os.getenv("WEIGHT_DECAY", "0.01"))
self.optim = os.getenv("OPTIM", "adamw_torch")
# Logging and Saving
self.logging_steps = int(os.getenv("LOGGING_STEPS", "10"))
self.save_steps = int(os.getenv("SAVE_STEPS", "100"))
self.eval_steps = int(os.getenv("EVAL_STEPS", "100"))
self.save_total_limit = int(os.getenv("SAVE_TOTAL_LIMIT", "2"))
# Memory Optimization
self.dataloader_num_workers = int(os.getenv("DATALOADER_NUM_WORKERS", "0"))
self.group_by_length = os.getenv("GROUP_BY_LENGTH", "True").lower() == "true"
# Hugging Face Hub Integration
self.push_to_hub = os.getenv("PUSH_TO_HUB", "False").lower() == "true"
self.hub_username = os.getenv("HUB_USERNAME", "your-username")
# Experiment Tracking
self.report_to = os.getenv("REPORT_TO", "trackio").split(",")
def get_training_config(self):
"""Create and return SFTConfig with loaded parameters."""
return SFTConfig(
# Model and data
output_dir=f"./{self.new_model_name}",
dataset_text_field=self.dataset_text_field,
max_length=self.max_length,
# Training hyperparameters
per_device_train_batch_size=self.per_device_train_batch_size,
gradient_accumulation_steps=self.gradient_accumulation_steps,
learning_rate=self.learning_rate,
num_train_epochs=self.num_train_epochs,
max_steps=self.max_steps,
# Optimization
warmup_steps=self.warmup_steps,
weight_decay=self.weight_decay,
optim=self.optim,
# Logging and saving
logging_steps=self.logging_steps,
save_steps=self.save_steps,
eval_steps=self.eval_steps,
save_total_limit=self.save_total_limit,
# Memory optimization
dataloader_num_workers=self.dataloader_num_workers,
group_by_length=self.group_by_length,
# Hugging Face Hub integration
push_to_hub=self.push_to_hub,
hub_model_id=f"{self.hub_username}/{self.new_model_name}",
# Experiment tracking
report_to=self.report_to,
run_name=f"{self.new_model_name}-training",
)
def print_config_summary(self):
"""Print configuration summary."""
print("Training configuration set!")
effective_batch = self.per_device_train_batch_size * self.gradient_accumulation_steps
print(f"Effective batch size: {effective_batch}")
# Create global config instance
config = Config()
training_config = config.get_training_config()
config.print_config_summary()