-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsft.py
More file actions
110 lines (99 loc) · 4.47 KB
/
sft.py
File metadata and controls
110 lines (99 loc) · 4.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
from datasets import Dataset
from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import is_bfloat16_supported
from transformers import TrainingArguments, DataCollatorForSeq2Seq
import torch
from unsloth.chat_templates import train_on_responses_only
def get_instruct_response_part(tokenizer):
prefix_conversation = [
dict(role='user', content='ignore'),
dict(role='assistant', content='ignore'),
]
example_conversation = prefix_conversation + [
dict(role='user', content='<user message content>')
]
example_text = tokenizer.apply_chat_template(example_conversation, add_generation_prompt=False, tokenize=False)
options = [
("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
("<|start_header_id|>user<|end_header_id|>\n", "<|start_header_id|>assistant<|end_header_id|>\n"),
("[INST]", "[/INST]"),
("<|User|>", "<|Assistant|>"),
("<|User|>", "<|Assistant|>"),
]
for (instruction_part, response_part) in options:
if instruction_part in example_text and response_part in example_text:
return instruction_part, response_part
print("Warning: guessing how to train on responses only")
prefix = tokenizer.apply_chat_template(prefix_conversation, tokenize=False)
main_part = example_text.replace(prefix, '')
instruction_part, _ = main_part.split('<user message content>')
response_part = tokenizer.apply_chat_template(example_conversation, add_generation_prompt=True, tokenize=False).replace(example_text, '')
return instruction_part, response_part
def sft_train(training_cfg, dataset, model, tokenizer, test_dataset, **kwargs):
# NOTE: maybe this is not needed but we should test it with train_on_responses_only: https://huggingface.co/docs/trl/en/sft_trainer#dataset-format-support
def apply_chat_template(examples):
if "text" in examples:
return examples
conversations = examples["messages"]
texts = []
for conversation in conversations:
texts.append(
tokenizer.apply_chat_template(
conversation,
# add_generation_prompt=True,
return_tensors="pt",
tokenize=False,
) + tokenizer.eos_token
)
return {"text": texts}
dataset = dataset.map(apply_chat_template, batched=True)
test_dataset = test_dataset.map(apply_chat_template, batched=True)
learning_rate = training_cfg.learning_rate if (not isinstance(training_cfg.learning_rate, str)) else eval(training_cfg.learning_rate)
if learning_rate < 0:
learning_rate = 10 ** learning_rate
trainer_kwargs = dict(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=training_cfg.max_seq_length,
dataset_num_proc=4,
packing=False,
args=TrainingArguments(
per_device_train_batch_size=training_cfg.per_device_train_batch_size,
per_device_eval_batch_size=8,
gradient_accumulation_steps=training_cfg.gradient_accumulation_steps,
warmup_steps=training_cfg.warmup_steps,
learning_rate=learning_rate,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=1,
optim=training_cfg.optim,
weight_decay=training_cfg.weight_decay,
lr_scheduler_type=training_cfg.lr_scheduler_type,
seed=training_cfg.seed,
report_to=None,
num_train_epochs=training_cfg.epochs,
save_strategy=training_cfg.save_strategy,
save_steps=training_cfg.save_steps,
save_total_limit=training_cfg.save_total_limit,
output_dir=training_cfg.output_dir,
**kwargs,
),
callbacks=[],
eval_dataset=test_dataset,
)
if training_cfg.train_on_responses_only:
instruction_part, response_part = get_instruct_response_part(tokenizer)
trainer_kwargs['data_collator'] = DataCollatorForSeq2Seq(tokenizer = tokenizer)
trainer = train_on_responses_only(
SFTTrainer(**trainer_kwargs),
instruction_part=instruction_part,
response_part=response_part,
num_proc=1
)
else:
trainer = SFTTrainer(**trainer_kwargs)
return trainer