Skip to content

Commit 831d66e

Browse files
committed
EB4.5 supports SFT dataflow
1 parent 85529fe commit 831d66e

File tree

4 files changed

+137
-14
lines changed

4 files changed

+137
-14
lines changed

examples/experiments/ernie_pretrain/ernie/pretrain.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import re
1919
import time
2020
from dataclasses import dataclass
21+
from functools import partial
2122

2223
import numpy as np
2324
import paddle
@@ -60,6 +61,10 @@
6061
build_train_valid_test_datasets,
6162
check_data_split,
6263
)
64+
from paddleformers.datasets.finetuning import collate_fn
65+
from paddleformers.datasets.finetuning import create_dataset as create_dataset_sft
66+
from paddleformers.trainer import TrainingArguments
67+
from paddleformers.trl import ModelConfig
6368

6469
try:
6570
from paddleformers.trainer.trainer_utils import log_trainer_start
@@ -459,7 +464,47 @@ def sname_to_tname(pp_model):
459464

460465
logger.info(f"using model={type(model)}, cfg={cfg}")
461466

462-
train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset(args)
467+
dataset_config = {
468+
"tokenizer": tokenizer,
469+
"max_seq_len": args.max_seq_length + 1,
470+
"random_seed": args.seed,
471+
"num_replicas": args.dataset_world_size,
472+
"rank": args.dataset_rank,
473+
"num_samples_each_epoch": trainer_args.get("num_samples_each_epoch", 6000000),
474+
"random_shuffle": True,
475+
"greedy_intokens": True,
476+
"packing": True,
477+
"mix_strategy": "concat",
478+
"encode_one_turn": True,
479+
"use_template": True,
480+
"is_pretraining": False,
481+
}
482+
483+
if trainer_args.get("stage") == "sft":
484+
train_dataset = create_dataset_sft(
485+
task_group=trainer_args["train_dataset_path"],
486+
task_group_prob=trainer_args.get("train_dataset_prob", 1.0),
487+
sub_dataset_type=trainer_args.get("train_dataset_type", "erniekit"),
488+
**dataset_config,
489+
)
490+
eval_dataset = create_dataset_sft(
491+
task_group=trainer_args["eval_dataset_path"],
492+
task_group_prob=trainer_args.get("eval_dataset_prob", 1.0),
493+
sub_dataset_type=trainer_args.get("eval_dataset_type", "erniekit"),
494+
is_valid=True,
495+
**dataset_config,
496+
)
497+
data_collator = partial(
498+
collate_fn,
499+
tokenizer=tokenizer,
500+
training_args=TrainingArguments(
501+
output_dir=args.output_dir, num_nextn_predict_layers=args.multi_token_pred_depth
502+
),
503+
model_args=ModelConfig(stage="SFT", use_attn_mask_startend_row_indices=True),
504+
max_seq_len=args.max_seq_length + 1,
505+
)
506+
else:
507+
train_dataset, eval_dataset, _, data_collator = create_pretrained_dataset(args)
463508

464509
callbacks = []
465510
callbacks += [GlobalRNGCallback()]

examples/experiments/ernie_pretrain/models/ernie/modeling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def scaled_dot_product_attention(
357357
value_states.astype(value_states.dtype),
358358
startend_row_indices=startend_row_indices,
359359
dropout=config.attention_probs_dropout_prob,
360-
causal=False,
360+
causal=True,
361361
)
362362
attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads])
363363
return attn_output, None
@@ -1104,6 +1104,7 @@ def rope_attn(
11041104
past_key_value=None,
11051105
use_cache=False,
11061106
inbatch_pack_offset=None,
1107+
attn_mask_startend_row_indices=None,
11071108
):
11081109
if mix_layer is not None:
11091110
query_states, key_states, value_states = paddle.split(mix_layer, 3, axis=-1)
@@ -1186,6 +1187,7 @@ def rope_attn(
11861187
config=self.config,
11871188
inbatch_pack_offset=inbatch_pack_offset,
11881189
training=self.training,
1190+
startend_row_indices=attn_mask_startend_row_indices,
11891191
)
11901192
return attn_output, attn_weights, past_key_value
11911193

examples/experiments/ernie_pretrain/models/ernie/modeling_moe.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ def forward(
821821
use_cache: bool = False,
822822
inbatch_pack_offset: Optional[Tuple[paddle.Tensor]] = None,
823823
token_type_ids: Optional[Tuple[paddle.Tensor]] = None,
824+
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
824825
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
825826
if token_type_ids is not None:
826827
token_type_ids = token_type_ids[:, :-1]
@@ -901,6 +902,7 @@ def forward(
901902
past_key_value,
902903
use_cache,
903904
inbatch_pack_offset,
905+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
904906
use_reentrant=False,
905907
)
906908
else:
@@ -915,6 +917,7 @@ def forward(
915917
past_key_value=past_key_value,
916918
use_cache=use_cache,
917919
inbatch_pack_offset=inbatch_pack_offset,
920+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
918921
)
919922
if self.config.sequence_parallel:
920923
attn_output = attn_output.reshape([-1, attn_output.shape[-1]])
@@ -1152,6 +1155,7 @@ def forward(
11521155
use_cache: Optional[bool] = False,
11531156
inbatch_pack_offset: Optional[paddle.Tensor] = None,
11541157
output_gate_logits=True,
1158+
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
11551159
) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]:
11561160
residual = hidden_states
11571161
if token_type_ids is not None:
@@ -1178,6 +1182,7 @@ def forward(
11781182
use_cache=use_cache,
11791183
inbatch_pack_offset=inbatch_pack_offset,
11801184
token_type_ids=token_type_ids,
1185+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
11811186
)
11821187

11831188
if self.use_linear_residual_norm_recompute is True:
@@ -1660,6 +1665,7 @@ def forward(
16601665
output_hidden_states=None,
16611666
return_dict=False,
16621667
inbatch_pack_offset=None,
1668+
attn_mask_startend_row_indices=None,
16631669
**kwargs,
16641670
):
16651671
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1719,6 +1725,12 @@ def forward(
17191725
)
17201726
hidden_states = inputs_embeds
17211727

1728+
attn_mask_startend_row_indices_ori = attn_mask_startend_row_indices
1729+
if attn_mask_startend_row_indices is not None:
1730+
attn_mask_startend_row_indices = attn_mask_startend_row_indices[
1731+
:, :, : -self.config.multi_token_pred_depth
1732+
]
1733+
17221734
all_hidden_states = () if output_hidden_states else None
17231735
all_self_attns = () if output_attentions else None
17241736
next_decoder_cache = () if use_cache else None
@@ -1743,6 +1755,7 @@ def forward(
17431755
past_key_value,
17441756
use_cache,
17451757
inbatch_pack_offset,
1758+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
17461759
)
17471760
else:
17481761
layer_outputs = decoder_layer(
@@ -1754,6 +1767,7 @@ def forward(
17541767
past_key_value,
17551768
use_cache,
17561769
inbatch_pack_offset,
1770+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
17571771
)
17581772

17591773
if isinstance(layer_outputs, (tuple, list)):
@@ -1786,6 +1800,11 @@ def forward(
17861800
],
17871801
axis=1,
17881802
)
1803+
attn_mask_startend_row_indices_cur_depth = None
1804+
if attn_mask_startend_row_indices is not None:
1805+
attn_mask_startend_row_indices_cur_depth = attn_mask_startend_row_indices_ori[
1806+
:, :, (depth + 1) : inputs_embeds_ori.shape[1] + (depth + 1)
1807+
] - (depth + 1)
17891808

17901809
inputs_embeds_cur_depth_norm = self.mtp_emb_norm[depth](inputs_embeds_cur_depth)
17911810
hidden_states_norm = self.mtp_hidden_norm[depth](hidden_states)
@@ -1809,6 +1828,7 @@ def forward(
18091828
past_key_value,
18101829
use_cache,
18111830
inbatch_pack_offset,
1831+
attn_mask_startend_row_indices=attn_mask_startend_row_indices_cur_depth,
18121832
)
18131833

18141834
if isinstance(layer_outputs, (tuple, list)):
@@ -2132,6 +2152,8 @@ def forward(
21322152
data_id=None,
21332153
src_id=None,
21342154
inbatch_pack_offset=None,
2155+
attn_mask_startend_row_indices=None,
2156+
**kwargs,
21352157
):
21362158
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
21372159
output_hidden_states = (
@@ -2151,6 +2173,7 @@ def forward(
21512173
output_hidden_states=output_hidden_states,
21522174
return_dict=True,
21532175
inbatch_pack_offset=inbatch_pack_offset,
2176+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
21542177
)
21552178

21562179
hidden_states = outputs.last_hidden_state

0 commit comments

Comments
 (0)