Skip to content

Commit 02bdd12

Browse files
Merge branch 'master' into summit_mismatch
2 parents 59fe0c2 + b418cf6 commit 02bdd12

File tree

10 files changed

+450
-182
lines changed

10 files changed

+450
-182
lines changed

deepspeed/module_inject/auto_tp.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Optional
1212
import torch
1313
from deepspeed import comm as dist
14-
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer
14+
from .layers import *
1515
from deepspeed.accelerator import get_accelerator
1616
from .fusedqkv_utils import require_tp_fused_qkvw
1717
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
@@ -211,7 +211,7 @@ def __init__(self,
211211
self.orig_layer_impl = orig_layer_impl
212212
self.linear_policies = None
213213
self.conv_linear_layer = False
214-
self.keep_module_on_host = keep_module_on_host
214+
TensorParallel_Layer.set_keep_module_on_host(keep_module_on_host)
215215

216216
def in_module_list(module, module_list):
217217
for item in module_list:
@@ -350,10 +350,7 @@ def _replace(self, child, name, conv_linear_layer):
350350
# and avoid any complex shard-related logic.
351351
if getattr(child, "replaced", False) == True:
352352
return
353-
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
354-
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
355-
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
356-
return_new_copy = not self.keep_module_on_host
353+
357354
weight_shape = child.weight.shape
358355
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
359356
# For TP layer skip, e.g., MoE gate, deepseek low rank layer skip

deepspeed/module_inject/layers.py

+56-32
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
from copy import deepcopy
1818
from typing import Union
1919

20+
__all__ = [
21+
"TensorParallel_Layer", "LinearAllreduce", "LinearLayer", "LmHeadLinearAllreduce", "Yuan_LinearAllreduce",
22+
"Yuan_LinearLayer", "GateUpPack_LinearLayer", "Conv_LinearALlreduce", "fused_LinearLayer", "conv_LinearLayer"
23+
]
24+
2025
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
2126
DS_IS_REPLACED_MODULE = 'ds_is_replaced_module'
2227
DS_TENSOR_MODEL_PARALLEL = 'tensor_model_parallel'
@@ -43,26 +48,6 @@ def set_autotp_mode(training=False):
4348
DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE
4449

4550

46-
def move(tensor, device):
47-
# TODO: consider the timing of deletion
48-
# to save host resources when DP > 1。
49-
50-
if tensor.is_meta:
51-
# Keep tensor in meta device if tensor is meta.
52-
return tensor
53-
else:
54-
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
55-
# Using copy=True instead of clone() will help in case of cpu --> cpu.
56-
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
57-
cloned_tensor = tensor.to(device, copy=True)
58-
59-
# free the memory of the original tensor to reduce memory peak
60-
# Equivalent to directly deleting the tensor reference outside the function.
61-
# see https://github.com/microsoft/DeepSpeed/pull/4353
62-
tensor.data = torch.empty(0, device=tensor.device)
63-
return cloned_tensor
64-
65-
6651
class RowParallel(torch.autograd.Function):
6752
"""
6853
A custom autograd function for performing row-wise parallelism.
@@ -140,6 +125,10 @@ class TensorParallel_Layer(nn.Module, ABC):
140125
name (Optional[str]): The name of the layer, if provided.
141126
"""
142127

128+
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
129+
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
130+
keep_module_on_host: bool = False
131+
143132
def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
144133
"""
145134
Initializes the TensorParallel_Layer with optional model parallelism group and layer name.
@@ -163,6 +152,16 @@ def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any):
163152
if kwargs.get('name') is not None:
164153
self.name = kwargs.get('name') # Set the layer name if provided.
165154

155+
@classmethod
156+
def set_keep_module_on_host(cls, value: bool):
157+
"""
158+
Set the static variable keep_module_on_host.
159+
160+
Args:
161+
value (bool): The new value for keep_module_on_host.
162+
"""
163+
cls.keep_module_on_host = value
164+
166165
@abstractmethod
167166
def forward(self, input):
168167
"""
@@ -235,6 +234,31 @@ def extra_repr(self):
235234
in_features, out_features, self.bias is not None, dtype)
236235
return extra_repr_str
237236

237+
def move(self, tensor):
238+
# TODO: consider the timing of deletion
239+
# to save host resources when DP > 1。
240+
241+
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
242+
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
243+
if tensor.is_meta:
244+
# Keep tensor in meta device if tensor is meta.
245+
return tensor
246+
else:
247+
device = 'cpu' if self.__class__.keep_module_on_host else get_accelerator().current_device_name()
248+
return_new_copy = not self.__class__.keep_module_on_host
249+
250+
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
251+
# Using copy=True instead of clone() will help in case of cpu --> cpu.
252+
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
253+
cloned_tensor = tensor.to(device, copy=return_new_copy)
254+
255+
if return_new_copy:
256+
# free the memory of the original tensor to reduce memory peak
257+
# Equivalent to directly deleting the tensor reference outside the function.
258+
# see https://github.com/microsoft/DeepSpeed/pull/4353
259+
tensor.data = torch.empty(0, device=tensor.device)
260+
return cloned_tensor
261+
238262

239263
class GatherReplacedLayerParams:
240264
"""
@@ -349,7 +373,7 @@ def _tp_partition(self, params_list):
349373
return
350374
_partition = torch.chunk(param, self.tp_world_size, dim=-1)[self.tp_index]
351375

352-
_partition = move(_partition, get_accelerator().current_device_name()).detach()
376+
_partition = self.move(_partition).detach()
353377

354378
params_list[idx].data = _partition
355379

@@ -363,7 +387,7 @@ def uneven_partition(self, params_list):
363387
self.name),
364388
dim=1)[self.tp_index]
365389

366-
_partition = move(_partition, get_accelerator().current_device_name()).detach()
390+
_partition = self.move(_partition).detach()
367391
params_list[idx].data = _partition
368392

369393

@@ -414,7 +438,7 @@ def _tp_partition(self, params_list):
414438
#split bias if provide
415439
_partition = torch.chunk(param, self.tp_world_size, dim=0)[self.tp_index]
416440

417-
_partition = move(_partition, get_accelerator().current_device_name()).detach()
441+
_partition = self.move(_partition).detach()
418442

419443
params_list[idx].data = _partition
420444

@@ -429,7 +453,7 @@ def uneven_partition(self, params_list):
429453
self.name),
430454
dim=0)[self.tp_index]
431455

432-
_partition = move(_partition, get_accelerator().current_device_name()).detach()
456+
_partition = self.move(_partition).detach()
433457

434458
params_list[idx].data = _partition
435459

@@ -475,7 +499,7 @@ def _tp_partition(self, params_list):
475499

476500
_partition = prepare_tp_fused_qkvw(self.fused_module.module, param, self.tp_world_size, self.tp_index)
477501

478-
_partition = move(_partition, get_accelerator().current_device_name()).detach()
502+
_partition = self.move(_partition).detach()
479503

480504
params_list[idx].data = _partition
481505

@@ -492,13 +516,13 @@ def _tp_partition(self, params_list):
492516
weight, bias = params_list[0], params_list[1]
493517
_partition = weight.data.split(get_shard_size_list(weight.shape[0], self.tp_world_size, self.name),
494518
dim=1)[self.tp_index]
495-
_partition = move(_partition, get_accelerator().current_device_name()).detach()
519+
_partition = self.move(_partition).detach()
496520
weight.data = _partition
497521

498522
if bias is not None:
499523
_partition = bias.data.split(get_shard_size_list(weight.shape[1], self.tp_world_size, self.name),
500524
dim=0)[self.tp_index]
501-
_partition = move(_partition, get_accelerator().current_device_name()).detach()
525+
_partition = self.move(_partition).detach()
502526

503527
bias.data = _partition
504528

@@ -522,19 +546,19 @@ class Yuan_LinearLayer(LinearLayer):
522546
def _tp_partition(self, params_list):
523547
weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index,
524548
self.tp_world_size, True)
525-
params_list[0].data = move(weight, get_accelerator().current_device_name()).detach()
549+
params_list[0].data = self.move(weight).detach()
526550
if bias is not None:
527-
params_list[1].data = move(bias, get_accelerator().current_device_name()).detach()
551+
params_list[1].data = self.move(bias).detach()
528552

529553

530554
class GateUpPack_LinearLayer(LinearLayer):
531555
# chatGLM2, chatGLM2
532556
@torch.no_grad()
533557
def _tp_partition(self, params_list):
534558
weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size)
535-
params_list[0].data = move(weight, device=get_accelerator().current_device_name()).detach()
559+
params_list[0].data = self.move(weight).detach()
536560
if bias is not None:
537-
params_list[1].data = move(bias, device=get_accelerator().current_device_name()).detach()
561+
params_list[1].data = self.move(bias).detach()
538562

539563

540564
class Conv_LinearALlreduce(LinearAllreduce):
@@ -549,7 +573,7 @@ def _tp_partition(self, params_list):
549573
_partition = param.split(get_shard_size_list(param.shape[0], self.tp_world_size, self.name),
550574
dim=1)[self.tp_index]
551575

552-
_partition = move(_partition, get_accelerator().current_device_name()).detach()
576+
_partition = self.move(_partition).detach()
553577

554578
params_list[idx].data = _partition
555579

deepspeed/runtime/bf16_optimizer.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self,
4444
timers=None,
4545
grad_acc_dtype=None,
4646
graph_harvesting=False,
47-
immediate_grad_update=False,
47+
immediate_grad_update=True,
4848
has_moe_layers=False):
4949
super().__init__()
5050
see_memory_usage('begin bf16_optimizer', force=True)
@@ -313,7 +313,7 @@ def step(self, closure=None):
313313

314314
self.clear_hp_grads()
315315

316-
def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
316+
def backward(self, loss, retain_graph=False, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
317317
"""Perform a backward pass and copy the low-precision gradients to the
318318
high-precision copy.
319319
@@ -323,7 +323,7 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg
323323
The low-precision grads are deallocated during this procedure.
324324
"""
325325
self.clear_lp_grads()
326-
loss.backward(**bwd_kwargs)
326+
loss.backward(retain_graph=retain_graph, **bwd_kwargs)
327327

328328
if update_hp_grads:
329329
self.update_hp_grads(clear_lp_grads=clear_lp_grads)
@@ -425,9 +425,6 @@ def update_lp_params(self):
425425
fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
426426
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
427427
bf16_partitions[partition_id].data.copy_(fp32_partition.data)
428-
# print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
429-
# if i == 0:
430-
# print_rank_0(f'{fp32_partition[:10]=}', force=True)
431428

432429
all_gather_dp_groups(groups_flat=self.bf16_groups_flat,
433430
partitioned_param_groups=self.bf16_partitioned_groups,
@@ -442,10 +439,12 @@ def clear_hp_grads(self):
442439
for i, group in enumerate(self.fp32_groups_gradients):
443440
self.fp32_groups_has_gradients[i] = [False] * len(group)
444441

445-
def clear_lp_grads(self):
442+
def clear_lp_grads(self, set_to_none=False):
446443

447444
# using zero_() fixed memory address for graph replay
448-
set_to_none = False if self.graph_harvesting else True
445+
if self.graph_harvesting:
446+
assert not set_to_none, "graph harvesting is incompatible with setting lp grads to None"
447+
449448
zero_grads_list = []
450449
for group in self.bf16_groups:
451450
for param in group:
@@ -458,6 +457,10 @@ def clear_lp_grads(self):
458457
if not set_to_none and len(zero_grads_list) > 0:
459458
torch._foreach_zero_(zero_grads_list)
460459

460+
def zero_grad(self, set_to_none=True):
461+
self.clear_lp_grads(set_to_none)
462+
self.clear_hp_grads()
463+
461464
def state_dict(self):
462465
state_dict = {}
463466
state_dict[CLIP_GRAD] = self.clip_grad

deepspeed/runtime/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128

129129
# BFLOAT16 optimizer immediate gradient update
130130
BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update"
131-
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False
131+
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = True
132132

133133
#########################################
134134
# FP16 support

0 commit comments

Comments
 (0)