-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathexample_train_30b_a3b_unsloth.py
More file actions
108 lines (91 loc) · 3.29 KB
/
example_train_30b_a3b_unsloth.py
File metadata and controls
108 lines (91 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
#
# Example to train a LoRA on the fused and quantized version of Qwen3-30B-A3B using Unsloth
# Runs with 24 GB VRAM
#
# Important: We cache autotuned Triton kernels by default. If you did some small-scale tests, then you should
# clear the Triton cache and the TorchInductor cache before the actual training
import os
from unsloth import FastModel
# Import unsloth before others
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from qwen3_moe_fused.compile_utils import compile_layers
from qwen3_moe_fused.fast_lora import patch_Qwen3MoeFusedSparseMoeBlock_forward
from qwen3_moe_fused.lora import patch_lora_config
from qwen3_moe_fused.modular_qwen3_moe_fused import Qwen3MoeFusedForCausalLM, patch_Qwen3MoeSparseMoeBlock_init
from qwen3_moe_fused.quantize.quantizer import patch_bnb_quantizer
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"
def main():
patch_Qwen3MoeSparseMoeBlock_init()
patch_bnb_quantizer()
patch_lora_config()
patch_Qwen3MoeFusedSparseMoeBlock_forward()
model_id = "bash99/Qwen3-30B-A3B-Instruct-2507-fused-bnb-4bit"
model, tokenizer = FastModel.from_pretrained(model_id, auto_model=Qwen3MoeFusedForCausalLM)
model = FastModel.get_peft_model(
model,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
# "gate",
"gate_proj",
"up_proj",
"down_proj",
],
# We can set a smaller rank for MoE layers,
# see https://github.com/woct0rdho/transformers-qwen3-moe-fused/issues/3#issuecomment-3144009673
# With rslora, we don't need to set a different alpha for them
# It's possible to create a LoRA on the routing gate, but this may make the training unstable
rank_pattern={
"q_proj": 16,
"k_proj": 16,
"v_proj": 16,
"o_proj": 16,
# "gate": 16,
"gate_proj": 4,
"up_proj": 4,
"down_proj": 4,
},
lora_alpha=1,
use_rslora=True,
use_gradient_checkpointing="unsloth",
random_state=3407,
)
compile_layers(model)
dataset = load_dataset("stanfordnlp/imdb", split="train")
sft_config = SFTConfig(
per_device_train_batch_size=1, # Increase batch size if you have more memory
gradient_accumulation_steps=1,
learning_rate=1e-4,
# For MoE models, weight decay can be smaller than for dense models,
# because not every expert has gradient in every step, but weight decay is applied to every expert
weight_decay=1e-3,
num_train_epochs=1,
lr_scheduler_type="linear",
warmup_steps=1000,
logging_steps=1,
save_steps=100,
save_total_limit=5,
bf16=True,
optim="adamw_8bit",
dataset_text_field="text",
dataset_num_proc=1,
torch_compile=True,
torch_compile_mode="max-autotune",
report_to="none", # You may report to Wandb
seed=3407,
)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=dataset,
args=sft_config,
)
trainer_stats = trainer.train()
print("trainer_stats")
print(trainer_stats)
if __name__ == "__main__":
main()