Skip to content

Commit c17dc33

Browse files
authored
Using explicit GPU upcast for ZeRO-Offload (#6962)
Following discussion in [PR-6670](#6670), the explict upcast is much more efficient than implicit upcast, this PR is to replace implicit upcast with explict one. The results on 3B model are shown below: | Option | BWD (ms) | Speed up | |------------|-----|------| | Before PR-6670 | 25603.30 | 1x | | After PR-6670 | 1174.31 | 21.8X | | After this PR| 309.2 | 82.8X |
1 parent 8d1bc0a commit c17dc33

File tree

1 file changed

+2
-11
lines changed

1 file changed

+2
-11
lines changed

deepspeed/runtime/zero/stage3.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -546,15 +546,10 @@ def _setup_for_real_optimizer(self):
546546
self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer)
547547

548548
offset = 0
549-
max_partition_numel = 0
550549
for param in all_params:
551550
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
552551
0, offset, param.partition_numel())
553552
offset += param.partition_numel()
554-
max_partition_numel = max(max_partition_numel, param.partition_numel())
555-
if self.offload_optimizer:
556-
self.pinned_grad_buffer: Tensor = get_accelerator().pin_memory(
557-
torch.empty(max_partition_numel, device=self.device))
558553

559554
def _link_all_hp_params(self):
560555
for p in self.module.parameters():
@@ -1510,13 +1505,9 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
15101505
offload_fp32_gradients[i].append(grad_buffer.float())
15111506
offload_fp32_offsets[i].append(dest_offset)
15121507
else:
1513-
buffer_numel = grad_buffer.numel()
15141508
fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow(
1515-
0, dest_offset, buffer_numel)
1516-
self.pinned_grad_buffer[:buffer_numel].copy_(
1517-
grad_buffer.to(dtype=torch.float32, non_blocking=True))
1518-
get_accelerator().synchronize()
1519-
fp32_grad_tensor.copy_(self.pinned_grad_buffer[:buffer_numel], non_blocking=True)
1509+
0, dest_offset, grad_buffer.numel())
1510+
fp32_grad_tensor.copy_(grad_buffer.float())
15201511

15211512
# free the gradient
15221513
if not get_accelerator().is_synchronized_device():

0 commit comments

Comments
 (0)