-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTrain.py
More file actions
163 lines (126 loc) · 5.01 KB
/
Copy pathTrain.py
File metadata and controls
163 lines (126 loc) · 5.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
Model Training Module - Fine-tuning SmolLM3 with Object-Oriented Architecture
"""
# Import required libraries for fine-tuning
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import config, training_config
from trl import SFTTrainer
from datasets import load_dataset
import torch
class ModelLoader:
"""Handles loading and initialization of the model and tokenizer."""
def __init__(self, config):
"""
Initialize ModelLoader with configuration.
Args:
config: Configuration object containing model parameters
"""
self.config = config
self.model = None
self.tokenizer = None
def load_model(self):
"""Load the pre-trained model with specified configuration."""
print(f"Loading {self.config.model_name}...")
# Determine dtype
dtype = torch.float16 if self.config.model_dtype == "float16" else torch.float32
self.model = AutoModelForCausalLM.from_pretrained(
self.config.model_name,
dtype=dtype,
device_map=self.config.device_map,
trust_remote_code=self.config.trust_remote_code,
)
print(f"Model loaded! Parameters: {self.model.num_parameters():,}")
return self.model
def load_tokenizer(self):
"""Load and configure the tokenizer."""
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token # Set padding token
self.tokenizer.padding_side = "right" # Padding on the right for generation
return self.tokenizer
def load_all(self):
"""Load both model and tokenizer."""
self.load_model()
self.load_tokenizer()
return self.model, self.tokenizer
class DatasetLoader:
"""Handles loading and preparation of training datasets."""
def __init__(self, config):
"""
Initialize DatasetLoader with configuration.
Args:
config: Configuration object containing dataset parameters
"""
self.config = config
self.dataset = None
self.train_dataset = None
def load_dataset(self):
"""Load the dataset from HuggingFace."""
print("=== PREPARING DATASET ===\n")
# Load dataset
self.dataset = load_dataset(
self.config.dataset_name,
self.config.dataset_split
)
return self.dataset
def prepare_training_split(self):
"""Prepare the training split with specified subset size."""
if self.dataset is None:
self.load_dataset()
# Select training split and subset
self.train_dataset = self.dataset[self.config.training_split].select(
range(self.config.dataset_subset_size)
)
print(f"Training dataset prepared with {len(self.train_dataset)} examples")
return self.train_dataset
class ModelTrainer:
"""Handles the training process for the model."""
def __init__(self, model, train_dataset, training_config):
"""
Initialize ModelTrainer.
Args:
model: The model to train
train_dataset: The prepared training dataset
training_config: Training configuration from SFTConfig
"""
self.model = model
self.train_dataset = train_dataset
self.training_config = training_config
self.trainer = None
def setup_trainer(self):
"""Set up the SFT trainer."""
self.trainer = SFTTrainer(
model=self.model,
args=self.training_config,
train_dataset=self.train_dataset,
)
return self.trainer
def train(self):
"""Execute the training process."""
if self.trainer is None:
self.setup_trainer()
print("\n=== STARTING TRAINING ===")
self.trainer.train()
def save_model(self):
"""Save the trained model."""
if self.trainer is None:
raise RuntimeError("Trainer not initialized. Cannot save model.")
self.trainer.save_model()
print(f"Model saved to {self.training_config.output_dir}")
def train_and_save(self):
"""Execute full training pipeline: setup, train, and save."""
self.setup_trainer()
self.train()
self.save_model()
def main():
"""Main execution function for model training pipeline."""
# Initialize model loader and load model
model_loader = ModelLoader(config)
model, tokenizer = model_loader.load_all()
# Initialize dataset loader and prepare data
dataset_loader = DatasetLoader(config)
train_dataset = dataset_loader.prepare_training_split()
# Initialize trainer and execute training
trainer = ModelTrainer(model, train_dataset, training_config)
trainer.train_and_save()
if __name__ == "__main__":
main()