Skip to content

Commit 8d1bc0a

Browse files
authored
Update torch.norm to torch.linalg.norm and torch.linalg.vector_norm (#6931)
- [x] Update PR since `torch.norm` and `torch.linalg.norm` have [different function signatures](https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm). - [x] Check if there are any numeric differences between the functions. - [x] Determine why there appear to be performance differences from others [here](pytorch/pytorch#136360). - [x] Update to `torch.linalg.vectornorm` Follow up PR handles these in the comm folder: #6960
1 parent bc76b04 commit 8d1bc0a

File tree

5 files changed

+7
-6
lines changed

5 files changed

+7
-6
lines changed

deepspeed/runtime/comm/compressed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro
9696

9797
compensated_server_m.add_(server_error)
9898

99-
server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
99+
server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
100100

101101
server_error.set_(compensated_server_m -
102102
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

deepspeed/runtime/comm/hccl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro
8383

8484
compensated_server_m.add_(server_error)
8585

86-
server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
86+
server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
8787

8888
server_error.set_(compensated_server_m -
8989
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

deepspeed/runtime/fp16/onebit/lamb.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def step(self, closure=None, grads=None):
177177
# This is used to reduce compression error during compression stage.
178178
momentum_scales = []
179179
for group in self.param_groups:
180-
momentum_scales.append([(torch.linalg.norm(self.state[p]['exp_avg']) /
180+
momentum_scales.append([(torch.linalg.vector_norm(self.state[p]['exp_avg']) /
181181
np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
182182
for p in group['params']])
183183
united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])

deepspeed/runtime/zero/stage3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2101,7 +2101,7 @@ def step(self, closure=None):
21012101
return
21022102

21032103
norm_groups = self._get_norm_groups()
2104-
scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups))
2104+
scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups))
21052105

21062106
# Stash unscaled gradient norm
21072107
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale

deepspeed/runtime/zero/stage_1_and_2.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1691,7 +1691,8 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
16911691
continue
16921692
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
16931693
all_norms.append(
1694-
torch.norm(g.data.double().detach(), norm_type).to(get_accelerator().current_device_name()))
1694+
torch.linalg.vector_norm(g.data.double().detach(),
1695+
ord=norm_type).to(get_accelerator().current_device_name()))
16951696
if len(all_norms) > 0:
16961697
total_norm = torch.stack(all_norms).square().sum().float()
16971698
else:
@@ -1795,7 +1796,7 @@ def scaled_global_norm(self, norm_type=2):
17951796
self._average_expert_grad_norms(norm_groups)
17961797

17971798
# calculating L2 norm
1798-
return torch.norm(torch.stack(norm_groups), p=norm_type)
1799+
return torch.linalg.vector_norm(torch.stack(norm_groups), ord=norm_type)
17991800

18001801
def get_bit16_param_group(self, group_no):
18011802
bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]

0 commit comments

Comments
 (0)