Skip to content

Commit d5a8ec5

Browse files
committed
add dtypes for reduce scatters and all gathers
1 parent da66c0a commit d5a8ec5

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

axonn/intra_layer/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,19 @@ def gather(
4343
OVERLAP_REDUCE_SCATTER = False
4444
OVERLAP_ALL_REDUCE = False
4545
ALL_GATHER_ITERATOR = None
46-
ALL_GATHER_DTYPE = torch.bfloat16
46+
ALL_GATHER_DTYPE = torch.float32
47+
REDUCE_SCATTER_DTYPE = torch.float32
4748
handles = []
4849
pending_grad_accumulations = []
4950
weights_cache = {}
5051

52+
def set_all_gather_dtype(dtype):
53+
global ALL_GATHER_DTYPE
54+
ALL_GATHER_DTYPE = dtype
55+
56+
def set_reduce_scatter_dtype(dtype):
57+
global REDUCE_SCATTER_DTYPE
58+
REDUCE_SCATTER_DTYPE = dtype
5159

5260
def register_handle(handle):
5361
# ToDo: This might be unnecesary since

axonn/intra_layer/communication.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ def _gather(input_, dim, process_group=None, cache=False):
4444
input_ = input_.contiguous()
4545
# Size and dimension.
4646
rank = dist.get_rank(process_group)
47+
48+
from axonn.intra_layer import ALL_GATHER_DTYPE
4749

4850
tensor_list = [
49-
torch.empty_like(input_) for _ in range(dist.get_world_size(process_group))
51+
torch.empty_like(input_, dtype=ALL_GATHER_DTYPE) for _ in range(dist.get_world_size(process_group))
5052
]
5153
tensor_list[rank] = input_
52-
dist.all_gather(tensor_list, input_, group=process_group)
54+
dist.all_gather(tensor_list, input_.to(ALL_GATHER_DTYPE), group=process_group)
5355

5456
# Note: torch.cat already creates a contiguous tensor.
5557
output = torch.cat(tensor_list, dim=dim).contiguous()
@@ -70,17 +72,20 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
7072
assert input_.shape[dim] % total_chunks == 0
7173
tensor_shape = list(input_.shape)
7274
tensor_shape[dim] //= total_chunks
75+
76+
from axonn.intra_layer import REDUCE_SCATTER_DTYPE
77+
7378
output = torch.empty(
74-
tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device()
79+
tensor_shape, dtype=REDUCE_SCATTER_DTYPE, device=torch.cuda.current_device()
7580
)
7681

7782
if hasattr(torch.distributed, "reduce_scatter_tensor"):
7883
handle = torch.distributed.reduce_scatter_tensor(
79-
output, input_, group=process_group, async_op=overlap_comm
84+
output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm
8085
)
8186
else:
8287
handle = torch.distributed._reduce_scatter_base(
83-
output, input_, group=process_group, async_op=overlap_comm
88+
output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm
8489
)
8590

8691
if overlap_comm:

0 commit comments

Comments
 (0)