@@ -273,6 +273,9 @@ def __init__(self,
273
273
# Configure distributed model
274
274
self ._configure_distributed_model (model )
275
275
276
+ self .module_forward_pre_hook = self ._create_module_forward_pre_hook ()
277
+ self .module_forward_post_hook = self ._create_module_forward_post_hook ()
278
+
276
279
# needed for zero_to_fp32 weights reconstruction to remap nameless data to state_dict
277
280
self .param_names = {param : name for name , param in model .named_parameters ()}
278
281
@@ -1889,7 +1892,6 @@ def deepspeed_io(self,
1889
1892
GLOBAL_RANK : self .global_rank ,
1890
1893
DATA_SAMPLING_NUM_WORKERS : self .data_sampling_config ()[DATA_SAMPLING_NUM_WORKERS ]
1891
1894
}
1892
-
1893
1895
return DeepSpeedDataLoader (dataset = dataset ,
1894
1896
batch_size = batch_size ,
1895
1897
pin_memory = pin_memory ,
@@ -1936,17 +1938,24 @@ def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None):
1936
1938
1937
1939
return scaled_loss
1938
1940
1939
- @instrument_w_nvtx
1940
- def forward (self , * inputs , ** kwargs ):
1941
- r"""Execute forward propagation
1942
- Arguments:
1943
- *inputs: Variable length input list
1944
- **kwargs: variable length keyword arguments
1945
- """
1941
+ def _create_module_forward_pre_hook (self ):
1946
1942
1947
- if self .autotuning_profile_model_info ():
1948
- ma = get_ma_status ()
1949
- else :
1943
+ def _module_forward_pre_hook (module , inputs , kwargs ):
1944
+ return self ._forward_prologue (inputs , kwargs )
1945
+
1946
+ return self .module .register_forward_pre_hook (_module_forward_pre_hook , prepend = False , with_kwargs = True )
1947
+
1948
+ def _create_module_forward_post_hook (self ):
1949
+
1950
+ def _module_forward_post_hook (module , input , output ):
1951
+ self ._forward_epilogue ()
1952
+
1953
+ return self .module .register_forward_hook (_module_forward_post_hook )
1954
+
1955
+ def _forward_prologue (self , inputs , kwargs ):
1956
+ return_modified = False
1957
+
1958
+ if not self .autotuning_profile_model_info ():
1950
1959
see_memory_usage ("Engine before forward" , force = self .memory_breakdown ())
1951
1960
1952
1961
flops_profiler_active = (self .flops_profiler_enabled ()
@@ -1965,61 +1974,84 @@ def forward(self, *inputs, **kwargs):
1965
1974
self .eigenvalue_enabled (),
1966
1975
None ,
1967
1976
)
1977
+ return_modified = True
1968
1978
1969
1979
if flops_profiler_active :
1970
1980
self .flops_profiler .start_profile (ignore_list = None )
1971
1981
1972
- if self .module .training :
1973
- if self .progressive_layer_drop :
1974
- kwargs .update (self .progressive_layer_drop .get_state ())
1982
+ if kwargs is not None :
1983
+ if self .module .training :
1984
+ if self .progressive_layer_drop :
1985
+ kwargs .update (self .progressive_layer_drop .get_state ())
1975
1986
1976
- if self .__class__ .__name__ != "PipelineEngine" :
1977
- # TODO: The above if condition is a HACK since for PipelineEngine
1978
- # it's difficult to inject argument in forward pass.
1979
- if self .module .training and self .curriculum_enabled_legacy ():
1980
- self .curriculum_scheduler_legacy .update_difficulty (self .global_steps + 1 )
1981
- if self .curriculum_params_legacy ()["curriculum_type" ] == "seqlen" :
1982
- kwargs .update ({"curriculum_seqlen" : self .curriculum_scheduler_legacy .get_current_difficulty ()})
1987
+ if self .__class__ .__name__ != "PipelineEngine" :
1988
+ # TODO: The above if condition is a HACK since for PipelineEngine
1989
+ # it's difficult to inject argument in forward pass.
1990
+ if self .module .training and self .curriculum_enabled_legacy ():
1991
+ self .curriculum_scheduler_legacy .update_difficulty (self .global_steps + 1 )
1992
+ if self .curriculum_params_legacy ()["curriculum_type" ] == "seqlen" :
1993
+ kwargs .update ({"curriculum_seqlen" : self .curriculum_scheduler_legacy .get_current_difficulty ()})
1994
+ return_modified = True
1983
1995
1984
1996
if self .module .training and self .random_ltd_enabled ():
1985
1997
self .random_ltd_scheduler .update_seq (self .global_steps )
1986
1998
1999
+ if self .training_dataloader is None :
2000
+ self .tput_timer .start ()
2001
+
2002
+ self ._start_timers (self .engine_timers .forward_timers )
2003
+
1987
2004
if self .zero_optimization_partition_weights ():
1988
2005
# Enable automated discovery of external parameters by indicating that
1989
2006
# we are in a forward pass.
1990
2007
for module in self .module .modules ():
1991
2008
module ._parameters ._in_forward = True
1992
2009
1993
- self ._start_timers (self .engine_timers .forward_timers )
1994
-
1995
- if self .training_dataloader is None :
1996
- self .tput_timer .start ()
1997
-
1998
2010
if self .fp16_auto_cast ():
1999
2011
inputs = self ._cast_inputs_half (inputs )
2012
+ return_modified = True
2000
2013
2001
- with torch .autocast (device_type = get_accelerator ().device_name (),
2002
- dtype = self .torch_autocast_dtype (),
2003
- enabled = self .torch_autocast_enabled ()):
2004
- loss = self .module (* inputs , ** kwargs )
2014
+ if return_modified :
2015
+ return inputs , kwargs
2005
2016
2017
+ def _forward_epilogue (self ):
2006
2018
if self .zero_optimization_partition_weights ():
2007
2019
# Disable automated discovery of external parameters
2008
2020
for module in self .module .modules ():
2009
2021
module ._parameters ._in_forward = False
2010
2022
2011
2023
self ._stop_timers (self .engine_timers .forward_timers )
2012
2024
2025
+ flops_profiler_active = (self .flops_profiler_enabled ()
2026
+ and self .global_steps == self .flops_profiler_profile_step () and self .global_rank == 0 )
2027
+
2013
2028
if flops_profiler_active :
2014
2029
self .flops_profiler .stop_profile ()
2015
2030
2031
+ if not self .autotuning_profile_model_info ():
2032
+ see_memory_usage ("Engine after forward" , force = self .memory_breakdown ())
2033
+
2034
+ @instrument_w_nvtx
2035
+ def forward (self , * inputs , ** kwargs ):
2036
+ r"""Execute forward propagation
2037
+ Arguments:
2038
+ *inputs: Variable length input list
2039
+ **kwargs: variable length keyword arguments
2040
+ """
2041
+ if self .autotuning_profile_model_info ():
2042
+ ma = get_ma_status ()
2043
+
2044
+ with torch .autocast (device_type = get_accelerator ().device_name (),
2045
+ dtype = self .torch_autocast_dtype (),
2046
+ enabled = self .torch_autocast_enabled ()):
2047
+ loss = self .module (* inputs , ** kwargs )
2048
+
2016
2049
if self .autotuning_profile_model_info ():
2017
2050
activation_mem = get_ma_status () - ma
2018
2051
self .autotuning_model_info ["activation_mem_per_gpu" ] = activation_mem
2019
2052
print_json_dist (self .autotuning_model_info , [0 ], path = self .autotuning_model_info_path ())
2020
2053
exit ()
2021
- else :
2022
- see_memory_usage ("Engine after forward" , force = self .memory_breakdown ())
2054
+
2023
2055
return loss
2024
2056
2025
2057
def _cast_inputs_half (self , inputs ):
@@ -2078,43 +2110,13 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
2078
2110
grads = None
2079
2111
self .buffered_allreduce_fallback (grads = grads , elements_per_buffer = bucket_size )
2080
2112
2081
- @contextmanager
2082
- def no_sync (self ):
2083
- r"""
2084
- Context manager to disable gradient reduction during backward pass.
2085
- This context manager has the following effects on other DeepSpeed features.
2086
- 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2087
- 2. It is illegal to call engine.step() within the context manager.
2088
- 3. Tracking of gradient accumulation steps is disabled.
2089
- """
2090
- assert not self .zero_optimization_partition_gradients (), \
2091
- f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage { self .zero_optimization_stage ()} "
2092
-
2093
- assert not self .inside_no_sync_ctxt , f"no_sync context manager reentry is unsupported"
2094
-
2095
- self .inside_no_sync_ctxt = True
2096
- try :
2097
- yield
2098
- finally :
2099
- self .inside_no_sync_ctxt = False
2100
-
2101
- @instrument_w_nvtx
2102
- def backward (self , loss , release_loss = False , retain_graph = False , scale_wrt_gas = True ):
2103
- r"""Execute backward pass on the loss
2104
- Arguments:
2105
- loss: Torch tensor on which to execute backward propagation
2106
- retain_graph: bool, default: false
2107
- forward on user defined choice of retain_graph
2108
- """
2109
-
2113
+ def _backward_prologue (self , loss , scale_wrt_gas = True ):
2110
2114
see_memory_usage ("Engine before backward" , force = self .memory_breakdown ())
2111
-
2112
2115
if self .scale_wrt_gas is not None :
2113
2116
scale_wrt_gas = self .scale_wrt_gas
2114
2117
2115
- do_gradient_reduction = self .enable_backward_allreduce and not self .inside_no_sync_ctxt
2116
-
2117
2118
# scale loss w.r.t. gradient accumulation if reduction is not disabled
2119
+ do_gradient_reduction = self .enable_backward_allreduce and not self .inside_no_sync_ctxt
2118
2120
if do_gradient_reduction and self .gradient_accumulation_steps () > 1 and scale_wrt_gas :
2119
2121
loss = self ._scale_loss_by_gas (loss .float ())
2120
2122
@@ -2131,13 +2133,18 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
2131
2133
)]
2132
2134
self .monitor .write_events (self .summary_events )
2133
2135
2134
- self . _start_timers ( self . engine_timers . backward_timers )
2136
+ return loss
2135
2137
2136
- assert self .optimizer is not None and not isinstance (self .optimizer , DummyOptim ), \
2137
- "must provide optimizer during init in order to use backward"
2138
+ def _backward_epilogue (self ):
2139
+ self ._start_timers (self .engine_timers .backward_reduce_timers )
2140
+ if self .enable_backward_allreduce and not self .inside_no_sync_ctxt :
2141
+ # Traditional code path that allreduces the module parameter grads
2142
+ self .allreduce_gradients ()
2143
+ self ._stop_timers (self .engine_timers .backward_reduce_timers )
2144
+ see_memory_usage ("Engine after backward" , force = self .memory_breakdown ())
2138
2145
2146
+ def _do_optimizer_backward (self , loss , retain_graph ):
2139
2147
self ._start_timers (self .engine_timers .backward_inner_timers )
2140
-
2141
2148
if self .zero_optimization ():
2142
2149
self .optimizer .is_gradient_accumulation_boundary = self .is_gradient_accumulation_boundary ()
2143
2150
self .optimizer .backward (loss , retain_graph = retain_graph )
@@ -2153,30 +2160,50 @@ def backward(self, loss, release_loss=False, retain_graph=False, scale_wrt_gas=T
2153
2160
else :
2154
2161
self .optimizer .backward (loss , retain_graph = retain_graph )
2155
2162
elif self .bfloat16_enabled ():
2156
- self .optimizer .backward (loss )
2163
+ self .optimizer .backward (loss , retain_graph = retain_graph )
2157
2164
else :
2158
2165
if self .eigenvalue_enabled ():
2159
2166
loss .backward (create_graph = True , retain_graph = True )
2160
2167
else :
2161
2168
loss .backward (retain_graph = retain_graph )
2162
-
2163
2169
self ._stop_timers (self .engine_timers .backward_inner_timers )
2164
2170
2165
- self ._start_timers (self .engine_timers .backward_reduce_timers )
2166
-
2167
- if do_gradient_reduction :
2168
- # Traditional code path that allreduces the module parameter grads
2169
- self .allreduce_gradients ()
2171
+ @contextmanager
2172
+ def no_sync (self ):
2173
+ r"""
2174
+ Context manager to disable gradient reduction during backward pass.
2175
+ This context manager has the following effects on other DeepSpeed features:
2176
+ 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning.
2177
+ 2. It is illegal to call engine.step() within the context manager.
2178
+ 3. Tracking of gradient accumulation steps is disabled.
2179
+ """
2180
+ assert not self .zero_optimization_partition_gradients (), \
2181
+ f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage { self .zero_optimization_stage ()} "
2170
2182
2171
- self ._stop_timers ( self . engine_timers . backward_reduce_timers )
2183
+ assert not self .inside_no_sync_ctxt , f"no_sync context manager reentry is unsupported"
2172
2184
2173
- self ._stop_timers (self .engine_timers .backward_timers )
2185
+ self .inside_no_sync_ctxt = True
2186
+ try :
2187
+ yield
2188
+ finally :
2189
+ self .inside_no_sync_ctxt = False
2174
2190
2175
- if release_loss :
2176
- # loss.data = None
2177
- pass
2191
+ @instrument_w_nvtx
2192
+ def backward (self , loss , retain_graph = False , scale_wrt_gas = True ):
2193
+ r"""Execute backward pass on the loss
2194
+ Arguments:
2195
+ loss: Torch tensor on which to execute backward propagation
2196
+ retain_graph: bool, default: false
2197
+ forward on user defined choice of retain_graph
2198
+ """
2199
+ assert self .optimizer is not None and not isinstance (self .optimizer , DummyOptim ), \
2200
+ "must provide optimizer during init in order to use backward"
2178
2201
2179
- see_memory_usage ("Engine after backward" , force = self .memory_breakdown ())
2202
+ self ._start_timers (self .engine_timers .backward_timers )
2203
+ loss = self ._backward_prologue (loss , scale_wrt_gas )
2204
+ self ._do_optimizer_backward (loss , retain_graph )
2205
+ self ._backward_epilogue ()
2206
+ self ._stop_timers (self .engine_timers .backward_timers )
2180
2207
2181
2208
return loss
2182
2209
0 commit comments