@@ -44,12 +44,14 @@ def _gather(input_, dim, process_group=None, cache=False):
4444 input_ = input_ .contiguous ()
4545 # Size and dimension.
4646 rank = dist .get_rank (process_group )
47+
48+ from axonn .intra_layer import ALL_GATHER_DTYPE
4749
4850 tensor_list = [
49- torch .empty_like (input_ ) for _ in range (dist .get_world_size (process_group ))
51+ torch .empty_like (input_ , dtype = ALL_GATHER_DTYPE ) for _ in range (dist .get_world_size (process_group ))
5052 ]
5153 tensor_list [rank ] = input_
52- dist .all_gather (tensor_list , input_ , group = process_group )
54+ dist .all_gather (tensor_list , input_ . to ( ALL_GATHER_DTYPE ) , group = process_group )
5355
5456 # Note: torch.cat already creates a contiguous tensor.
5557 output = torch .cat (tensor_list , dim = dim ).contiguous ()
@@ -70,17 +72,20 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
7072 assert input_ .shape [dim ] % total_chunks == 0
7173 tensor_shape = list (input_ .shape )
7274 tensor_shape [dim ] //= total_chunks
75+
76+ from axonn .intra_layer import REDUCE_SCATTER_DTYPE
77+
7378 output = torch .empty (
74- tensor_shape , dtype = input_ . dtype , device = torch .cuda .current_device ()
79+ tensor_shape , dtype = REDUCE_SCATTER_DTYPE , device = torch .cuda .current_device ()
7580 )
7681
7782 if hasattr (torch .distributed , "reduce_scatter_tensor" ):
7883 handle = torch .distributed .reduce_scatter_tensor (
79- output , input_ , group = process_group , async_op = overlap_comm
84+ output , input_ . to ( REDUCE_SCATTER_DTYPE ) , group = process_group , async_op = overlap_comm
8085 )
8186 else :
8287 handle = torch .distributed ._reduce_scatter_base (
83- output , input_ , group = process_group , async_op = overlap_comm
88+ output , input_ . to ( REDUCE_SCATTER_DTYPE ) , group = process_group , async_op = overlap_comm
8489 )
8590
8691 if overlap_comm :
0 commit comments