Skip to content

Commit 492e9e5

Browse files
committed
rm PrefixModelForCausalLM
1 parent 18a3a1d commit 492e9e5

File tree

4 files changed

+6
-10
lines changed

4 files changed

+6
-10
lines changed

.github/workflows/gpu_ci_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ concurrency:
1414
jobs:
1515
Test:
1616
name: Test
17-
runs-on: [self-hosted, ernie-8gpu]
17+
runs-on: [self-hosted, ernie-8gpu-1]
1818
steps:
1919
- name: Start Docker
2020
run: |

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ env:
1313
jobs:
1414
Lint:
1515
name: Lint
16-
runs-on: [self-hosted, ernie-cpu]
16+
runs-on: [self-hosted, ernie-cpu-01]
1717
permissions:
1818
pull-requests: write
1919
contents: read

erniekit/train/ocr_vl_sft/trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
from distutils.util import strtobool
4141

42-
from paddleformers.peft import LoRAModel, PrefixModelForCausalLM
42+
from paddleformers.peft import LoRAModel
4343
from paddleformers.trainer import (
4444
speed_metrics,
4545
)
@@ -833,9 +833,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
833833
logger.info(
834834
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
835835
)
836-
if isinstance(self.model, LoRAModel) or isinstance(
837-
self.model, PrefixModelForCausalLM
838-
):
836+
if isinstance(self.model, LoRAModel):
839837
self._load_best_model_from_peft_checkpoint()
840838
else:
841839
weight_name = PADDLE_WEIGHTS_NAME

erniekit/train/vl_sft/trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
from setuptools._distutils.util import strtobool
4343

44-
from paddleformers.peft import LoRAModel, PrefixModelForCausalLM
44+
from paddleformers.peft import LoRAModel
4545
from paddleformers.trainer import (
4646
speed_metrics,
4747
)
@@ -829,9 +829,7 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
829829
logger.info(
830830
f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
831831
)
832-
if isinstance(self.model, LoRAModel) or isinstance(
833-
self.model, PrefixModelForCausalLM
834-
):
832+
if isinstance(self.model, LoRAModel):
835833
self._load_best_model_from_peft_checkpoint()
836834
else:
837835
weight_name = PADDLE_WEIGHTS_NAME

0 commit comments

Comments
 (0)