Skip to content

sft 训练lora实现有bug #56

@lometheus

Description

@lometheus

您发现了代码中的一个实现遗漏

从代码来看:

  1. TrainingConfig 中确实定义了 LoRA 相关配置(第37-42行):

    use_lora: bool = True
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_target_modules: list = field(default_factory=lambda: ["q_proj", "v_proj"])
  2. SFTTrainerWrappersetup_model() 方法(第152-173行)完全没有使用这些参数 —— 它直接加载模型,没有任何 LoRA 适配器的配置逻辑。

这是一个待实现的功能。正确的做法应该是在 setup_model() 中添加类似这样的逻辑:

def setup_model(self):
    # ... 现有的 tokenizer 和 model 加载代码 ...
    
    # 如果启用 LoRA,应用 PEFT 配置
    if self.config.use_lora:
        from peft import LoraConfig, get_peft_model
        
        lora_config = LoraConfig(
            r=self.config.lora_r,
            lora_alpha=self.config.lora_alpha,
            lora_dropout=self.config.lora_dropout,
            target_modules=self.config.lora_target_modules,
            bias="none",
            task_type="CAUSAL_LM"
        )
        self.model = get_peft_model(self.model, lora_config)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions