-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidate.py
More file actions
134 lines (112 loc) · 6.52 KB
/
validate.py
File metadata and controls
134 lines (112 loc) · 6.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator
import argparse
class TrainingConfig(BaseModel):
class Config:
extra = "forbid" # Prevent extra fields not defined in the model
# Required model and data paths
model: str = Field(..., description="Hugging Face model ID")
training_file: str | list[str] = Field(..., description="File ID of the training dataset")
test_file: Optional[str] = Field(None, description="File ID of the test dataset")
# Output model
finetuned_model_id: str = Field('{org_id}/{model_name}-{job_id}', description="File ID of the finetuned model")
# Model configuration
max_seq_length: int = Field(2048, description="Maximum sequence length for training")
load_in_4bit: bool = Field(False, description="Whether to load model in 4-bit quantization")
# Training type configuration
loss: Literal["dpo", "orpo", "sft"] = Field(..., description="Loss function / training type")
# PEFT configuration
is_peft: bool = Field(True, description="Whether to use PEFT for training")
target_modules: Optional[List[str]] = Field(
default=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
description="Target modules for LoRA"
)
lora_bias: Literal["all", "none"] = Field("none", description="Value for FastLanguageModel.get_peft_model(bias=?)")
# LoRA specific arguments
r: int = Field(16, description="LoRA attention dimension")
lora_alpha: int = Field(16, description="LoRA alpha parameter")
lora_dropout: float = Field(0.0, description="LoRA dropout rate")
use_rslora: bool = Field(True, description="Whether to use RSLoRA")
merge_before_push: bool = Field(True, description="Whether to merge model before pushing to Hub. Only merged models can be used as parent models for further finetunes. Only supported for bf16 models.")
push_to_private: bool = Field(True, description="Whether to push to private Hub")
# Training hyperparameters
epochs: int = Field(1, description="Number of training epochs")
max_steps: Optional[int] = Field(None, description="Maximum number of training steps")
per_device_train_batch_size: int = Field(2, description="Training batch size per device")
gradient_accumulation_steps: int = Field(8, description="Number of gradient accumulation steps")
warmup_steps: int = Field(5, description="Number of warmup steps")
learning_rate: Union[float, str] = Field(1e-4, description="Learning rate or string expression")
logging_steps: int = Field(1, description="Number of steps between logging")
optim: str = Field("adamw_8bit", description="Optimizer to use for training")
weight_decay: float = Field(0.01, description="Weight decay rate")
lr_scheduler_type: str = Field("linear", description="Learning rate scheduler type")
seed: int = Field(3407, description="Random seed for reproducibility")
beta: float = Field(0.1, description="Beta parameter for DPO/ORPO training")
save_steps: int = Field(5000, description="Save checkpoint every X steps")
save_strategy: Literal["no", "epoch", "steps"] = Field("epoch", description="Save strategy: 'no', 'epoch', or 'steps'")
save_total_limit: Optional[int] = Field(1, description="Maximum number of checkpoints to keep")
output_dir: str = Field("./tmp", description="Output directory for training checkpoints")
train_on_responses_only: bool = Field(True, description="Whether to train on responses only")
no_test_split: bool = Field(False, description="Whether to not split the test set")
# Steering configuration
steering_config: Optional[Dict] = Field(None, description="Steering configuration for projection intervention")
enable_steering_during_training: bool = Field(False, description="Whether to enable steering during training")
@model_validator(mode="before")
def validate_training_file_prefixes(cls, values):
loss = values.get('loss', 'orpo')
training_file = values.get('training_file')
if isinstance(training_file, str):
if os.path.exists(training_file):
return values
elif isinstance(training_file, list):
for file in training_file:
if not os.path.exists(file):
raise ValueError(f"Training file {file} does not exist")
else:
raise ValueError(f"Training file must be a string or a list of strings, got: {training_file}")
# if loss == 'sft' and not training_file.startswith('conversations'):
# raise ValueError(f"For SFT training, dataset filename must start with 'conversations', got: {training_file}")
if loss in ['dpo', 'orpo'] and not training_file.startswith('preference'):
raise ValueError(f"For DPO/ORPO training, dataset filename must start with 'preference', got: {training_file}")
return values
@field_validator("finetuned_model_id")
def validate_finetuned_model_id(cls, v):
# if v and model_exists(v):
# raise ValueError(f"Model {v} already exists")
if len(v.split("/")) != 2:
raise ValueError("Model ID must be in the format 'user/model'")
org, model = v.split("/")
if org in ["datasets", "models", "unsloth", "None"]:
raise ValueError(f"You have set org={org}, but it must be an org you have access to")
return v
@field_validator("learning_rate", mode="before")
def validate_learning_rate(cls, v):
if isinstance(v, float) and v <= 0:
raise ValueError("Learning rate must be positive")
return v
@field_validator("lora_dropout")
def validate_dropout(cls, v):
if not 0 <= v <= 1:
raise ValueError("Dropout rate must be between 0 and 1")
return v
@field_validator("optim")
def validate_optimizer(cls, v):
allowed_optimizers = ["adamw_8bit", "adamw", "adam", "sgd"]
if v not in allowed_optimizers:
raise ValueError(f"Optimizer must be one of {allowed_optimizers}")
return v
@field_validator("lr_scheduler_type")
def validate_scheduler(cls, v):
allowed_schedulers = ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
if v not in allowed_schedulers:
raise ValueError(f"Scheduler must be one of {allowed_schedulers}")
return v