Skip to content

Commit d36acca

Browse files
committed
Support messages-type data for ERNIE-VL (#1371)
1 parent bc7fca4 commit d36acca

File tree

14 files changed

+60
-43
lines changed

14 files changed

+60
-43
lines changed

ernie/dataset/text_sft_reader/finetuning.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
simplify=False,
7878
use_train_part_sharding=False,
7979
rope_3d=False,
80+
chat_template="ernie_vl",
8081
**kwargs,
8182
):
8283
self.task_group = copy.deepcopy(task_group)
@@ -104,6 +105,7 @@ def __init__(
104105
self.simplify = simplify
105106
self.use_train_part_sharding = use_train_part_sharding
106107
self.rope_3d = rope_3d
108+
self.chat_template = chat_template
107109
self.place = paddle.set_device(device)
108110

109111
# setup special tokens
@@ -331,7 +333,10 @@ def _read_jsonl(self, input_file):
331333
if "<think>" in last_tgt and "</think>" in last_tgt:
332334
data["prefix"] = ""
333335
else:
334-
data["prefix"] = "<think>\n\n</think>\n\n"
336+
if self.chat_template == "ernie_vl_thinking":
337+
data["prefix"] = "\n<think>\n\n</think>\n\n"
338+
else:
339+
data["prefix"] = "<think>\n\n</think>\n\n"
335340
data["label"] = [0] * len(data["tgt"])
336341
data["label"][-1] = 1
337342
else:
@@ -744,6 +749,25 @@ def _convert_example_to_record(self, example, max_seq_length, tokenizer, index):
744749
prefix_token = tokenizer.tokenize(example.prefix)
745750
cur_tokens = tokens_src + prefix_token + tokens_target
746751
extra_loss_mask = [0] * len(prefix_token)
752+
elif (
753+
"</think>" in tgt.strip() and self.chat_template == "ernie_vl_thinking"
754+
):
755+
reasoning_content = (
756+
tgt.strip()
757+
.split("</think>")[0]
758+
.rstrip("\n")
759+
.split("<think>")[-1]
760+
.lstrip("\n")
761+
)
762+
content = tgt.strip().split("</think>")[-1].lstrip("\n")
763+
tokens_target = (
764+
tokenizer.tokenize("\n<think>\n")
765+
+ tokenizer.tokenize(reasoning_content.strip("\n"))
766+
+ tokenizer.tokenize("\n</think>\n\n")
767+
+ tokenizer.tokenize(content)
768+
)
769+
cur_tokens = tokens_src + tokens_target
770+
extra_loss_mask = []
747771
else:
748772
cur_tokens = tokens_src + tokens_target
749773
extra_loss_mask = []
@@ -1025,7 +1049,7 @@ def _convert_example_to_record(self, example, max_seq_length, tokenizer, index):
10251049
# User
10261050
if "user" in example.messages[index - 1]["role"]:
10271051
src = example.messages[index - 1]["content"]
1028-
tokens_src = self.begin_of_query + tokenizer.tokenize(src)
1052+
tokens_src = self.begin_of_query + tokenizer.tokenize(src.strip())
10291053

10301054
# Tool
10311055
if "tool" in example.messages[index - 1]["role"]:
@@ -1041,30 +1065,39 @@ def _convert_example_to_record(self, example, max_seq_length, tokenizer, index):
10411065
tokens_src = tokens_src + tokenizer.tokenize("\n</tool_output>\n")
10421066

10431067
# Assistant
1044-
if "</think>" in turn["content"]:
1068+
if "</think>" in turn["content"].strip():
10451069
reasoning_content = (
10461070
turn["content"]
1071+
.strip()
10471072
.split("</think>")[0]
10481073
.rstrip("\n")
10491074
.split("<think>")[-1]
10501075
.lstrip("\n")
10511076
)
1052-
content = turn["content"].split("</think>")[-1].lstrip("\n")
1077+
content = turn["content"].strip().split("</think>")[-1].lstrip("\n")
10531078
else:
10541079
reasoning_content = ""
1055-
content = turn["content"]
1080+
content = turn["content"].strip()
10561081

10571082
tokens_target = []
10581083
if reasoning_content:
10591084
tokens_src = tokens_src + self.begin_of_response
1060-
tokens_src = tokens_src + tokenizer.tokenize("\n<think>\n")
1085+
if self.chat_template == "ernie_vl_thinking":
1086+
tokens_target = tokens_target + tokenizer.tokenize(
1087+
"\n<think>\n"
1088+
)
1089+
else:
1090+
tokens_target = tokens_target + tokenizer.tokenize("<think>\n")
10611091
tokens_target = tokens_target + tokenizer.tokenize(
10621092
reasoning_content.strip("\n")
10631093
)
10641094
tokens_target = tokens_target + tokenizer.tokenize("\n</think>\n\n")
10651095
else:
10661096
tokens_src = tokens_src + self.begin_of_response
1067-
tokens_src = tokens_src + tokenizer.tokenize("\n<think>\n")
1097+
if self.chat_template == "ernie_vl_thinking":
1098+
tokens_src = tokens_src + tokenizer.tokenize("\n<think>\n")
1099+
else:
1100+
tokens_src = tokens_src + tokenizer.tokenize("<think>\n")
10681101
tokens_src = tokens_src + tokenizer.tokenize("\n</think>\n\n")
10691102

10701103
if len(content) > 0:
@@ -1122,12 +1155,12 @@ def _convert_example_to_record(self, example, max_seq_length, tokenizer, index):
11221155

11231156
previous_cur_len += len(cur_tokens) + len(break_token_multi_turn)
11241157

1125-
if len(tokens) <= 4:
1126-
return []
1158+
if len(tokens) <= 4:
1159+
return []
11271160

1128-
if tokens[0] != self.begin_token:
1129-
tokens = [self.begin_token] + tokens
1130-
loss_mask = [0] + loss_mask
1161+
if tokens[0] != self.begin_token:
1162+
tokens = [self.begin_token] + tokens
1163+
loss_mask = [0] + loss_mask
11311164

11321165
assert len(tokens) <= self.max_seq_len, f"{len(tokens)}-{self.max_seq_len}"
11331166
assert (
@@ -1157,7 +1190,7 @@ def _convert_example_to_record(self, example, max_seq_length, tokenizer, index):
11571190
assert len(pos_ids) == len(pos_ids_extra)
11581191

11591192
if sum(loss_mask) == 0:
1160-
print("[BAD CASE] loss_mask all 0", example.src, example.tgt)
1193+
print("[BAD CASE] loss_mask all 0", example.messages)
11611194
return []
11621195

11631196
records = []
@@ -1211,6 +1244,9 @@ def _read_jsonl(self, input_file):
12111244
]
12121245
Example = namedtuple("Example", names)
12131246

1247+
if "tools" not in data:
1248+
data["tools"] = ""
1249+
12141250
# 自动生成label
12151251
if "label" not in data:
12161252
data["label"] = []

erniekit/train/ocr_vl_sft/workflow.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,6 @@ def run_ocr_vl_sft(
152152

153153
PipelineParallel.timer_printer = lambda _: None
154154

155-
# checkpoint O1 quantization is open by default.
156-
if (
157-
not finetuning_args.disable_ckpt_quant
158-
and finetuning_args.ckpt_quant_stage == "O0"
159-
and not model_args.lora
160-
):
161-
finetuning_args.ckpt_quant_stage = "O1"
162-
elif finetuning_args.disable_ckpt_quant:
163-
finetuning_args.ckpt_quant_stage = "O0"
164-
165155
finetuning_args.resume_from_checkpoint = get_resume_checkpoint_path(finetuning_args)
166156
if (
167157
finetuning_args.resume_from_checkpoint is not None

erniekit/train/sft/workflow.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,6 @@ def run_sft(
129129
if finetuning_args.release_grads is True:
130130
finetuning_args.release_grads = False
131131

132-
# checkpoint O1 quantization is open by default.
133-
if (
134-
not finetuning_args.disable_ckpt_quant
135-
and finetuning_args.ckpt_quant_stage == "O0"
136-
and not model_args.lora
137-
):
138-
finetuning_args.ckpt_quant_stage = "O1"
139-
elif finetuning_args.disable_ckpt_quant:
140-
finetuning_args.ckpt_quant_stage = "O0"
141-
142132
finetuning_args.print_config(model_args, "Model")
143133
finetuning_args.print_config(data_args, "Data")
144134

erniekit/train/vl_sft/workflow.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,16 +182,6 @@ def run_vl_sft(
182182

183183
PipelineParallel.timer_printer = lambda _: None
184184

185-
# checkpoint O1 quantization is open by default.
186-
if (
187-
not finetuning_args.disable_ckpt_quant
188-
and finetuning_args.ckpt_quant_stage == "O0"
189-
and not model_args.lora
190-
):
191-
finetuning_args.ckpt_quant_stage = "O1"
192-
elif finetuning_args.disable_ckpt_quant:
193-
finetuning_args.ckpt_quant_stage = "O0"
194-
195185
finetuning_args.resume_from_checkpoint = get_resume_checkpoint_path(finetuning_args)
196186
if (
197187
finetuning_args.resume_from_checkpoint is not None
@@ -706,6 +696,7 @@ def compute_metrics(p):
706696
"max_shot": finetuning_args.max_shot,
707697
"use_train_part_sharding": finetuning_args.text_use_train_part_sharding,
708698
"rope_3d": model_args.rope_3d,
699+
"chat_template": preprocess_args.chat_template,
709700
}
710701

711702
text_sft_train_reader = create_pyreader(config_dataset_text)

examples/configs/ERNIE-4.5-VL-28B-A3B-Thinking/sft/run_sft_128k.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ moe_aux_loss_lambda: 0.0
1717
moe_use_aux_free: true
1818
moe_use_hard_gate: true
1919
moe_multimodal_dispatch_use_allgather: v2-alltoall-unpad-text
20+
pp_seg_method: layer:Ernie4_5_DecoderLayer|ErnieDecoderLayer|EmptyLayer
2021

2122
# data
2223
train_dataset_path: "examples/data/sft_vl-train_demo1.jsonl"

examples/configs/ERNIE-4.5-VL-28B-A3B-Thinking/sft/run_sft_32k.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ moe_aux_loss_lambda: 0.0
1717
moe_use_aux_free: true
1818
moe_use_hard_gate: true
1919
moe_multimodal_dispatch_use_allgather: v2-alltoall-unpad-text
20+
pp_seg_method: layer:Ernie4_5_DecoderLayer|ErnieDecoderLayer|EmptyLayer
2021

2122
# data
2223
train_dataset_path: "examples/data/sft_vl-train_demo1.jsonl"

examples/configs/ERNIE-4.5-VL-28B-A3B-Thinking/sft/run_sft_8k.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ moe_aux_loss_lambda: 0.0
1717
moe_use_aux_free: true
1818
moe_use_hard_gate: true
1919
moe_multimodal_dispatch_use_allgather: v2-alltoall-unpad-text
20+
pp_seg_method: layer:Ernie4_5_DecoderLayer|ErnieDecoderLayer|EmptyLayer
2021

2122
# data
2223
train_dataset_path: "examples/data/sft_vl-train_demo1.jsonl"

examples/configs/ERNIE-4.5-VL-28B-A3B-Thinking/sft/run_sft_lora_32k.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ moe_aux_loss_lambda: 0.0
1818
moe_use_aux_free: true
1919
moe_use_hard_gate: true
2020
moe_multimodal_dispatch_use_allgather: v2-alltoall-unpad-text
21+
pp_seg_method: layer:Ernie4_5_DecoderLayer|ErnieDecoderLayer|EmptyLayer
2122

2223
# data
2324
train_dataset_path: "examples/data/sft_vl-train_demo1.jsonl"

examples/configs/ERNIE-4.5-VL-28B-A3B-Thinking/sft/run_sft_lora_8k.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ moe_aux_loss_lambda: 0.0
1818
moe_use_aux_free: true
1919
moe_use_hard_gate: true
2020
moe_multimodal_dispatch_use_allgather: v2-alltoall-unpad-text
21+
pp_seg_method: layer:Ernie4_5_DecoderLayer|ErnieDecoderLayer|EmptyLayer
2122

2223
# data
2324
train_dataset_path: "examples/data/sft_vl-train_demo1.jsonl"

examples/configs/ERNIE-4.5-VL-28B-A3B-Thinking/sft_function_call/run_sft_128k.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ moe_aux_loss_lambda: 0.0
1717
moe_use_aux_free: true
1818
moe_use_hard_gate: true
1919
moe_multimodal_dispatch_use_allgather: v2-alltoall-unpad-text
20+
pp_seg_method: layer:Ernie4_5_DecoderLayer|ErnieDecoderLayer|EmptyLayer
2021

2122
# data
2223
dataset_name: "FunctionCallSFTReader"

0 commit comments

Comments
 (0)