如何使用unsloth进行dpo微调(qlora) #7324
Answered
by
SnowFox4004
SnowFox4004
asked this question in
Q&A
-
如何在llamafactory中使用unsloth进行dpo微调?在使用unsloth时会报错
我的yaml配置如下 ### model
model_name_or_path: Qwen/Qwen2.5-3B-Instruct
quantization_bit: 4
quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)]
trust_remote_code: true
### method
stage: dpo
do_train: true
finetuning_type: lora
lora_rank: 8
lora_target: all
pref_loss: orpo
upcast_layernorm: true
### dataset
dataset: fic_smzh_dpo
template: qwen
cutoff_len: 2048
max_samples: 2500
overwrite_cache: true
preprocessing_num_workers: 16
### output
output_dir: saves/qwen2fiction_dpo/lora/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 1.0e-4
num_train_epochs: 2
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
ddp_timeout: 180000000
### eval
# val_size: 0.1
# per_device_eval_batch_size: 1
# eval_strategy: steps
# eval_steps: 500
### speedup
flash_attn: fa2
# enable_liger_kernel: True
use_unsloth: True
### monitoring
use_swanlab: true
swanlab_project: llamafactory
swanlab_run_name: 3b_DPO_fiction
# report_to: tensorboard
在训练时会如下错误 File "/home/snowfox/miniconda3/envs/lmfac/bin/llamafactory-cli", line 8, in <module>
sys.exit(main())
File "/home/snowfox/coding/llm_finetuning/LLaMA-Factory/src/llamafactory/cli.py", line 118, in main
run_exp()
File "/home/snowfox/coding/llm_finetuning/LLaMA-Factory/src/llamafactory/train/tuner.py", line 103, in run_exp
_training_function(config={"args": args, "callbacks": callbacks})
File "/home/snowfox/coding/llm_finetuning/LLaMA-Factory/src/llamafactory/train/tuner.py", line 74, in _training_function
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
File "/home/snowfox/coding/llm_finetuning/LLaMA-Factory/src/llamafactory/train/dpo/workflow.py", line 83, in run_dpo
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
File "/home/snowfox/miniconda3/envs/lmfac/lib/python3.10/site-packages/transformers/trainer.py", line 2241, in train
return inner_training_loop(
File "<string>", line 306, in _fast_inner_training_loop
File "<string>", line 31, in _unsloth_training_step
File "/home/snowfox/coding/llm_finetuning/LLaMA-Factory/src/llamafactory/train/dpo/trainer.py", line 272, in compute_loss
return super().compute_loss(model, inputs, return_outputs)
File "/home/snowfox/miniconda3/envs/lmfac/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1408, in compute_loss
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
File "/home/snowfox/coding/llm_finetuning/LLaMA-Factory/src/llamafactory/train/dpo/trainer.py", line 239, in get_batch_loss_metrics
) = self.concatenated_forward(model, batch)
File "/home/snowfox/coding/llm_finetuning/LLaMA-Factory/src/llamafactory/train/dpo/trainer.py", line 189, in concatenated_forward
all_logits: torch.Tensor = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
'NoneType' object has no attribute 'to'
0%| | 0/624 [00:07<?, ?it/s] 关闭unsloth加速后就可以正常训练,但是我看unsloth也可以进行强化学习微调,想请教一下如何使用 |
Beta Was this translation helpful? Give feedback.
Answered by
SnowFox4004
Mar 16, 2025
Replies: 1 comment
-
在代码里打印generate的输出后发现 all_logits: CausalLMOutputWithPast(loss=tensor(0.4934, device='cuda:0', grad_fn=<LinearCrossEntropyFunctionBackward>), logits=Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:
\```
import os
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
trainer.train()
\```
No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!, past_key_values=None, hidden_states=None, attentions=None) 是unsloth改了导致不会输出logits,只需修改环境变量即可 |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
SnowFox4004
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
在代码里打印generate的输出后发现
是unsloth改了导致不会输出logits,只需修改环境变量即可