-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathexample_train_tiny_unsloth.py
More file actions
95 lines (79 loc) · 2.52 KB
/
example_train_tiny_unsloth.py
File metadata and controls
95 lines (79 loc) · 2.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
#
# Example to train a tiny model using Unsloth
# Run example_create_tiny.py first
import os
from unsloth import FastModel
# Import unsloth before others
from datasets import Dataset
from trl import SFTConfig, SFTTrainer
from qwen3_moe_fused.fast_lora import patch_Qwen3MoeFusedSparseMoeBlock_forward
from qwen3_moe_fused.lora import patch_lora_config
from qwen3_moe_fused.modular_qwen3_moe_fused import Qwen3MoeFusedForCausalLM
from qwen3_moe_fused.quantize.quantizer import patch_bnb_quantizer
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"
def main():
patch_bnb_quantizer()
patch_lora_config()
patch_Qwen3MoeFusedSparseMoeBlock_forward()
model_dir = "./pretrained/qwen-moe-tiny-lm-quantized"
model, tokenizer = FastModel.from_pretrained(model_dir, auto_model=Qwen3MoeFusedForCausalLM)
model = FastModel.get_peft_model(
model,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate",
"gate_proj",
"up_proj",
"down_proj",
],
# We can set a smaller rank for MoE layers,
# see https://github.com/woct0rdho/transformers-qwen3-moe-fused/issues/3#issuecomment-3144009673
# With rslora, we don't need to set a different alpha for them
rank_pattern={
"q_proj": 16,
"k_proj": 16,
"v_proj": 16,
"o_proj": 16,
"gate": 16,
"gate_proj": 4,
"up_proj": 4,
"down_proj": 4,
},
lora_alpha=1,
use_rslora=True,
use_gradient_checkpointing="unsloth",
random_state=3407,
)
dataset = Dataset.from_dict({"text": [x * 100 for x in "abcdefghijkl"]})
# These hyperparameters are for exaggerating the training of the tiny model
# Don't use them in actual training
sft_config = SFTConfig(
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
learning_rate=1e-2,
weight_decay=1e-2,
num_train_epochs=1,
logging_steps=1,
save_steps=3,
bf16=True,
optim="adamw_8bit",
dataset_text_field="text",
dataset_num_proc=1,
report_to="none",
seed=3407,
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
args=sft_config,
)
trainer_stats = trainer.train()
print("trainer_stats")
print(trainer_stats)
if __name__ == "__main__":
main()