Skip to content

Commit f166fad

Browse files
committed
[Phi4]Add phi3 model
1 parent d17af14 commit f166fad

16 files changed

Lines changed: 1596 additions & 31 deletions

File tree

examples/alignment/dpo/run_dpo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
AutoTokenizer,
4444
LlamaForCausalLM,
4545
LlamaForCausalLMPipe,
46+
Phi3ForCausalLM,
47+
Phi3ForCausalLMPipe,
4648
Qwen2ForCausalLM,
4749
Qwen2ForCausalLMPipe,
4850
Qwen2MoeForCausalLM,
@@ -68,6 +70,8 @@
6870
Qwen3ForCausalLMPipe,
6971
Qwen3MoeForCausalLM,
7072
Qwen3MoeForCausalLMPipe,
73+
Phi3ForCausalLM,
74+
Phi3ForCausalLMPipe,
7175
]
7276

7377

paddleformers/nn/pp_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def forward(self, args):
266266
dtype="int64",
267267
)
268268
.unsqueeze(0)
269-
.tile(input_ids.shape[0], 1)
269+
.tile([input_ids.shape[0], 1])
270270
)
271271
if self.config.fuse_rope:
272272
position_embeddings = None

paddleformers/transformers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@
341341
"auto": ["AutoModelForCausalLM"],
342342
"legacy.tokenizer_utils_base": ["EncodingFast"],
343343
"legacy": [],
344+
"phi3.configuration": ["Phi3Config"],
345+
"phi3.tokenizer": ["Phi3Tokenizer"],
346+
"phi3.modeling": ["Phi3Model", "Phi3ForCausalLM", "Phi3ForCausalLMPipe"],
344347
}
345348

346349
if TYPE_CHECKING:
@@ -403,6 +406,7 @@
403406
from .qwen3_moe import *
404407
from .glm4_moe import *
405408
from .gpt_oss import *
409+
from .phi3 import *
406410
else:
407411
sys.modules[__name__] = _LazyModule(
408412
__name__,

paddleformers/transformers/auto/configuration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
("qwen3_moe", "Qwen3MoeConfig"),
4747
("glm4_moe", "Glm4MoeConfig"),
4848
("gpt_oss", "GptOssConfig"),
49+
("phi3", "Phi3Config"),
4950
]
5051
)
5152

paddleformers/transformers/auto/modeling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
("Qwen3Moe", "qwen3_moe"),
6666
("Glm4Moe", "glm4_moe"),
6767
("GptOss", "gpt_oss"),
68+
("Phi3", "phi3"),
6869
]
6970
)
7071

paddleformers/transformers/conversion_utils.py

Lines changed: 248 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2):
332332

333333

334334
def naive_fuse_split_tp(
335-
weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True, fuse_tensor_parts=2
335+
weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True, fuse_tensor_parts=2, num_kv_groups=1
336336
):
337337
"""
338338
@@ -353,19 +353,57 @@ def naive_fuse_split_tp(
353353
size = weight.get_shape()[axis]
354354
block_size = size // (fuse_tensor_parts * tensor_parallel_degree)
355355

356-
splited = []
357-
if tensor_parallel_rank is None:
358-
begin, end, step = 0, fuse_tensor_parts * tensor_parallel_degree, 1
356+
# for qkv tp split
357+
if fuse_tensor_parts == 3 and num_kv_groups > 1:
358+
q_size = num_kv_groups * size // (num_kv_groups + 2)
359+
kv_size = size - q_size
360+
q_block_size = q_size // tensor_parallel_degree
361+
kv_block_size = kv_size // (tensor_parallel_degree * 2)
362+
q_end = q_size // q_block_size
363+
kv_end = kv_size // kv_block_size
364+
365+
splited = []
366+
if tensor_parallel_rank is None:
367+
begin, step = 0, 1
368+
else:
369+
begin, step = tensor_parallel_rank, tensor_parallel_degree
370+
# for q split
371+
for rank in range(begin, q_end, step):
372+
start = rank * q_block_size
373+
stop = (rank + 1) * q_block_size
374+
if axis == 0 or len(weight.get_shape()) == 1:
375+
tensor = weight[start:stop]
376+
else:
377+
tensor = weight[:, start:stop]
378+
splited.append(tensor)
379+
# for kv split
380+
for rank in range(begin, kv_end, step):
381+
start = rank * kv_block_size + q_size
382+
stop = (rank + 1) * kv_block_size + q_size
383+
if axis == 0 or len(weight.get_shape()) == 1:
384+
tensor = weight[start:stop]
385+
else:
386+
tensor = weight[:, start:stop]
387+
splited.append(tensor)
388+
359389
else:
360-
begin, end, step = tensor_parallel_rank, fuse_tensor_parts * tensor_parallel_degree, tensor_parallel_degree
361-
for rank in range(begin, end, step):
362-
start = rank * block_size
363-
stop = (rank + 1) * block_size
364-
if axis == 0 or len(weight.get_shape()) == 1:
365-
tensor = weight[start:stop]
390+
splited = []
391+
if tensor_parallel_rank is None:
392+
begin, end, step = 0, fuse_tensor_parts * tensor_parallel_degree, 1
366393
else:
367-
tensor = weight[:, start:stop]
368-
splited.append(tensor)
394+
begin, end, step = (
395+
tensor_parallel_rank,
396+
fuse_tensor_parts * tensor_parallel_degree,
397+
tensor_parallel_degree,
398+
)
399+
for rank in range(begin, end, step):
400+
start = rank * block_size
401+
stop = (rank + 1) * block_size
402+
if axis == 0 or len(weight.get_shape()) == 1:
403+
tensor = weight[start:stop]
404+
else:
405+
tensor = weight[:, start:stop]
406+
splited.append(tensor)
369407

370408
if tensor_parallel_rank is None:
371409
ret = []
@@ -377,8 +415,10 @@ def naive_fuse_split_tp(
377415

378416
if isinstance(weight, paddle.Tensor):
379417

380-
def slice_concat_by_axis(weight, fuse_tensor_parts, tensor_parallel_degree, tensor_parallel_rank, axis=0):
381-
total_splits = fuse_tensor_parts * tensor_parallel_degree
418+
def slice_concat_by_axis(
419+
weight, fuse_tensor_parts, tensor_parallel_degree, tensor_parallel_rank, num_kv_groups=1, axis=0
420+
):
421+
total_splits = fuse_tensor_parts * tensor_parallel_degree * num_kv_groups
382422
dim_size = weight.shape[axis]
383423
split_size = dim_size // total_splits
384424

@@ -395,16 +435,21 @@ def slice_concat_by_axis(weight, fuse_tensor_parts, tensor_parallel_degree, tens
395435

396436
if tensor_parallel_rank is not None:
397437
return slice_concat_by_axis(
398-
weight, fuse_tensor_parts, tensor_parallel_degree, tensor_parallel_rank, axis=axis
438+
weight,
439+
fuse_tensor_parts,
440+
tensor_parallel_degree,
441+
tensor_parallel_rank,
442+
num_kv_groups=num_kv_groups,
443+
axis=axis,
399444
)
400445
else:
401-
splited = paddle.split(weight, fuse_tensor_parts * tensor_parallel_degree, axis=axis)
446+
splited = paddle.split(weight, fuse_tensor_parts * tensor_parallel_degree * num_kv_groups, axis=axis)
402447
ret = []
403448
for tensor_parallel_rank in range(tensor_parallel_degree):
404449
ret.append(paddle.cat(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis))
405450
return ret
406451
else:
407-
splited = np.split(weight, fuse_tensor_parts * tensor_parallel_degree, axis=axis)
452+
splited = np.split(weight, fuse_tensor_parts * tensor_parallel_degree * num_kv_groups, axis=axis)
408453

409454
if tensor_parallel_rank is None:
410455
ret = []
@@ -415,6 +460,90 @@ def slice_concat_by_axis(weight, fuse_tensor_parts, tensor_parallel_degree, tens
415460
return np.concatenate(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis)
416461

417462

463+
# def naive_fuse_split_tp(
464+
# weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True, fuse_tensor_parts=2
465+
# ):
466+
# """
467+
468+
# [A1, A2, B1, B2] => [A1 B1],[A2 B2]
469+
470+
# Args:
471+
# weight (numpy.ndarray): the tensor weight,
472+
# tensor_parallel_degree (int): tensor_parallel_degree
473+
# tensor_parallel_rank (int): tensor_parallel_rank
474+
# is_column (bool, optional): is ColumnLinear . Defaults to True.
475+
476+
# Returns:
477+
# tensor (numpy.ndarray): splited weight.
478+
479+
# """
480+
# axis = -1 if is_column else 0
481+
# if "PySafeSlice" in str(type(weight)):
482+
# size = weight.get_shape()[axis]
483+
# block_size = size // (fuse_tensor_parts * tensor_parallel_degree)
484+
485+
# splited = []
486+
# if tensor_parallel_rank is None:
487+
# begin, end, step = 0, fuse_tensor_parts * tensor_parallel_degree, 1
488+
# else:
489+
# begin, end, step = tensor_parallel_rank, fuse_tensor_parts * tensor_parallel_degree, tensor_parallel_degree
490+
# for rank in range(begin, end, step):
491+
# start = rank * block_size
492+
# stop = (rank + 1) * block_size
493+
# if axis == 0 or len(weight.get_shape()) == 1:
494+
# tensor = weight[start:stop]
495+
# else:
496+
# tensor = weight[:, start:stop]
497+
# splited.append(tensor)
498+
499+
# if tensor_parallel_rank is None:
500+
# ret = []
501+
# for tensor_parallel_rank in range(tensor_parallel_degree):
502+
# ret.append(np.concatenate(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis))
503+
# return ret
504+
505+
# return np.concatenate(splited, axis=axis)
506+
507+
# if isinstance(weight, paddle.Tensor):
508+
509+
# def slice_concat_by_axis(weight, fuse_tensor_parts, tensor_parallel_degree, tensor_parallel_rank, axis=0):
510+
# total_splits = fuse_tensor_parts * tensor_parallel_degree
511+
# dim_size = weight.shape[axis]
512+
# split_size = dim_size // total_splits
513+
514+
# slices = []
515+
# for idx in range(tensor_parallel_rank, total_splits, tensor_parallel_degree):
516+
# start = idx * split_size
517+
# end = (start + split_size) if (idx != total_splits - 1) else dim_size
518+
# slice_idx = [slice(None)] * len(weight.shape)
519+
# slice_idx[axis] = slice(start, end)
520+
# block = weight[tuple(slice_idx)]
521+
# slices.append(block)
522+
# result = paddle.cat(slices, axis=axis)
523+
# return result
524+
525+
# if tensor_parallel_rank is not None:
526+
# return slice_concat_by_axis(
527+
# weight, fuse_tensor_parts, tensor_parallel_degree, tensor_parallel_rank, axis=axis
528+
# )
529+
# else:
530+
# splited = paddle.split(weight, fuse_tensor_parts * tensor_parallel_degree, axis=axis)
531+
# ret = []
532+
# for tensor_parallel_rank in range(tensor_parallel_degree):
533+
# ret.append(paddle.cat(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis))
534+
# return ret
535+
# else:
536+
# splited = np.split(weight, fuse_tensor_parts * tensor_parallel_degree, axis=axis)
537+
538+
# if tensor_parallel_rank is None:
539+
# ret = []
540+
# for tensor_parallel_rank in range(tensor_parallel_degree):
541+
# ret.append(np.concatenate(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis))
542+
# return ret
543+
544+
# return np.concatenate(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis)
545+
546+
418547
def normal_fuse_merge_tp(weight_list, is_column=True):
419548
"""
420549
@@ -740,7 +869,15 @@ def fn(
740869

741870

742871
def get_tensor_parallel_split_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
743-
def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False):
872+
def fn(
873+
x,
874+
is_column=True,
875+
transpose=False,
876+
is_old_qkv=False,
877+
is_naive_2fuse=False,
878+
is_naive_3fuse=False,
879+
num_kv_groups=1,
880+
):
744881
if x is None:
745882
return None
746883
if transpose:
@@ -758,14 +895,106 @@ def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=Fals
758895
)
759896
if is_naive_3fuse:
760897
return naive_fuse_split_tp(
761-
x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=3
898+
x,
899+
tensor_parallel_degree,
900+
tensor_parallel_rank,
901+
is_column=is_column,
902+
fuse_tensor_parts=3,
903+
num_kv_groups=num_kv_groups,
762904
)
763905

764906
return normal_fuse_split_tp(x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column)
765907

766908
return fn
767909

768910

911+
# def get_tensor_parallel_split_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
912+
# def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False):
913+
# if x is None:
914+
# return None
915+
# if transpose:
916+
# if isinstance(x, paddle.Tensor):
917+
# x = paddle.transpose(x, [1, 0])
918+
# else:
919+
# x = np.transpose(x, [1, 0])
920+
# if is_old_qkv:
921+
# assert is_column, "QKV tensor should be column parallel linear."
922+
# assert num_attention_heads is not None, "is_old_qkv need num_attention_heads"
923+
# x = naive_merged_qkv_to_tensor_parallel_qkv(x, num_attention_heads)
924+
# if is_naive_2fuse:
925+
# return naive_fuse_split_tp(
926+
# x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=2
927+
# )
928+
# if is_naive_3fuse:
929+
# return naive_fuse_split_tp(
930+
# x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=3
931+
# )
932+
933+
# return normal_fuse_split_tp(x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column)
934+
935+
# return fn
936+
937+
938+
# def get_tensor_parallel_split_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
939+
# def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False):
940+
# # print(f"\nis_column={is_column}, is_old_qkv={is_old_qkv}, is_naive_2fuse={is_naive_2fuse}")
941+
# if x is None:
942+
# return None
943+
# if transpose:
944+
# if isinstance(x, paddle.Tensor):
945+
# x = paddle.transpose(x, [1, 0])
946+
# else:
947+
# x = np.transpose(x, [1, 0])
948+
949+
# # if is_old_qkv:
950+
# # assert is_column, "QKV tensor should be column parallel linear."
951+
# # assert num_attention_heads is not None, "is_old_qkv need num_attention_heads"
952+
# # x = naive_merged_qkv_to_tensor_parallel_qkv(x, num_attention_heads)
953+
# # if is_naive_2fuse:
954+
# # return naive_fuse_split_tp(
955+
# # x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=2
956+
# # )
957+
958+
# if is_old_qkv:
959+
# assert num_attention_heads is not None, "is_old_qkv need num_attention_heads"
960+
# if not is_column:
961+
# if isinstance(x, paddle.Tensor):
962+
# x = paddle.transpose(x, [1, 0])
963+
# else:
964+
# x = np.transpose(x, [1, 0])
965+
# x = naive_merged_qkv_to_tensor_parallel_qkv(x, num_attention_heads)
966+
# if not is_column:
967+
# if isinstance(x, paddle.Tensor):
968+
# x = paddle.transpose(x, [1, 0])
969+
# else:
970+
# x = np.transpose(x, [1, 0])
971+
972+
# if is_naive_2fuse:
973+
# # if not is_column:
974+
# # if isinstance(x, paddle.Tensor):
975+
# # x = paddle.transpose(x, [1, 0])
976+
# # else:
977+
# # x = np.transpose(x, [1, 0])
978+
# x = naive_fuse_split_tp(
979+
# x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=2
980+
# )
981+
# # if not is_column:
982+
# # if isinstance(x, paddle.Tensor):
983+
# # x = paddle.transpose(x, [1, 0])
984+
# # else:
985+
# # x = np.transpose(x, [1, 0])
986+
# return x
987+
988+
# if is_naive_3fuse:
989+
# return naive_fuse_split_tp(
990+
# x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=3
991+
# )
992+
993+
# return normal_fuse_split_tp(x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column)
994+
995+
# return fn
996+
997+
769998
def split_or_merge_func(is_split, tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
770999
if is_split:
7711000
return get_tensor_parallel_split_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads)

0 commit comments

Comments
 (0)