Skip to content

Commit 17f4ddd

Browse files
committed
Fix MPS SFT training: Use TrainingArguments instead of SFTConfig
The working MPS finetuning test uses TrainingArguments from transformers, not SFTConfig from trl. Key changes: - Changed from SFTConfig to TrainingArguments - Added fp16/bf16 flags based on model dtype detection - Moved dataset_text_field and max_seq_length to SFTTrainer constructor - Keep existing MPS-specific handling (no 4-bit, adamw_torch optimizer) This aligns with test_mps_finetuning.py which successfully trains on MPS.
1 parent 859c5eb commit 17f4ddd

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

test_sft_training_mac.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
import platform
1515
from unsloth import FastLanguageModel
1616
from datasets import Dataset
17-
from trl import SFTTrainer, SFTConfig
17+
from trl import SFTTrainer
18+
from transformers import TrainingArguments
1819

1920
def main():
2021
print("="*50)
@@ -59,10 +60,20 @@ def main():
5960
] * 10 # 20 samples total
6061
dataset = Dataset.from_list(data)
6162

63+
# Detect correct precision from the model itself
6264
print("\n[4] Configuring training arguments...")
65+
model_dtype = getattr(model.config, "torch_dtype", None)
66+
if model_dtype is None:
67+
model_dtype = model.dtype
68+
69+
is_bf16 = (model_dtype == torch.bfloat16)
70+
is_fp16 = (model_dtype == torch.float16)
71+
6372
if is_mps:
6473
print(" Using adamw_torch optimizer (8-bit not supported on MPS)")
65-
training_args = SFTConfig(
74+
print(f" Using fp16={is_fp16}, bf16={is_bf16}")
75+
76+
training_args = TrainingArguments(
6677
output_dir="./test_sft_output",
6778
per_device_train_batch_size=2,
6879
gradient_accumulation_steps=2,
@@ -71,16 +82,17 @@ def main():
7182
logging_steps=1,
7283
optim="adamw_torch" if is_mps else "adamw_8bit",
7384
seed=42,
74-
max_seq_length=max_seq_length,
75-
dataset_text_field="text",
76-
packing=False, # Test without packing first
85+
fp16=is_fp16,
86+
bf16=is_bf16,
7787
)
7888

7989
print("\n[5] Initializing SFTTrainer...")
8090
trainer = SFTTrainer(
8191
model=model,
8292
tokenizer=tokenizer,
8393
train_dataset=dataset,
94+
dataset_text_field="text",
95+
max_seq_length=max_seq_length,
8496
args=training_args,
8597
)
8698

0 commit comments

Comments
 (0)