Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix, pipeline model with moe cause error when send grad #7055

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ def _recv_tensor_meta(self, send_stage):
recv_dtype = self.ID_TO_DTYPE[buffer[1].item()]
recv_ndims = buffer[2].item()
recv_shape = buffer[3:3 + recv_ndims].tolist()
return self._allocate_or_extend_buffers(0, recv_shape, recv_dtype)
return self._allocate_buffer(recv_shape, dtype=recv_dtype, num_buffers=1)[0]

# List or tuple of tensors (recv_type == 1 (list) is currently unused)
elif recv_type == 1 or recv_type == 2:
Expand All @@ -1006,7 +1006,7 @@ def _recv_tensor_meta(self, send_stage):
recv_shape = buffer[offset + 2:offset + 2 + recv_ndims].tolist()
offset += 2 + recv_ndims

buffers.append(self._allocate_or_extend_buffers(idx, recv_shape, recv_dtype))
buffers.append(self._allocate_buffer(recv_shape, dtype=recv_dtype, num_buffers=1)[0])

# Convert to tuples if requested.
if recv_type == 2:
Expand Down Expand Up @@ -1172,7 +1172,7 @@ def _exec_recv_grads(self, buffer_id):
# Allocate gradient if necessary
if self.dynamic_shape or self.grad_layer is None:
if isinstance(outputs, torch.Tensor):
self.grad_layer = self._allocate_or_extend_buffers(0, list(outputs.size()), outputs.dtype)
self.grad_layer = self._allocate_buffer(list(outputs.size()), dtype=outputs.dtype, num_buffers=1)[0]
else:
# XXX This is a HACK
# When we exchange activations/gradients, the two pipe stages
Expand All @@ -1196,7 +1196,7 @@ def _exec_recv_grads(self, buffer_id):
sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()]

self.grad_layer = [
self._allocate_or_extend_buffers(i, size, dtype)
self._allocate_buffer(size, dtype=dtype, num_buffers=1)[0]
for i, (size, dtype) in enumerate(sizes_and_dtypes)
]

Expand Down Expand Up @@ -1279,18 +1279,6 @@ def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):
buffers.append(self._allocate_zeros(shape, **kwargs))
return buffers

def _allocate_or_extend_buffers(self, idx, shape, dtype):
numel = reduce(mul, shape) if len(shape) > 0 else 1
if len(self._grad_layer_buf) <= idx or self._grad_layer_buf[idx].numel() < numel:
new_buf = self._allocate_buffer(shape, dtype=dtype, num_buffers=1)[0]
if len(self._grad_layer_buf) <= idx:
self._grad_layer_buf.append(new_buf)
else:
self._grad_layer_buf[idx] = new_buf
return self._grad_layer_buf[idx]
else:
return self._grad_layer_buf[idx].flatten()[:numel].view(shape)

def forward(self, *args, **kwargs):
"""Disabled for pipeline parallel training. See ``train_batch()``. """
raise PipelineError("Only train_batch() is accessible in pipeline mode.")
Expand Down
Loading