-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
150 lines (127 loc) · 4.75 KB
/
train.py
File metadata and controls
150 lines (127 loc) · 4.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3
"""
Fine-tunes Gemma-3-4B with Unsloth LoRA adapters.
Task: extract de-contextualized propositions from entity records.
"""
from __future__ import annotations
import json, torch
from datasets import Dataset
from unsloth import FastLanguageModel
from trl import SFTTrainer, SFTConfig
# ────────────────────────── Model ──────────────────────────
def load_model(
*,
max_seq_length: int = 4096,
dtype: torch.dtype | None = None,
load_in_4bit: bool = True,
):
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=load_in_4bit,
)
model = FastLanguageModel.get_peft_model(
model,
finetune_vision_layers=False,
finetune_language_layers=True,
finetune_attention_modules=True,
finetune_mlp_modules=True,
r=8,
lora_alpha=8,
lora_dropout=0, # int to satisfy Pyright
bias="none",
random_state=3407,
)
return model, tokenizer
# ─────────────────────────── Data ───────────────────────────
def load_training_data(
path: str = "processed/unsloth_training_data.json",
) -> list[dict]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def format_for_chat(batch: dict) -> dict:
texts: list[str] = []
for inst, inp, out in zip(
batch["instruction"], batch["input"], batch["output"], strict=True
):
prompt = (
"<start_of_turn>user\n"
f"{inst}\n\n{inp}<end_of_turn>\n"
"<start_of_turn>model\n"
f"{out}<end_of_turn>"
)
texts.append(prompt)
return {"text": texts}
def build_dataset(data: list[dict]) -> Dataset:
ds = Dataset.from_list(data)
return ds.map(format_for_chat, batched=True, remove_columns=ds.column_names)
# ──────────────────────── Trainer ───────────────────────────
def make_trainer(model, tokenizer, dataset: Dataset) -> SFTTrainer:
cfg = SFTConfig(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=30,
learning_rate=2e-4,
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="checkpoints",
)
return SFTTrainer(
model=model,
args=cfg,
train_dataset=dataset,
processing_class=tokenizer, # TRL ≥ 0.16
)
# ─────────────────────── Quick smoke test ───────────────────
def smoke_test(model, tokenizer):
instr = (
"Extract meaningful, de-contextualized propositions from this entity data. "
"Use full nouns, no pronouns, always include the person's name."
)
entity = (
"Entity Information:\n"
"Name: John Smith\n"
"Summary: Senior Data Scientist at AI Corp\n"
"Long Summary: John has 5 years of experience in machine learning and data analysis.\n"
"Research: Met at AI conference 2023\n"
"Location: Boston, MA"
)
prompt = (
"<start_of_turn>user\n"
f"{instr}\n\n{entity}<end_of_turn>\n"
"<start_of_turn>model\n"
)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
answer = tokenizer.decode(out[0], skip_special_tokens=True)
print(
"\n--- Smoke test output ---\n",
answer.split("<start_of_turn>model\n")[-1].split("<end_of_turn>")[0].strip(),
)
# ─────────────────────────── Main ───────────────────────────
def main():
data = load_training_data()
dataset = build_dataset(data)
model, tokenizer = load_model()
trainer = make_trainer(model, tokenizer, dataset)
print("Starting fine-tuning…")
stats = trainer.train()
print(f"Loss {stats.training_loss:.4f} at step {stats.global_step}")
smoke_test(model, tokenizer)
print("Saving adapters…")
model.save_pretrained("lora_model")
tokenizer.save_pretrained("lora_model")
if __name__ == "__main__":
main()