Skip to content

Commit 171f717

Browse files
committed
rm PrefixModelForCausalLM
1 parent 18a3a1d commit 171f717

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

erniekit/train/ocr_vl_sft/trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from distutils.util import strtobool
4141

42-
from paddleformers.peft import LoRAModel, PrefixModelForCausalLM
42+
from paddleformers.peft import LoRAModel
4343
from paddleformers.trainer import (
4444
speed_metrics,
4545
)
@@ -833,9 +833,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
833833
logger.info(
834834
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
835835
)
836-
if isinstance(self.model, LoRAModel) or isinstance(
837-
self.model, PrefixModelForCausalLM
838-
):
836+
if isinstance(self.model, LoRAModel):
839837
self._load_best_model_from_peft_checkpoint()
840838
else:
841839
weight_name = PADDLE_WEIGHTS_NAME

erniekit/train/vl_sft/trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
from setuptools._distutils.util import strtobool
4343

44-
from paddleformers.peft import LoRAModel, PrefixModelForCausalLM
44+
from paddleformers.peft import LoRAModel
4545
from paddleformers.trainer import (
4646
speed_metrics,
4747
)
@@ -829,9 +829,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
829829
logger.info(
830830
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
831831
)
832-
if isinstance(self.model, LoRAModel) or isinstance(
833-
self.model, PrefixModelForCausalLM
834-
):
832+
if isinstance(self.model, LoRAModel):
835833
self._load_best_model_from_peft_checkpoint()
836834
else:
837835
weight_name = PADDLE_WEIGHTS_NAME

0 commit comments

Comments
 (0)