Skip to content

Commit b167258

Browse files
[pipeline]refactor ppschedule to support tensor list (#1050)
* [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e650. * refactor ppschedule to support tensor list * polish
1 parent e3fde4e commit b167258

File tree

4 files changed

+261
-217
lines changed

4 files changed

+261
-217
lines changed

colossalai/communication/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
44
recv_forward, recv_backward)
55
from .ring import ring_forward
6-
from .utils import send_tensor_meta, recv_tensor_meta
6+
from .utils import send_obj_meta, recv_obj_meta
77

88
__all__ = [
99
'all_gather',
@@ -21,6 +21,6 @@
2121
'recv_backward',
2222
'recv_forward',
2323
'ring_forward',
24-
'send_tensor_meta',
25-
'recv_tensor_meta',
24+
'send_obj_meta',
25+
'recv_obj_meta',
2626
]

colossalai/communication/utils.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,21 @@
99
TensorShape = Union[torch.Size, List[int], Tuple[int]]
1010

1111

12-
def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
13-
"""Sends tensor meta information before sending a specific tensor.
14-
Since the recipient must know the shape of the tensor in p2p communications,
15-
meta information of the tensor should be sent before communications. This function
16-
synchronizes with :func:`recv_tensor_meta`.
12+
def send_meta_helper(obj, next_rank, tensor_kwargs):
13+
send_shape = torch.tensor(obj.size(), **tensor_kwargs)
14+
send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
15+
dist.send(send_ndims, next_rank)
16+
dist.send(send_shape, next_rank)
17+
18+
19+
def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
20+
"""Sends obj meta information before sending a specific obj.
21+
Since the recipient must know the shape of the obj in p2p communications,
22+
meta information of the obj should be sent before communications. This function
23+
synchronizes with :func:`recv_obj_meta`.
1724
1825
Args:
19-
tensor (:class:`torch.Tensor`): Tensor to be sent.
26+
obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
2027
need_meta (bool, optional): If False, meta information won't be sent.
2128
next_rank (int): The rank of the next member in pipeline parallel group.
2229
@@ -28,42 +35,57 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
2835
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
2936

3037
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
31-
32-
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
33-
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
34-
dist.send(send_ndims, next_rank)
35-
dist.send(send_shape, next_rank)
38+
if isinstance(obj, torch.Tensor):
39+
send_obj_nums = torch.tensor(1, **tensor_kwargs)
40+
dist.send(send_obj_nums, next_rank)
41+
send_meta_helper(obj, next_rank, tensor_kwargs)
42+
else:
43+
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
44+
dist.send(send_obj_nums, next_rank)
45+
for tensor_to_send in obj:
46+
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
3647

3748
return False
3849

3950

40-
def recv_tensor_meta(tensor_shape: TensorShape, prev_rank=None) -> torch.Size:
41-
"""Receives tensor meta information before receiving a specific tensor.
42-
Since the recipient must know the shape of the tensor in p2p communications,
43-
meta information of the tensor should be received before communications. This function
44-
synchronizes with :func:`send_tensor_meta`.
51+
def recv_meta_helper(prev_rank, tensor_kwargs):
52+
recv_ndims = torch.empty((), **tensor_kwargs)
53+
dist.recv(recv_ndims, prev_rank)
54+
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
55+
dist.recv(recv_shape, prev_rank)
56+
return recv_shape
57+
58+
59+
def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
60+
"""Receives obj meta information before receiving a specific obj.
61+
Since the recipient must know the shape of the obj in p2p communications,
62+
meta information of the obj should be received before communications. This function
63+
synchronizes with :func:`send_obj_meta`.
4564
4665
Args:
47-
tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
48-
prev_rank (int): The rank of the source of the tensor.
66+
obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
67+
prev_rank (int): The rank of the source of the obj.
4968
5069
Returns:
51-
:class:`torch.Size`: The shape of the tensor to be received.
70+
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
5271
"""
53-
if tensor_shape is None:
72+
if obj_shape is None:
5473
if prev_rank is None:
5574
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
5675

5776
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
58-
59-
recv_ndims = torch.empty((), **tensor_kwargs)
60-
dist.recv(recv_ndims, prev_rank)
61-
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
62-
dist.recv(recv_shape, prev_rank)
63-
64-
tensor_shape = torch.Size(recv_shape)
65-
66-
return tensor_shape
77+
recv_obj_nums = torch.empty((), **tensor_kwargs)
78+
dist.recv(recv_obj_nums, prev_rank)
79+
if recv_obj_nums.item() == 1:
80+
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
81+
obj_shape = torch.Size(recv_shape)
82+
else:
83+
obj_shape = []
84+
for i in range(recv_obj_nums.item()):
85+
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
86+
obj_shape.append(torch.Size(recv_shape))
87+
88+
return obj_shape
6789

6890

6991
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:

0 commit comments

Comments
 (0)