@@ -44,12 +44,13 @@ def _gather(input_, dim, process_group=None, cache=False):
4444 input_ = input_ .contiguous ()
4545 # Size and dimension.
4646 rank = dist .get_rank (process_group )
47+
48+ from axonn .intra_layer import ALL_GATHER_DTYPE
4749
4850 tensor_list = [
49- torch .empty_like (input_ ) for _ in range (dist .get_world_size (process_group ))
51+ torch .empty_like (input_ , dtype = ALL_GATHER_DTYPE ) for _ in range (dist .get_world_size (process_group ))
5052 ]
51- tensor_list [rank ] = input_
52- dist .all_gather (tensor_list , input_ , group = process_group )
53+ dist .all_gather (tensor_list , input_ .to (ALL_GATHER_DTYPE ), group = process_group )
5354
5455 # Note: torch.cat already creates a contiguous tensor.
5556 output = torch .cat (tensor_list , dim = dim ).contiguous ()
@@ -70,17 +71,20 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
7071 assert input_ .shape [dim ] % total_chunks == 0
7172 tensor_shape = list (input_ .shape )
7273 tensor_shape [dim ] //= total_chunks
74+
75+ from axonn .intra_layer import REDUCE_SCATTER_DTYPE
76+
7377 output = torch .empty (
74- tensor_shape , dtype = input_ . dtype , device = torch .cuda .current_device ()
78+ tensor_shape , dtype = REDUCE_SCATTER_DTYPE , device = torch .cuda .current_device ()
7579 )
7680
7781 if hasattr (torch .distributed , "reduce_scatter_tensor" ):
7882 handle = torch .distributed .reduce_scatter_tensor (
79- output , input_ , group = process_group , async_op = overlap_comm
83+ output , input_ . to ( REDUCE_SCATTER_DTYPE ) , group = process_group , async_op = overlap_comm
8084 )
8185 else :
8286 handle = torch .distributed ._reduce_scatter_base (
83- output , input_ , group = process_group , async_op = overlap_comm
87+ output , input_ . to ( REDUCE_SCATTER_DTYPE ) , group = process_group , async_op = overlap_comm
8488 )
8589
8690 if overlap_comm :
0 commit comments