@@ -332,7 +332,7 @@ def naive_fuse_merge_tp(weight_list, is_column=True, fuse_tensor_parts=2):
332332
333333
334334def naive_fuse_split_tp (
335- weight , tensor_parallel_degree , tensor_parallel_rank = None , is_column = True , fuse_tensor_parts = 2
335+ weight , tensor_parallel_degree , tensor_parallel_rank = None , is_column = True , fuse_tensor_parts = 2 , num_kv_groups = 1
336336):
337337 """
338338
@@ -353,19 +353,57 @@ def naive_fuse_split_tp(
353353 size = weight .get_shape ()[axis ]
354354 block_size = size // (fuse_tensor_parts * tensor_parallel_degree )
355355
356- splited = []
357- if tensor_parallel_rank is None :
358- begin , end , step = 0 , fuse_tensor_parts * tensor_parallel_degree , 1
356+ # for qkv tp split
357+ if fuse_tensor_parts == 3 and num_kv_groups > 1 :
358+ q_size = num_kv_groups * size // (num_kv_groups + 2 )
359+ kv_size = size - q_size
360+ q_block_size = q_size // tensor_parallel_degree
361+ kv_block_size = kv_size // (tensor_parallel_degree * 2 )
362+ q_end = q_size // q_block_size
363+ kv_end = kv_size // kv_block_size
364+
365+ splited = []
366+ if tensor_parallel_rank is None :
367+ begin , step = 0 , 1
368+ else :
369+ begin , step = tensor_parallel_rank , tensor_parallel_degree
370+ # for q split
371+ for rank in range (begin , q_end , step ):
372+ start = rank * q_block_size
373+ stop = (rank + 1 ) * q_block_size
374+ if axis == 0 or len (weight .get_shape ()) == 1 :
375+ tensor = weight [start :stop ]
376+ else :
377+ tensor = weight [:, start :stop ]
378+ splited .append (tensor )
379+ # for kv split
380+ for rank in range (begin , kv_end , step ):
381+ start = rank * kv_block_size + q_size
382+ stop = (rank + 1 ) * kv_block_size + q_size
383+ if axis == 0 or len (weight .get_shape ()) == 1 :
384+ tensor = weight [start :stop ]
385+ else :
386+ tensor = weight [:, start :stop ]
387+ splited .append (tensor )
388+
359389 else :
360- begin , end , step = tensor_parallel_rank , fuse_tensor_parts * tensor_parallel_degree , tensor_parallel_degree
361- for rank in range (begin , end , step ):
362- start = rank * block_size
363- stop = (rank + 1 ) * block_size
364- if axis == 0 or len (weight .get_shape ()) == 1 :
365- tensor = weight [start :stop ]
390+ splited = []
391+ if tensor_parallel_rank is None :
392+ begin , end , step = 0 , fuse_tensor_parts * tensor_parallel_degree , 1
366393 else :
367- tensor = weight [:, start :stop ]
368- splited .append (tensor )
394+ begin , end , step = (
395+ tensor_parallel_rank ,
396+ fuse_tensor_parts * tensor_parallel_degree ,
397+ tensor_parallel_degree ,
398+ )
399+ for rank in range (begin , end , step ):
400+ start = rank * block_size
401+ stop = (rank + 1 ) * block_size
402+ if axis == 0 or len (weight .get_shape ()) == 1 :
403+ tensor = weight [start :stop ]
404+ else :
405+ tensor = weight [:, start :stop ]
406+ splited .append (tensor )
369407
370408 if tensor_parallel_rank is None :
371409 ret = []
@@ -377,8 +415,10 @@ def naive_fuse_split_tp(
377415
378416 if isinstance (weight , paddle .Tensor ):
379417
380- def slice_concat_by_axis (weight , fuse_tensor_parts , tensor_parallel_degree , tensor_parallel_rank , axis = 0 ):
381- total_splits = fuse_tensor_parts * tensor_parallel_degree
418+ def slice_concat_by_axis (
419+ weight , fuse_tensor_parts , tensor_parallel_degree , tensor_parallel_rank , num_kv_groups = 1 , axis = 0
420+ ):
421+ total_splits = fuse_tensor_parts * tensor_parallel_degree * num_kv_groups
382422 dim_size = weight .shape [axis ]
383423 split_size = dim_size // total_splits
384424
@@ -395,16 +435,21 @@ def slice_concat_by_axis(weight, fuse_tensor_parts, tensor_parallel_degree, tens
395435
396436 if tensor_parallel_rank is not None :
397437 return slice_concat_by_axis (
398- weight , fuse_tensor_parts , tensor_parallel_degree , tensor_parallel_rank , axis = axis
438+ weight ,
439+ fuse_tensor_parts ,
440+ tensor_parallel_degree ,
441+ tensor_parallel_rank ,
442+ num_kv_groups = num_kv_groups ,
443+ axis = axis ,
399444 )
400445 else :
401- splited = paddle .split (weight , fuse_tensor_parts * tensor_parallel_degree , axis = axis )
446+ splited = paddle .split (weight , fuse_tensor_parts * tensor_parallel_degree * num_kv_groups , axis = axis )
402447 ret = []
403448 for tensor_parallel_rank in range (tensor_parallel_degree ):
404449 ret .append (paddle .cat (splited [tensor_parallel_rank ::tensor_parallel_degree ], axis = axis ))
405450 return ret
406451 else :
407- splited = np .split (weight , fuse_tensor_parts * tensor_parallel_degree , axis = axis )
452+ splited = np .split (weight , fuse_tensor_parts * tensor_parallel_degree * num_kv_groups , axis = axis )
408453
409454 if tensor_parallel_rank is None :
410455 ret = []
@@ -415,6 +460,90 @@ def slice_concat_by_axis(weight, fuse_tensor_parts, tensor_parallel_degree, tens
415460 return np .concatenate (splited [tensor_parallel_rank ::tensor_parallel_degree ], axis = axis )
416461
417462
463+ # def naive_fuse_split_tp(
464+ # weight, tensor_parallel_degree, tensor_parallel_rank=None, is_column=True, fuse_tensor_parts=2
465+ # ):
466+ # """
467+
468+ # [A1, A2, B1, B2] => [A1 B1],[A2 B2]
469+
470+ # Args:
471+ # weight (numpy.ndarray): the tensor weight,
472+ # tensor_parallel_degree (int): tensor_parallel_degree
473+ # tensor_parallel_rank (int): tensor_parallel_rank
474+ # is_column (bool, optional): is ColumnLinear . Defaults to True.
475+
476+ # Returns:
477+ # tensor (numpy.ndarray): splited weight.
478+
479+ # """
480+ # axis = -1 if is_column else 0
481+ # if "PySafeSlice" in str(type(weight)):
482+ # size = weight.get_shape()[axis]
483+ # block_size = size // (fuse_tensor_parts * tensor_parallel_degree)
484+
485+ # splited = []
486+ # if tensor_parallel_rank is None:
487+ # begin, end, step = 0, fuse_tensor_parts * tensor_parallel_degree, 1
488+ # else:
489+ # begin, end, step = tensor_parallel_rank, fuse_tensor_parts * tensor_parallel_degree, tensor_parallel_degree
490+ # for rank in range(begin, end, step):
491+ # start = rank * block_size
492+ # stop = (rank + 1) * block_size
493+ # if axis == 0 or len(weight.get_shape()) == 1:
494+ # tensor = weight[start:stop]
495+ # else:
496+ # tensor = weight[:, start:stop]
497+ # splited.append(tensor)
498+
499+ # if tensor_parallel_rank is None:
500+ # ret = []
501+ # for tensor_parallel_rank in range(tensor_parallel_degree):
502+ # ret.append(np.concatenate(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis))
503+ # return ret
504+
505+ # return np.concatenate(splited, axis=axis)
506+
507+ # if isinstance(weight, paddle.Tensor):
508+
509+ # def slice_concat_by_axis(weight, fuse_tensor_parts, tensor_parallel_degree, tensor_parallel_rank, axis=0):
510+ # total_splits = fuse_tensor_parts * tensor_parallel_degree
511+ # dim_size = weight.shape[axis]
512+ # split_size = dim_size // total_splits
513+
514+ # slices = []
515+ # for idx in range(tensor_parallel_rank, total_splits, tensor_parallel_degree):
516+ # start = idx * split_size
517+ # end = (start + split_size) if (idx != total_splits - 1) else dim_size
518+ # slice_idx = [slice(None)] * len(weight.shape)
519+ # slice_idx[axis] = slice(start, end)
520+ # block = weight[tuple(slice_idx)]
521+ # slices.append(block)
522+ # result = paddle.cat(slices, axis=axis)
523+ # return result
524+
525+ # if tensor_parallel_rank is not None:
526+ # return slice_concat_by_axis(
527+ # weight, fuse_tensor_parts, tensor_parallel_degree, tensor_parallel_rank, axis=axis
528+ # )
529+ # else:
530+ # splited = paddle.split(weight, fuse_tensor_parts * tensor_parallel_degree, axis=axis)
531+ # ret = []
532+ # for tensor_parallel_rank in range(tensor_parallel_degree):
533+ # ret.append(paddle.cat(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis))
534+ # return ret
535+ # else:
536+ # splited = np.split(weight, fuse_tensor_parts * tensor_parallel_degree, axis=axis)
537+
538+ # if tensor_parallel_rank is None:
539+ # ret = []
540+ # for tensor_parallel_rank in range(tensor_parallel_degree):
541+ # ret.append(np.concatenate(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis))
542+ # return ret
543+
544+ # return np.concatenate(splited[tensor_parallel_rank::tensor_parallel_degree], axis=axis)
545+
546+
418547def normal_fuse_merge_tp (weight_list , is_column = True ):
419548 """
420549
@@ -740,7 +869,15 @@ def fn(
740869
741870
742871def get_tensor_parallel_split_func (tensor_parallel_degree , tensor_parallel_rank , num_attention_heads = None ):
743- def fn (x , is_column = True , transpose = False , is_old_qkv = False , is_naive_2fuse = False , is_naive_3fuse = False ):
872+ def fn (
873+ x ,
874+ is_column = True ,
875+ transpose = False ,
876+ is_old_qkv = False ,
877+ is_naive_2fuse = False ,
878+ is_naive_3fuse = False ,
879+ num_kv_groups = 1 ,
880+ ):
744881 if x is None :
745882 return None
746883 if transpose :
@@ -758,14 +895,106 @@ def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=Fals
758895 )
759896 if is_naive_3fuse :
760897 return naive_fuse_split_tp (
761- x , tensor_parallel_degree , tensor_parallel_rank , is_column = is_column , fuse_tensor_parts = 3
898+ x ,
899+ tensor_parallel_degree ,
900+ tensor_parallel_rank ,
901+ is_column = is_column ,
902+ fuse_tensor_parts = 3 ,
903+ num_kv_groups = num_kv_groups ,
762904 )
763905
764906 return normal_fuse_split_tp (x , tensor_parallel_degree , tensor_parallel_rank , is_column = is_column )
765907
766908 return fn
767909
768910
911+ # def get_tensor_parallel_split_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
912+ # def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False):
913+ # if x is None:
914+ # return None
915+ # if transpose:
916+ # if isinstance(x, paddle.Tensor):
917+ # x = paddle.transpose(x, [1, 0])
918+ # else:
919+ # x = np.transpose(x, [1, 0])
920+ # if is_old_qkv:
921+ # assert is_column, "QKV tensor should be column parallel linear."
922+ # assert num_attention_heads is not None, "is_old_qkv need num_attention_heads"
923+ # x = naive_merged_qkv_to_tensor_parallel_qkv(x, num_attention_heads)
924+ # if is_naive_2fuse:
925+ # return naive_fuse_split_tp(
926+ # x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=2
927+ # )
928+ # if is_naive_3fuse:
929+ # return naive_fuse_split_tp(
930+ # x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=3
931+ # )
932+
933+ # return normal_fuse_split_tp(x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column)
934+
935+ # return fn
936+
937+
938+ # def get_tensor_parallel_split_func(tensor_parallel_degree, tensor_parallel_rank, num_attention_heads=None):
939+ # def fn(x, is_column=True, transpose=False, is_old_qkv=False, is_naive_2fuse=False, is_naive_3fuse=False):
940+ # # print(f"\nis_column={is_column}, is_old_qkv={is_old_qkv}, is_naive_2fuse={is_naive_2fuse}")
941+ # if x is None:
942+ # return None
943+ # if transpose:
944+ # if isinstance(x, paddle.Tensor):
945+ # x = paddle.transpose(x, [1, 0])
946+ # else:
947+ # x = np.transpose(x, [1, 0])
948+
949+ # # if is_old_qkv:
950+ # # assert is_column, "QKV tensor should be column parallel linear."
951+ # # assert num_attention_heads is not None, "is_old_qkv need num_attention_heads"
952+ # # x = naive_merged_qkv_to_tensor_parallel_qkv(x, num_attention_heads)
953+ # # if is_naive_2fuse:
954+ # # return naive_fuse_split_tp(
955+ # # x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=2
956+ # # )
957+
958+ # if is_old_qkv:
959+ # assert num_attention_heads is not None, "is_old_qkv need num_attention_heads"
960+ # if not is_column:
961+ # if isinstance(x, paddle.Tensor):
962+ # x = paddle.transpose(x, [1, 0])
963+ # else:
964+ # x = np.transpose(x, [1, 0])
965+ # x = naive_merged_qkv_to_tensor_parallel_qkv(x, num_attention_heads)
966+ # if not is_column:
967+ # if isinstance(x, paddle.Tensor):
968+ # x = paddle.transpose(x, [1, 0])
969+ # else:
970+ # x = np.transpose(x, [1, 0])
971+
972+ # if is_naive_2fuse:
973+ # # if not is_column:
974+ # # if isinstance(x, paddle.Tensor):
975+ # # x = paddle.transpose(x, [1, 0])
976+ # # else:
977+ # # x = np.transpose(x, [1, 0])
978+ # x = naive_fuse_split_tp(
979+ # x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=2
980+ # )
981+ # # if not is_column:
982+ # # if isinstance(x, paddle.Tensor):
983+ # # x = paddle.transpose(x, [1, 0])
984+ # # else:
985+ # # x = np.transpose(x, [1, 0])
986+ # return x
987+
988+ # if is_naive_3fuse:
989+ # return naive_fuse_split_tp(
990+ # x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column, fuse_tensor_parts=3
991+ # )
992+
993+ # return normal_fuse_split_tp(x, tensor_parallel_degree, tensor_parallel_rank, is_column=is_column)
994+
995+ # return fn
996+
997+
769998def split_or_merge_func (is_split , tensor_parallel_degree , tensor_parallel_rank , num_attention_heads = None ):
770999 if is_split :
7711000 return get_tensor_parallel_split_func (tensor_parallel_degree , tensor_parallel_rank , num_attention_heads )
0 commit comments