Skip to content

Commit e212829

Browse files
committed
make input and output buffer dtypes same
1 parent 9dd4845 commit e212829

File tree

3 files changed

+20
-7
lines changed

3 files changed

+20
-7
lines changed

axonn/intra_layer/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,19 @@ def gather(
4343
OVERLAP_REDUCE_SCATTER = False
4444
OVERLAP_ALL_REDUCE = False
4545
ALL_GATHER_ITERATOR = None
46-
ALL_GATHER_DTYPE = torch.bfloat16
46+
ALL_GATHER_DTYPE = torch.float32
47+
REDUCE_SCATTER_DTYPE = torch.bfloat16
4748
handles = []
4849
pending_grad_accumulations = []
4950
weights_cache = {}
5051

52+
def set_all_gather_dtype(dtype):
53+
global ALL_GATHER_DTYPE
54+
ALL_GATHER_DTYPE = dtype
55+
56+
def set_reduce_scatter_dtype(dtype):
57+
global REDUCE_SCATTER_DTYPE
58+
REDUCE_SCATTER_DTYPE = dtype
5159

5260
def register_handle(handle):
5361
# ToDo: This might be unnecesary since

axonn/intra_layer/communication.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ def _gather(input_, dim, process_group=None, cache=False):
4444
input_ = input_.contiguous()
4545
# Size and dimension.
4646
rank = dist.get_rank(process_group)
47+
48+
from axonn.intra_layer import ALL_GATHER_DTYPE
4749

4850
tensor_list = [
49-
torch.empty_like(input_) for _ in range(dist.get_world_size(process_group))
51+
torch.empty_like(input_, dtype=ALL_GATHER_DTYPE) for _ in range(dist.get_world_size(process_group))
5052
]
51-
tensor_list[rank] = input_
52-
dist.all_gather(tensor_list, input_, group=process_group)
53+
dist.all_gather(tensor_list, input_.to(ALL_GATHER_DTYPE), group=process_group)
5354

5455
# Note: torch.cat already creates a contiguous tensor.
5556
output = torch.cat(tensor_list, dim=dim).contiguous()
@@ -70,17 +71,20 @@ def _reduce_scatter(input_, dim, process_group=None, overlap_comm=False):
7071
assert input_.shape[dim] % total_chunks == 0
7172
tensor_shape = list(input_.shape)
7273
tensor_shape[dim] //= total_chunks
74+
75+
from axonn.intra_layer import REDUCE_SCATTER_DTYPE
76+
7377
output = torch.empty(
74-
tensor_shape, dtype=input_.dtype, device=torch.cuda.current_device()
78+
tensor_shape, dtype=REDUCE_SCATTER_DTYPE, device=torch.cuda.current_device()
7579
)
7680

7781
if hasattr(torch.distributed, "reduce_scatter_tensor"):
7882
handle = torch.distributed.reduce_scatter_tensor(
79-
output, input_, group=process_group, async_op=overlap_comm
83+
output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm
8084
)
8185
else:
8286
handle = torch.distributed._reduce_scatter_base(
83-
output, input_, group=process_group, async_op=overlap_comm
87+
output, input_.to(REDUCE_SCATTER_DTYPE), group=process_group, async_op=overlap_comm
8488
)
8589

8690
if overlap_comm:

axonn/intra_layer/fully_connected.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def forward(
7979
ctx.backward_comm_async = backward_comm_async
8080
if not forward_comm_async:
8181
output = input_.matmul(weight.t())
82+
8283
dist.all_reduce(output, group=forward_all_reduce_group, async_op=False)
8384
else:
8485
assert input_.shape[0] % 2 == 0

0 commit comments

Comments
 (0)