File tree Expand file tree Collapse file tree 2 files changed +5
-3
lines changed
Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -304,8 +304,9 @@ def backward_dw(self):
304304
305305 # the output grad memory is last used in wgrad compute, should be safe to release.
306306 assert self .delay_grads_release , "output grad memory should be valid before wgrad."
307- for tensor in self .output_grads :
308- tensor .untyped_storage ().resize_ (0 )
307+ if self .manual_release_grads :
308+ for tensor in self .output_grads :
309+ tensor .untyped_storage ().resize_ (0 )
309310 self .output_grads = None
310311
311312 self .bwd_dw_callables = None
Original file line number Diff line number Diff line change @@ -183,6 +183,7 @@ def __init__(
183183 self .inputs = None
184184 self .outputs = None
185185 self .delay_grads_release = False
186+ self .manual_release_grads = False
186187
187188 def default_backward_func (self , outputs , output_grad ):
188189 """Default backward function"""
@@ -268,7 +269,7 @@ def _backward(self, *output_grad):
268269 # to avoid delayed garbage collection. If
269270 # delay_grads_release is True, dgrad is last used in
270271 # wgrad compute and skip the release here.
271- if not self .delay_grads_release :
272+ if self . manual_release_grads and not self .delay_grads_release :
272273 g .untyped_storage ().resize_ (0 )
273274
274275 grads = self .get_grad ()
You can’t perform that action at this time.
0 commit comments