@@ -272,6 +272,9 @@ def __init__(self,
272
272
# Configure distributed model
273
273
self ._configure_distributed_model (model )
274
274
275
+ self .module_forward_pre_hook = self ._create_module_forward_pre_hook ()
276
+ self .module_forward_post_hook = self ._create_module_forward_post_hook ()
277
+
275
278
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
276
279
self .param_names = {param : name for name , param in model .named_parameters ()}
277
280
@@ -1875,7 +1878,6 @@ def deepspeed_io(self,
1875
1878
GLOBAL_RANK : self .global_rank ,
1876
1879
DATA_SAMPLING_NUM_WORKERS : self .data_sampling_config ()[DATA_SAMPLING_NUM_WORKERS ]
1877
1880
}
1878
-
1879
1881
return DeepSpeedDataLoader (dataset = dataset ,
1880
1882
batch_size = batch_size ,
1881
1883
pin_memory = pin_memory ,
@@ -1922,17 +1924,24 @@ def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None):
1922
1924
1923
1925
return scaled_loss
1924
1926
1925
- @instrument_w_nvtx
1926
- def forward (self , * inputs , ** kwargs ):
1927
- r"""Execute forward propagation
1928
- Arguments:
1929
- *inputs: Variable length input list
1930
- **kwargs: variable length keyword arguments
1931
- """
1927
+ def _create_module_forward_pre_hook (self ):
1932
1928
1933
- if self .autotuning_profile_model_info ():
1934
- ma = get_ma_status ()
1935
- else :
1929
+ def _module_forward_pre_hook (module , inputs , kwargs ):
1930
+ return self ._forward_prologue (inputs , kwargs )
1931
+
1932
+ return self .module .register_forward_pre_hook (_module_forward_pre_hook , prepend = False , with_kwargs = True )
1933
+
1934
+ def _create_module_forward_post_hook (self ):
1935
+
1936
+ def _module_forward_post_hook (module , input , output ):
1937
+ self ._forward_epilogue ()
1938
+
1939
+ return self .module .register_forward_hook (_module_forward_post_hook )
1940
+
1941
+ def _forward_prologue (self , inputs , kwargs ):
1942
+ return_modified = False
1943
+
1944
+ if not self .autotuning_profile_model_info ():
1936
1945
see_memory_usage ("Engine before forward" , force = self .memory_breakdown ())
1937
1946
1938
1947
flops_profiler_active = (self .flops_profiler_enabled ()
@@ -1951,58 +1960,81 @@ def forward(self, *inputs, **kwargs):
1951
1960
self .eigenvalue_enabled (),
1952
1961
None ,
1953
1962
)
1963
+ return_modified = True
1954
1964
1955
1965
if flops_profiler_active :
1956
1966
self .flops_profiler .start_profile (ignore_list = None )
1957
1967
1958
- if self .module .training :
1959
- if self .progressive_layer_drop :
1960
- kwargs .update (self .progressive_layer_drop .get_state ())
1968
+ if kwargs is not None :
1969
+ if self .module .training :
1970
+ if self .progressive_layer_drop :
1971
+ kwargs .update (self .progressive_layer_drop .get_state ())
1961
1972
1962
- if self .__class__ .__name__ != "PipelineEngine" :
1963
- # TODO: The above if condition is a HACK since for PipelineEngine
1964
- # it's difficult to inject argument in forward pass.
1965
- if self .module .training and self .curriculum_enabled_legacy ():
1966
- self .curriculum_scheduler_legacy .update_difficulty (self .global_steps + 1 )
1967
- if self .curriculum_params_legacy ()["curriculum_type" ] == "seqlen" :
1968
- kwargs .update ({"curriculum_seqlen" : self .curriculum_scheduler_legacy .get_current_difficulty ()})
1973
+ if self .__class__ .__name__ != "PipelineEngine" :
1974
+ # TODO: The above if condition is a HACK since for PipelineEngine
1975
+ # it's difficult to inject argument in forward pass.
1976
+ if self .module .training and self .curriculum_enabled_legacy ():
1977
+ self .curriculum_scheduler_legacy .update_difficulty (self .global_steps + 1 )
1978
+ if self .curriculum_params_legacy ()["curriculum_type" ] == "seqlen" :
1979
+ kwargs .update ({"curriculum_seqlen" : self .curriculum_scheduler_legacy .get_current_difficulty ()})
1980
+ return_modified = True
1969
1981
1970
1982
if self .module .training and self .random_ltd_enabled ():
1971
1983
self .random_ltd_scheduler .update_seq (self .global_steps )
1972
1984
1985
+ if self .training_dataloader is None :
1986
+ self .tput_timer .start ()
1987
+
1988
+ self ._start_timers (self .engine_timers .forward_timers )
1989
+
1973
1990
if self .zero_optimization_partition_weights ():
1974
1991
# Enable automated discovery of external parameters by indicating that
1975
1992
# we are in a forward pass.
1976
1993
for module in self .module .modules ():
1977
1994
module ._parameters ._in_forward = True
1978
1995
1979
- self ._start_timers (self .engine_timers .forward_timers )
1980
-
1981
- if self .training_dataloader is None :
1982
- self .tput_timer .start ()
1983
-
1984
1996
if self .fp16_auto_cast ():
1985
1997
inputs = self ._cast_inputs_half (inputs )
1998
+ return_modified = True
1986
1999
1987
- loss = self .module (* inputs , ** kwargs )
2000
+ if return_modified :
2001
+ return inputs , kwargs
1988
2002
2003
+ def _forward_epilogue (self ):
1989
2004
if self .zero_optimization_partition_weights ():
1990
2005
# Disable automated discovery of external parameters
1991
2006
for module in self .module .modules ():
1992
2007
module ._parameters ._in_forward = False
1993
2008
1994
2009
self ._stop_timers (self .engine_timers .forward_timers )
1995
2010
2011
+ flops_profiler_active = (self .flops_profiler_enabled ()
2012
+ and self .global_steps == self .flops_profiler_profile_step () and self .global_rank == 0 )
2013
+
1996
2014
if flops_profiler_active :
1997
2015
self .flops_profiler .stop_profile ()
1998
2016
2017
+ if not self .autotuning_profile_model_info ():
2018
+ see_memory_usage ("Engine after forward" , force = self .memory_breakdown ())
2019
+
2020
+ @instrument_w_nvtx
2021
+ def forward (self , * inputs , ** kwargs ):
2022
+ r"""Execute forward propagation
2023
+ Arguments:
2024
+ *inputs: Variable length input list
2025
+ **kwargs: variable length keyword arguments
2026
+ """
2027
+ if self .autotuning_profile_model_info ():
2028
+ ma = get_ma_status ()
2029
+
2030
+ loss = self .module (* inputs , ** kwargs )
2031
+
1999
2032
if self .autotuning_profile_model_info ():
2000
2033
activation_mem = get_ma_status () - ma
2001
2034
self .autotuning_model_info ["activation_mem_per_gpu" ] = activation_mem
2002
2035
print_json_dist (self .autotuning_model_info , [0 ], path = self .autotuning_model_info_path ())
2003
2036
exit ()
2004
- else :
2005
- see_memory_usage ("Engine after forward" , force = self .memory_breakdown ())
2037
+
2006
2038
return loss
2007
2039
2008
2040
def _cast_inputs_half (self , inputs ):
@@ -2061,43 +2093,13 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
2061
2093
grads = None
2062
2094
self .buffered_allreduce_fallback (grads = grads , elements_per_buffer = bucket_size )
2063
2095
2064
- @contextmanager
2065
- def no_sync (self ):
2066
- r"""
2067
- Context manager to disable gradient reduction during backward pass.
2068
- This context manager has the following effects on other DeepSpeed features.
2069
- 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2070
- 2. It is illegal to call engine.step() within the context manager.
2071
- 3. Tracking of gradient accumulation steps is disabled.
2072
- """
2073
- assert not self .zero_optimization_partition_gradients (), \
2074
- f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage { self .zero_optimization_stage ()} "
2075
-
2076
- assert not self .inside_no_sync_ctxt , f"no_sync context manager reentry is unsupported"
2077
-
2078
- self .inside_no_sync_ctxt = True
2079
- try :
2080
- yield
2081
- finally :
2082
- self .inside_no_sync_ctxt = False
2083
-
2084
- @instrument_w_nvtx
2085
- def backward (self , loss , release_loss = False , retain_graph = False , scale_wrt_gas = True ):
2086
- r"""Execute backward pass on the loss
2087
- Arguments:
2088
- loss: Torch tensor on which to execute backward propagation
2089
- retain_graph: bool, default: false
2090
- forward on user defined choice of retain_graph
2091
- """
2092
-
2096
+ def _backward_prologue (self , loss , scale_wrt_gas = True ):
2093
2097
see_memory_usage ("Engine before backward" , force = self .memory_breakdown ())
2094
-
2095
2098
if self .scale_wrt_gas is not None :
2096
2099
scale_wrt_gas = self .scale_wrt_gas
2097
2100
2098
- do_gradient_reduction = self .enable_backward_allreduce and not self .inside_no_sync_ctxt
2099
-
2100
2101
# scale loss w.r.t. gradient accumulation if reduction is not disabled
2102
+ do_gradient_reduction = self .enable_backward_allreduce and not self .inside_no_sync_ctxt
2101
2103
if do_gradient_reduction and self .gradient_accumulation_steps () > 1 and scale_wrt_gas :
2102
2104
loss = self ._scale_loss_by_gas (loss .float ())
2103
2105
@@ -2114,13 +2116,18 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
2114
2116
)]
2115
2117
self .monitor .write_events (self .summary_events )
2116
2118
2117
- self . _start_timers ( self . engine_timers . backward_timers )
2119
+ return loss
2118
2120
2119
- assert self .optimizer is not None and not isinstance (self .optimizer , DummyOptim ), \
2120
- "must provide optimizer during init in order to use backward"
2121
+ def _backward_epilogue (self ):
2122
+ self ._start_timers (self .engine_timers .backward_reduce_timers )
2123
+ if self .enable_backward_allreduce and not self .inside_no_sync_ctxt :
2124
+ # Traditional code path that allreduces the module parameter grads
2125
+ self .allreduce_gradients ()
2126
+ self ._stop_timers (self .engine_timers .backward_reduce_timers )
2127
+ see_memory_usage ("Engine after backward" , force = self .memory_breakdown ())
2121
2128
2129
+ def _do_optimizer_backward (self , loss , retain_graph ):
2122
2130
self ._start_timers (self .engine_timers .backward_inner_timers )
2123
-
2124
2131
if self .zero_optimization ():
2125
2132
self .optimizer .is_gradient_accumulation_boundary = self .is_gradient_accumulation_boundary ()
2126
2133
self .optimizer .backward (loss , retain_graph = retain_graph )
@@ -2136,30 +2143,50 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
2136
2143
else :
2137
2144
self .optimizer .backward (loss , retain_graph = retain_graph )
2138
2145
elif self .bfloat16_enabled ():
2139
- self .optimizer .backward (loss )
2146
+ self .optimizer .backward (loss , retain_graph = retain_graph )
2140
2147
else :
2141
2148
if self .eigenvalue_enabled ():
2142
2149
loss .backward (create_graph = True , retain_graph = True )
2143
2150
else :
2144
2151
loss .backward (retain_graph = retain_graph )
2145
-
2146
2152
self ._stop_timers (self .engine_timers .backward_inner_timers )
2147
2153
2148
- self ._start_timers (self .engine_timers .backward_reduce_timers )
2149
-
2150
- if do_gradient_reduction :
2151
- # Traditional code path that allreduces the module parameter grads
2152
- self .allreduce_gradients ()
2154
+ @contextmanager
2155
+ def no_sync (self ):
2156
+ r"""
2157
+ Context manager to disable gradient reduction during backward pass.
2158
+ This context manager has the following effects on other DeepSpeed features:
2159
+ 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2160
+ 2. It is illegal to call engine.step() within the context manager.
2161
+ 3. Tracking of gradient accumulation steps is disabled.
2162
+ """
2163
+ assert not self .zero_optimization_partition_gradients (), \
2164
+ f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage { self .zero_optimization_stage ()} "
2153
2165
2154
- self ._stop_timers ( self . engine_timers . backward_reduce_timers )
2166
+ assert not self .inside_no_sync_ctxt , f"no_sync context manager reentry is unsupported"
2155
2167
2156
- self ._stop_timers (self .engine_timers .backward_timers )
2168
+ self .inside_no_sync_ctxt = True
2169
+ try :
2170
+ yield
2171
+ finally :
2172
+ self .inside_no_sync_ctxt = False
2157
2173
2158
- if release_loss :
2159
- # loss.data = None
2160
- pass
2174
+ @instrument_w_nvtx
2175
+ def backward (self , loss , retain_graph = False , scale_wrt_gas = True ):
2176
+ r"""Execute backward pass on the loss
2177
+ Arguments:
2178
+ loss: Torch tensor on which to execute backward propagation
2179
+ retain_graph: bool, default: false
2180
+ forward on user defined choice of retain_graph
2181
+ """
2182
+ assert self .optimizer is not None and not isinstance (self .optimizer , DummyOptim ), \
2183
+ "must provide optimizer during init in order to use backward"
2161
2184
2162
- see_memory_usage ("Engine after backward" , force = self .memory_breakdown ())
2185
+ self ._start_timers (self .engine_timers .backward_timers )
2186
+ loss = self ._backward_prologue (loss , scale_wrt_gas )
2187
+ self ._do_optimizer_backward (loss , retain_graph )
2188
+ self ._backward_epilogue ()
2189
+ self ._stop_timers (self .engine_timers .backward_timers )
2163
2190
2164
2191
return loss
2165
2192
0 commit comments