Skip to content

Commit 9f5b8c0

Browse files
committed
Merge branch 'master' into tohtana/support_autocast
2 parents 2016995 + b418cf6 commit 9f5b8c0

File tree

8 files changed

+398
-150
lines changed

8 files changed

+398
-150
lines changed

deepspeed/runtime/bf16_optimizer.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self,
4444
timers=None,
4545
grad_acc_dtype=None,
4646
graph_harvesting=False,
47-
immediate_grad_update=False,
47+
immediate_grad_update=True,
4848
has_moe_layers=False):
4949
super().__init__()
5050
see_memory_usage('begin bf16_optimizer', force=True)
@@ -313,7 +313,7 @@ def step(self, closure=None):
313313

314314
self.clear_hp_grads()
315315

316-
def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
316+
def backward(self, loss, retain_graph=False, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
317317
"""Perform a backward pass and copy the low-precision gradients to the
318318
high-precision copy.
319319
@@ -323,7 +323,7 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg
323323
The low-precision grads are deallocated during this procedure.
324324
"""
325325
self.clear_lp_grads()
326-
loss.backward(**bwd_kwargs)
326+
loss.backward(retain_graph=retain_graph, **bwd_kwargs)
327327

328328
if update_hp_grads:
329329
self.update_hp_grads(clear_lp_grads=clear_lp_grads)
@@ -425,9 +425,6 @@ def update_lp_params(self):
425425
fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
426426
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
427427
bf16_partitions[partition_id].data.copy_(fp32_partition.data)
428-
# print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
429-
# if i == 0:
430-
# print_rank_0(f'{fp32_partition[:10]=}', force=True)
431428

432429
all_gather_dp_groups(groups_flat=self.bf16_groups_flat,
433430
partitioned_param_groups=self.bf16_partitioned_groups,
@@ -442,10 +439,12 @@ def clear_hp_grads(self):
442439
for i, group in enumerate(self.fp32_groups_gradients):
443440
self.fp32_groups_has_gradients[i] = [False] * len(group)
444441

445-
def clear_lp_grads(self):
442+
def clear_lp_grads(self, set_to_none=False):
446443

447444
# using zero_() fixed memory address for graph replay
448-
set_to_none = False if self.graph_harvesting else True
445+
if self.graph_harvesting:
446+
assert not set_to_none, "graph harvesting is incompatible with setting lp grads to None"
447+
449448
zero_grads_list = []
450449
for group in self.bf16_groups:
451450
for param in group:
@@ -458,6 +457,10 @@ def clear_lp_grads(self):
458457
if not set_to_none and len(zero_grads_list) > 0:
459458
torch._foreach_zero_(zero_grads_list)
460459

460+
def zero_grad(self, set_to_none=True):
461+
self.clear_lp_grads(set_to_none)
462+
self.clear_hp_grads()
463+
461464
def state_dict(self):
462465
state_dict = {}
463466
state_dict[CLIP_GRAD] = self.clip_grad

deepspeed/runtime/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128

129129
# BFLOAT16 optimizer immediate gradient update
130130
BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update"
131-
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False
131+
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = True
132132

133133
#########################################
134134
# FP16 support

deepspeed/runtime/engine.py

+108-81
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ def __init__(self,
273273
# Configure distributed model
274274
self._configure_distributed_model(model)
275275

276+
self.module_forward_pre_hook = self._create_module_forward_pre_hook()
277+
self.module_forward_post_hook = self._create_module_forward_post_hook()
278+
276279
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
277280
self.param_names = {param: name for name, param in model.named_parameters()}
278281

@@ -1889,7 +1892,6 @@ def deepspeed_io(self,
18891892
GLOBAL_RANK: self.global_rank,
18901893
DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS]
18911894
}
1892-
18931895
return DeepSpeedDataLoader(dataset=dataset,
18941896
batch_size=batch_size,
18951897
pin_memory=pin_memory,
@@ -1936,17 +1938,24 @@ def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None):
19361938

19371939
return scaled_loss
19381940

1939-
@instrument_w_nvtx
1940-
def forward(self, *inputs, **kwargs):
1941-
r"""Execute forward propagation
1942-
Arguments:
1943-
*inputs: Variable length input list
1944-
**kwargs: variable length keyword arguments
1945-
"""
1941+
def _create_module_forward_pre_hook(self):
19461942

1947-
if self.autotuning_profile_model_info():
1948-
ma = get_ma_status()
1949-
else:
1943+
def _module_forward_pre_hook(module, inputs, kwargs):
1944+
return self._forward_prologue(inputs, kwargs)
1945+
1946+
return self.module.register_forward_pre_hook(_module_forward_pre_hook, prepend=False, with_kwargs=True)
1947+
1948+
def _create_module_forward_post_hook(self):
1949+
1950+
def _module_forward_post_hook(module, input, output):
1951+
self._forward_epilogue()
1952+
1953+
return self.module.register_forward_hook(_module_forward_post_hook)
1954+
1955+
def _forward_prologue(self, inputs, kwargs):
1956+
return_modified = False
1957+
1958+
if not self.autotuning_profile_model_info():
19501959
see_memory_usage("Engine before forward", force=self.memory_breakdown())
19511960

19521961
flops_profiler_active = (self.flops_profiler_enabled()
@@ -1965,61 +1974,84 @@ def forward(self, *inputs, **kwargs):
19651974
self.eigenvalue_enabled(),
19661975
None,
19671976
)
1977+
return_modified = True
19681978

19691979
if flops_profiler_active:
19701980
self.flops_profiler.start_profile(ignore_list=None)
19711981

1972-
if self.module.training:
1973-
if self.progressive_layer_drop:
1974-
kwargs.update(self.progressive_layer_drop.get_state())
1982+
if kwargs is not None:
1983+
if self.module.training:
1984+
if self.progressive_layer_drop:
1985+
kwargs.update(self.progressive_layer_drop.get_state())
19751986

1976-
if self.__class__.__name__ != "PipelineEngine":
1977-
# TODO: The above if condition is a HACK since for PipelineEngine
1978-
# it's difficult to inject argument in forward pass.
1979-
if self.module.training and self.curriculum_enabled_legacy():
1980-
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
1981-
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
1982-
kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})
1987+
if self.__class__.__name__ != "PipelineEngine":
1988+
# TODO: The above if condition is a HACK since for PipelineEngine
1989+
# it's difficult to inject argument in forward pass.
1990+
if self.module.training and self.curriculum_enabled_legacy():
1991+
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
1992+
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
1993+
kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})
1994+
return_modified = True
19831995

19841996
if self.module.training and self.random_ltd_enabled():
19851997
self.random_ltd_scheduler.update_seq(self.global_steps)
19861998

1999+
if self.training_dataloader is None:
2000+
self.tput_timer.start()
2001+
2002+
self._start_timers(self.engine_timers.forward_timers)
2003+
19872004
if self.zero_optimization_partition_weights():
19882005
# Enable automated discovery of external parameters by indicating that
19892006
# we are in a forward pass.
19902007
for module in self.module.modules():
19912008
module._parameters._in_forward = True
19922009

1993-
self._start_timers(self.engine_timers.forward_timers)
1994-
1995-
if self.training_dataloader is None:
1996-
self.tput_timer.start()
1997-
19982010
if self.fp16_auto_cast():
19992011
inputs = self._cast_inputs_half(inputs)
2012+
return_modified = True
20002013

2001-
with torch.autocast(device_type=get_accelerator().device_name(),
2002-
dtype=self.torch_autocast_dtype(),
2003-
enabled=self.torch_autocast_enabled()):
2004-
loss = self.module(*inputs, **kwargs)
2014+
if return_modified:
2015+
return inputs, kwargs
20052016

2017+
def _forward_epilogue(self):
20062018
if self.zero_optimization_partition_weights():
20072019
# Disable automated discovery of external parameters
20082020
for module in self.module.modules():
20092021
module._parameters._in_forward = False
20102022

20112023
self._stop_timers(self.engine_timers.forward_timers)
20122024

2025+
flops_profiler_active = (self.flops_profiler_enabled()
2026+
and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0)
2027+
20132028
if flops_profiler_active:
20142029
self.flops_profiler.stop_profile()
20152030

2031+
if not self.autotuning_profile_model_info():
2032+
see_memory_usage("Engine after forward", force=self.memory_breakdown())
2033+
2034+
@instrument_w_nvtx
2035+
def forward(self, *inputs, **kwargs):
2036+
r"""Execute forward propagation
2037+
Arguments:
2038+
*inputs: Variable length input list
2039+
**kwargs: variable length keyword arguments
2040+
"""
2041+
if self.autotuning_profile_model_info():
2042+
ma = get_ma_status()
2043+
2044+
with torch.autocast(device_type=get_accelerator().device_name(),
2045+
dtype=self.torch_autocast_dtype(),
2046+
enabled=self.torch_autocast_enabled()):
2047+
loss = self.module(*inputs, **kwargs)
2048+
20162049
if self.autotuning_profile_model_info():
20172050
activation_mem = get_ma_status() - ma
20182051
self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem
20192052
print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path())
20202053
exit()
2021-
else:
2022-
see_memory_usage("Engine after forward", force=self.memory_breakdown())
2054+
20232055
return loss
20242056

20252057
def _cast_inputs_half(self, inputs):
@@ -2078,43 +2110,13 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
20782110
grads = None
20792111
self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size)
20802112

2081-
@contextmanager
2082-
def no_sync(self):
2083-
r"""
2084-
Context manager to disable gradient reduction during backward pass.
2085-
This context manager has the following effects on other DeepSpeed features.
2086-
1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2087-
2. It is illegal to call engine.step() within the context manager.
2088-
3. Tracking of gradient accumulation steps is disabled.
2089-
"""
2090-
assert not self.zero_optimization_partition_gradients(), \
2091-
f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}"
2092-
2093-
assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported"
2094-
2095-
self.inside_no_sync_ctxt = True
2096-
try:
2097-
yield
2098-
finally:
2099-
self.inside_no_sync_ctxt = False
2100-
2101-
@instrument_w_nvtx
2102-
def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=True):
2103-
r"""Execute backward pass on the loss
2104-
Arguments:
2105-
loss: Torch tensor on which to execute backward propagation
2106-
retain_graph: bool, default: false
2107-
forward on user defined choice of retain_graph
2108-
"""
2109-
2113+
def _backward_prologue(self, loss, scale_wrt_gas=True):
21102114
see_memory_usage("Engine before backward", force=self.memory_breakdown())
2111-
21122115
if self.scale_wrt_gas is not None:
21132116
scale_wrt_gas = self.scale_wrt_gas
21142117

2115-
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt
2116-
21172118
# scale loss w.r.t. gradient accumulation if reduction is not disabled
2119+
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt
21182120
if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
21192121
loss = self._scale_loss_by_gas(loss.float())
21202122

@@ -2131,13 +2133,18 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
21312133
)]
21322134
self.monitor.write_events(self.summary_events)
21332135

2134-
self._start_timers(self.engine_timers.backward_timers)
2136+
return loss
21352137

2136-
assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
2137-
"must provide optimizer during init in order to use backward"
2138+
def _backward_epilogue(self):
2139+
self._start_timers(self.engine_timers.backward_reduce_timers)
2140+
if self.enable_backward_allreduce and not self.inside_no_sync_ctxt:
2141+
# Traditional code path that allreduces the module parameter grads
2142+
self.allreduce_gradients()
2143+
self._stop_timers(self.engine_timers.backward_reduce_timers)
2144+
see_memory_usage("Engine after backward", force=self.memory_breakdown())
21382145

2146+
def _do_optimizer_backward(self, loss, retain_graph):
21392147
self._start_timers(self.engine_timers.backward_inner_timers)
2140-
21412148
if self.zero_optimization():
21422149
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
21432150
self.optimizer.backward(loss, retain_graph=retain_graph)
@@ -2153,30 +2160,50 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
21532160
else:
21542161
self.optimizer.backward(loss, retain_graph=retain_graph)
21552162
elif self.bfloat16_enabled():
2156-
self.optimizer.backward(loss)
2163+
self.optimizer.backward(loss, retain_graph=retain_graph)
21572164
else:
21582165
if self.eigenvalue_enabled():
21592166
loss.backward(create_graph=True, retain_graph=True)
21602167
else:
21612168
loss.backward(retain_graph=retain_graph)
2162-
21632169
self._stop_timers(self.engine_timers.backward_inner_timers)
21642170

2165-
self._start_timers(self.engine_timers.backward_reduce_timers)
2166-
2167-
if do_gradient_reduction:
2168-
# Traditional code path that allreduces the module parameter grads
2169-
self.allreduce_gradients()
2171+
@contextmanager
2172+
def no_sync(self):
2173+
r"""
2174+
Context manager to disable gradient reduction during backward pass.
2175+
This context manager has the following effects on other DeepSpeed features:
2176+
1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2177+
2. It is illegal to call engine.step() within the context manager.
2178+
3. Tracking of gradient accumulation steps is disabled.
2179+
"""
2180+
assert not self.zero_optimization_partition_gradients(), \
2181+
f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}"
21702182

2171-
self._stop_timers(self.engine_timers.backward_reduce_timers)
2183+
assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported"
21722184

2173-
self._stop_timers(self.engine_timers.backward_timers)
2185+
self.inside_no_sync_ctxt = True
2186+
try:
2187+
yield
2188+
finally:
2189+
self.inside_no_sync_ctxt = False
21742190

2175-
if release_loss:
2176-
# loss.data = None
2177-
pass
2191+
@instrument_w_nvtx
2192+
def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
2193+
r"""Execute backward pass on the loss
2194+
Arguments:
2195+
loss: Torch tensor on which to execute backward propagation
2196+
retain_graph: bool, default: false
2197+
forward on user defined choice of retain_graph
2198+
"""
2199+
assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
2200+
"must provide optimizer during init in order to use backward"
21782201

2179-
see_memory_usage("Engine after backward", force=self.memory_breakdown())
2202+
self._start_timers(self.engine_timers.backward_timers)
2203+
loss = self._backward_prologue(loss, scale_wrt_gas)
2204+
self._do_optimizer_backward(loss, retain_graph)
2205+
self._backward_epilogue()
2206+
self._stop_timers(self.engine_timers.backward_timers)
21802207

21812208
return loss
21822209

0 commit comments

Comments
 (0)