Skip to content

Commit 9885ddb

Browse files
authored
[Dev] Disable ep overlap memory optimization (NVIDIA#2750)
1 parent 1068d77 commit 9885ddb

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,9 @@ def backward_dw(self):
304304

305305
# the output grad memory is last used in wgrad compute, should be safe to release.
306306
assert self.delay_grads_release, "output grad memory should be valid before wgrad."
307-
for tensor in self.output_grads:
308-
tensor.untyped_storage().resize_(0)
307+
if self.manual_release_grads:
308+
for tensor in self.output_grads:
309+
tensor.untyped_storage().resize_(0)
309310
self.output_grads = None
310311

311312
self.bwd_dw_callables = None

megatron/core/pipeline_parallel/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def __init__(
183183
self.inputs = None
184184
self.outputs = None
185185
self.delay_grads_release = False
186+
self.manual_release_grads = False
186187

187188
def default_backward_func(self, outputs, output_grad):
188189
"""Default backward function"""
@@ -268,7 +269,7 @@ def _backward(self, *output_grad):
268269
# to avoid delayed garbage collection. If
269270
# delay_grads_release is True, dgrad is last used in
270271
# wgrad compute and skip the release here.
271-
if not self.delay_grads_release:
272+
if self.manual_release_grads and not self.delay_grads_release:
272273
g.untyped_storage().resize_(0)
273274

274275
grads = self.get_grad()

0 commit comments

Comments
 (0)