Skip to content

Commit d56b5d3

Browse files
Merge branch 'master' into add_conditional_expression
2 parents 154dfc9 + b418cf6 commit d56b5d3

File tree

8 files changed

+391
-144
lines changed

8 files changed

+391
-144
lines changed

deepspeed/runtime/bf16_optimizer.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self,
4444
timers=None,
4545
grad_acc_dtype=None,
4646
graph_harvesting=False,
47-
immediate_grad_update=False,
47+
immediate_grad_update=True,
4848
has_moe_layers=False):
4949
super().__init__()
5050
see_memory_usage('begin bf16_optimizer', force=True)
@@ -313,7 +313,7 @@ def step(self, closure=None):
313313

314314
self.clear_hp_grads()
315315

316-
def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
316+
def backward(self, loss, retain_graph=False, update_hp_grads=True, clear_lp_grads=False, **bwd_kwargs):
317317
"""Perform a backward pass and copy the low-precision gradients to the
318318
high-precision copy.
319319
@@ -323,7 +323,7 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg
323323
The low-precision grads are deallocated during this procedure.
324324
"""
325325
self.clear_lp_grads()
326-
loss.backward(**bwd_kwargs)
326+
loss.backward(retain_graph=retain_graph, **bwd_kwargs)
327327

328328
if update_hp_grads:
329329
self.update_hp_grads(clear_lp_grads=clear_lp_grads)
@@ -425,9 +425,6 @@ def update_lp_params(self):
425425
fp32_partition) in enumerate(zip(self.bf16_partitioned_groups, self.fp32_groups_flat_partition)):
426426
partition_id = dist.get_rank(group=self.real_dp_process_group[i])
427427
bf16_partitions[partition_id].data.copy_(fp32_partition.data)
428-
# print_rank_0(f'update_lp_params {i=} {partition_id=}', force=True)
429-
# if i == 0:
430-
# print_rank_0(f'{fp32_partition[:10]=}', force=True)
431428

432429
all_gather_dp_groups(groups_flat=self.bf16_groups_flat,
433430
partitioned_param_groups=self.bf16_partitioned_groups,
@@ -442,10 +439,12 @@ def clear_hp_grads(self):
442439
for i, group in enumerate(self.fp32_groups_gradients):
443440
self.fp32_groups_has_gradients[i] = [False] * len(group)
444441

445-
def clear_lp_grads(self):
442+
def clear_lp_grads(self, set_to_none=False):
446443

447444
# using zero_() fixed memory address for graph replay
448-
set_to_none = False if self.graph_harvesting else True
445+
if self.graph_harvesting:
446+
assert not set_to_none, "graph harvesting is incompatible with setting lp grads to None"
447+
449448
zero_grads_list = []
450449
for group in self.bf16_groups:
451450
for param in group:
@@ -458,6 +457,10 @@ def clear_lp_grads(self):
458457
if not set_to_none and len(zero_grads_list) > 0:
459458
torch._foreach_zero_(zero_grads_list)
460459

460+
def zero_grad(self, set_to_none=True):
461+
self.clear_lp_grads(set_to_none)
462+
self.clear_hp_grads()
463+
461464
def state_dict(self):
462465
state_dict = {}
463466
state_dict[CLIP_GRAD] = self.clip_grad

deepspeed/runtime/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128

129129
# BFLOAT16 optimizer immediate gradient update
130130
BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update"
131-
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False
131+
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = True
132132

133133
#########################################
134134
# FP16 support

deepspeed/runtime/engine.py

+105-78
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ def __init__(self,
272272
# Configure distributed model
273273
self._configure_distributed_model(model)
274274

275+
self.module_forward_pre_hook = self._create_module_forward_pre_hook()
276+
self.module_forward_post_hook = self._create_module_forward_post_hook()
277+
275278
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
276279
self.param_names = {param: name for name, param in model.named_parameters()}
277280

@@ -1875,7 +1878,6 @@ def deepspeed_io(self,
18751878
GLOBAL_RANK: self.global_rank,
18761879
DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS]
18771880
}
1878-
18791881
return DeepSpeedDataLoader(dataset=dataset,
18801882
batch_size=batch_size,
18811883
pin_memory=pin_memory,
@@ -1922,17 +1924,24 @@ def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None):
19221924

19231925
return scaled_loss
19241926

1925-
@instrument_w_nvtx
1926-
def forward(self, *inputs, **kwargs):
1927-
r"""Execute forward propagation
1928-
Arguments:
1929-
*inputs: Variable length input list
1930-
**kwargs: variable length keyword arguments
1931-
"""
1927+
def _create_module_forward_pre_hook(self):
19321928

1933-
if self.autotuning_profile_model_info():
1934-
ma = get_ma_status()
1935-
else:
1929+
def _module_forward_pre_hook(module, inputs, kwargs):
1930+
return self._forward_prologue(inputs, kwargs)
1931+
1932+
return self.module.register_forward_pre_hook(_module_forward_pre_hook, prepend=False, with_kwargs=True)
1933+
1934+
def _create_module_forward_post_hook(self):
1935+
1936+
def _module_forward_post_hook(module, input, output):
1937+
self._forward_epilogue()
1938+
1939+
return self.module.register_forward_hook(_module_forward_post_hook)
1940+
1941+
def _forward_prologue(self, inputs, kwargs):
1942+
return_modified = False
1943+
1944+
if not self.autotuning_profile_model_info():
19361945
see_memory_usage("Engine before forward", force=self.memory_breakdown())
19371946

19381947
flops_profiler_active = (self.flops_profiler_enabled()
@@ -1951,58 +1960,81 @@ def forward(self, *inputs, **kwargs):
19511960
self.eigenvalue_enabled(),
19521961
None,
19531962
)
1963+
return_modified = True
19541964

19551965
if flops_profiler_active:
19561966
self.flops_profiler.start_profile(ignore_list=None)
19571967

1958-
if self.module.training:
1959-
if self.progressive_layer_drop:
1960-
kwargs.update(self.progressive_layer_drop.get_state())
1968+
if kwargs is not None:
1969+
if self.module.training:
1970+
if self.progressive_layer_drop:
1971+
kwargs.update(self.progressive_layer_drop.get_state())
19611972

1962-
if self.__class__.__name__ != "PipelineEngine":
1963-
# TODO: The above if condition is a HACK since for PipelineEngine
1964-
# it's difficult to inject argument in forward pass.
1965-
if self.module.training and self.curriculum_enabled_legacy():
1966-
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
1967-
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
1968-
kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})
1973+
if self.__class__.__name__ != "PipelineEngine":
1974+
# TODO: The above if condition is a HACK since for PipelineEngine
1975+
# it's difficult to inject argument in forward pass.
1976+
if self.module.training and self.curriculum_enabled_legacy():
1977+
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
1978+
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
1979+
kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()})
1980+
return_modified = True
19691981

19701982
if self.module.training and self.random_ltd_enabled():
19711983
self.random_ltd_scheduler.update_seq(self.global_steps)
19721984

1985+
if self.training_dataloader is None:
1986+
self.tput_timer.start()
1987+
1988+
self._start_timers(self.engine_timers.forward_timers)
1989+
19731990
if self.zero_optimization_partition_weights():
19741991
# Enable automated discovery of external parameters by indicating that
19751992
# we are in a forward pass.
19761993
for module in self.module.modules():
19771994
module._parameters._in_forward = True
19781995

1979-
self._start_timers(self.engine_timers.forward_timers)
1980-
1981-
if self.training_dataloader is None:
1982-
self.tput_timer.start()
1983-
19841996
if self.fp16_auto_cast():
19851997
inputs = self._cast_inputs_half(inputs)
1998+
return_modified = True
19861999

1987-
loss = self.module(*inputs, **kwargs)
2000+
if return_modified:
2001+
return inputs, kwargs
19882002

2003+
def _forward_epilogue(self):
19892004
if self.zero_optimization_partition_weights():
19902005
# Disable automated discovery of external parameters
19912006
for module in self.module.modules():
19922007
module._parameters._in_forward = False
19932008

19942009
self._stop_timers(self.engine_timers.forward_timers)
19952010

2011+
flops_profiler_active = (self.flops_profiler_enabled()
2012+
and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0)
2013+
19962014
if flops_profiler_active:
19972015
self.flops_profiler.stop_profile()
19982016

2017+
if not self.autotuning_profile_model_info():
2018+
see_memory_usage("Engine after forward", force=self.memory_breakdown())
2019+
2020+
@instrument_w_nvtx
2021+
def forward(self, *inputs, **kwargs):
2022+
r"""Execute forward propagation
2023+
Arguments:
2024+
*inputs: Variable length input list
2025+
**kwargs: variable length keyword arguments
2026+
"""
2027+
if self.autotuning_profile_model_info():
2028+
ma = get_ma_status()
2029+
2030+
loss = self.module(*inputs, **kwargs)
2031+
19992032
if self.autotuning_profile_model_info():
20002033
activation_mem = get_ma_status() - ma
20012034
self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem
20022035
print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path())
20032036
exit()
2004-
else:
2005-
see_memory_usage("Engine after forward", force=self.memory_breakdown())
2037+
20062038
return loss
20072039

20082040
def _cast_inputs_half(self, inputs):
@@ -2061,43 +2093,13 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
20612093
grads = None
20622094
self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size)
20632095

2064-
@contextmanager
2065-
def no_sync(self):
2066-
r"""
2067-
Context manager to disable gradient reduction during backward pass.
2068-
This context manager has the following effects on other DeepSpeed features.
2069-
1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2070-
2. It is illegal to call engine.step() within the context manager.
2071-
3. Tracking of gradient accumulation steps is disabled.
2072-
"""
2073-
assert not self.zero_optimization_partition_gradients(), \
2074-
f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}"
2075-
2076-
assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported"
2077-
2078-
self.inside_no_sync_ctxt = True
2079-
try:
2080-
yield
2081-
finally:
2082-
self.inside_no_sync_ctxt = False
2083-
2084-
@instrument_w_nvtx
2085-
def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=True):
2086-
r"""Execute backward pass on the loss
2087-
Arguments:
2088-
loss: Torch tensor on which to execute backward propagation
2089-
retain_graph: bool, default: false
2090-
forward on user defined choice of retain_graph
2091-
"""
2092-
2096+
def _backward_prologue(self, loss, scale_wrt_gas=True):
20932097
see_memory_usage("Engine before backward", force=self.memory_breakdown())
2094-
20952098
if self.scale_wrt_gas is not None:
20962099
scale_wrt_gas = self.scale_wrt_gas
20972100

2098-
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt
2099-
21002101
# scale loss w.r.t. gradient accumulation if reduction is not disabled
2102+
do_gradient_reduction = self.enable_backward_allreduce and not self.inside_no_sync_ctxt
21012103
if do_gradient_reduction and self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
21022104
loss = self._scale_loss_by_gas(loss.float())
21032105

@@ -2114,13 +2116,18 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
21142116
)]
21152117
self.monitor.write_events(self.summary_events)
21162118

2117-
self._start_timers(self.engine_timers.backward_timers)
2119+
return loss
21182120

2119-
assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
2120-
"must provide optimizer during init in order to use backward"
2121+
def _backward_epilogue(self):
2122+
self._start_timers(self.engine_timers.backward_reduce_timers)
2123+
if self.enable_backward_allreduce and not self.inside_no_sync_ctxt:
2124+
# Traditional code path that allreduces the module parameter grads
2125+
self.allreduce_gradients()
2126+
self._stop_timers(self.engine_timers.backward_reduce_timers)
2127+
see_memory_usage("Engine after backward", force=self.memory_breakdown())
21212128

2129+
def _do_optimizer_backward(self, loss, retain_graph):
21222130
self._start_timers(self.engine_timers.backward_inner_timers)
2123-
21242131
if self.zero_optimization():
21252132
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
21262133
self.optimizer.backward(loss, retain_graph=retain_graph)
@@ -2136,30 +2143,50 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
21362143
else:
21372144
self.optimizer.backward(loss, retain_graph=retain_graph)
21382145
elif self.bfloat16_enabled():
2139-
self.optimizer.backward(loss)
2146+
self.optimizer.backward(loss, retain_graph=retain_graph)
21402147
else:
21412148
if self.eigenvalue_enabled():
21422149
loss.backward(create_graph=True, retain_graph=True)
21432150
else:
21442151
loss.backward(retain_graph=retain_graph)
2145-
21462152
self._stop_timers(self.engine_timers.backward_inner_timers)
21472153

2148-
self._start_timers(self.engine_timers.backward_reduce_timers)
2149-
2150-
if do_gradient_reduction:
2151-
# Traditional code path that allreduces the module parameter grads
2152-
self.allreduce_gradients()
2154+
@contextmanager
2155+
def no_sync(self):
2156+
r"""
2157+
Context manager to disable gradient reduction during backward pass.
2158+
This context manager has the following effects on other DeepSpeed features:
2159+
1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2160+
2. It is illegal to call engine.step() within the context manager.
2161+
3. Tracking of gradient accumulation steps is disabled.
2162+
"""
2163+
assert not self.zero_optimization_partition_gradients(), \
2164+
f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}"
21532165

2154-
self._stop_timers(self.engine_timers.backward_reduce_timers)
2166+
assert not self.inside_no_sync_ctxt, f"no_sync context manager reentry is unsupported"
21552167

2156-
self._stop_timers(self.engine_timers.backward_timers)
2168+
self.inside_no_sync_ctxt = True
2169+
try:
2170+
yield
2171+
finally:
2172+
self.inside_no_sync_ctxt = False
21572173

2158-
if release_loss:
2159-
# loss.data = None
2160-
pass
2174+
@instrument_w_nvtx
2175+
def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
2176+
r"""Execute backward pass on the loss
2177+
Arguments:
2178+
loss: Torch tensor on which to execute backward propagation
2179+
retain_graph: bool, default: false
2180+
forward on user defined choice of retain_graph
2181+
"""
2182+
assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \
2183+
"must provide optimizer during init in order to use backward"
21612184

2162-
see_memory_usage("Engine after backward", force=self.memory_breakdown())
2185+
self._start_timers(self.engine_timers.backward_timers)
2186+
loss = self._backward_prologue(loss, scale_wrt_gas)
2187+
self._do_optimizer_backward(loss, retain_graph)
2188+
self._backward_epilogue()
2189+
self._stop_timers(self.engine_timers.backward_timers)
21632190

21642191
return loss
21652192

0 commit comments

Comments
 (0)