9
9
TensorShape = Union [torch .Size , List [int ], Tuple [int ]]
10
10
11
11
12
- def send_tensor_meta (tensor , need_meta = True , next_rank = None ) -> bool :
13
- """Sends tensor meta information before sending a specific tensor.
14
- Since the recipient must know the shape of the tensor in p2p communications,
15
- meta information of the tensor should be sent before communications. This function
16
- synchronizes with :func:`recv_tensor_meta`.
12
+ def send_meta_helper (obj , next_rank , tensor_kwargs ):
13
+ send_shape = torch .tensor (obj .size (), ** tensor_kwargs )
14
+ send_ndims = torch .tensor (len (obj .size ()), ** tensor_kwargs )
15
+ dist .send (send_ndims , next_rank )
16
+ dist .send (send_shape , next_rank )
17
+
18
+
19
+ def send_obj_meta (obj , need_meta = True , next_rank = None ) -> bool :
20
+ """Sends obj meta information before sending a specific obj.
21
+ Since the recipient must know the shape of the obj in p2p communications,
22
+ meta information of the obj should be sent before communications. This function
23
+ synchronizes with :func:`recv_obj_meta`.
17
24
18
25
Args:
19
- tensor ( :class:`torch.Tensor`): Tensor to be sent.
26
+ obj (Union[ :class:`torch.Tensor`, List[:class:`torch.Tensor`]] ): obj to be sent.
20
27
need_meta (bool, optional): If False, meta information won't be sent.
21
28
next_rank (int): The rank of the next member in pipeline parallel group.
22
29
@@ -28,42 +35,57 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
28
35
next_rank = gpc .get_next_global_rank (ParallelMode .PIPELINE )
29
36
30
37
tensor_kwargs = {'dtype' : torch .long , 'device' : get_current_device ()}
31
-
32
- send_shape = torch .tensor (tensor .size (), ** tensor_kwargs )
33
- send_ndims = torch .tensor (len (tensor .size ()), ** tensor_kwargs )
34
- dist .send (send_ndims , next_rank )
35
- dist .send (send_shape , next_rank )
38
+ if isinstance (obj , torch .Tensor ):
39
+ send_obj_nums = torch .tensor (1 , ** tensor_kwargs )
40
+ dist .send (send_obj_nums , next_rank )
41
+ send_meta_helper (obj , next_rank , tensor_kwargs )
42
+ else :
43
+ send_obj_nums = torch .tensor (len (obj ), ** tensor_kwargs )
44
+ dist .send (send_obj_nums , next_rank )
45
+ for tensor_to_send in obj :
46
+ send_meta_helper (tensor_to_send , next_rank , tensor_kwargs )
36
47
37
48
return False
38
49
39
50
40
- def recv_tensor_meta (tensor_shape : TensorShape , prev_rank = None ) -> torch .Size :
41
- """Receives tensor meta information before receiving a specific tensor.
42
- Since the recipient must know the shape of the tensor in p2p communications,
43
- meta information of the tensor should be received before communications. This function
44
- synchronizes with :func:`send_tensor_meta`.
51
+ def recv_meta_helper (prev_rank , tensor_kwargs ):
52
+ recv_ndims = torch .empty ((), ** tensor_kwargs )
53
+ dist .recv (recv_ndims , prev_rank )
54
+ recv_shape = torch .empty (recv_ndims , ** tensor_kwargs )
55
+ dist .recv (recv_shape , prev_rank )
56
+ return recv_shape
57
+
58
+
59
+ def recv_obj_meta (obj_shape , prev_rank = None ) -> torch .Size :
60
+ """Receives obj meta information before receiving a specific obj.
61
+ Since the recipient must know the shape of the obj in p2p communications,
62
+ meta information of the obj should be received before communications. This function
63
+ synchronizes with :func:`send_obj_meta`.
45
64
46
65
Args:
47
- tensor_shape ( :class:`torch.Size`): The shape of the tensor to be received.
48
- prev_rank (int): The rank of the source of the tensor .
66
+ obj_shape (Union[ :class:`torch.Size`, List[:class:`torch.Size`]] ): The shape of the obj to be received.
67
+ prev_rank (int): The rank of the source of the obj .
49
68
50
69
Returns:
51
- :class:`torch.Size`: The shape of the tensor to be received.
70
+ Union[ :class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
52
71
"""
53
- if tensor_shape is None :
72
+ if obj_shape is None :
54
73
if prev_rank is None :
55
74
prev_rank = gpc .get_prev_global_rank (ParallelMode .PIPELINE )
56
75
57
76
tensor_kwargs = {'dtype' : torch .long , 'device' : get_current_device ()}
58
-
59
- recv_ndims = torch .empty ((), ** tensor_kwargs )
60
- dist .recv (recv_ndims , prev_rank )
61
- recv_shape = torch .empty (recv_ndims , ** tensor_kwargs )
62
- dist .recv (recv_shape , prev_rank )
63
-
64
- tensor_shape = torch .Size (recv_shape )
65
-
66
- return tensor_shape
77
+ recv_obj_nums = torch .empty ((), ** tensor_kwargs )
78
+ dist .recv (recv_obj_nums , prev_rank )
79
+ if recv_obj_nums .item () == 1 :
80
+ recv_shape = recv_meta_helper (prev_rank , tensor_kwargs )
81
+ obj_shape = torch .Size (recv_shape )
82
+ else :
83
+ obj_shape = []
84
+ for i in range (recv_obj_nums .item ()):
85
+ recv_shape = recv_meta_helper (prev_rank , tensor_kwargs )
86
+ obj_shape .append (torch .Size (recv_shape ))
87
+
88
+ return obj_shape
67
89
68
90
69
91
def split_tensor_into_1d_equal_chunks (tensor : torch .Tensor , new_buffer = False ) -> torch .Tensor :
0 commit comments