Skip to content

Commit 0be7f02

Browse files
committed
fix bug in recompute.py
1 parent 571341b commit 0be7f02

File tree

2 files changed

+491
-1
lines changed

2 files changed

+491
-1
lines changed

python/paddle/distributed/fleet/recompute/recompute.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import contextlib
1818
import copy
19+
import ctypes
1920
import inspect
2021
import random
2122
import weakref
@@ -223,6 +224,22 @@ def switch_rng_state_tracker(
223224
custom_set_state_func(orig_custom_state)
224225

225226

227+
def _restore_freed_closure_tensors(ctx):
228+
"""..."""
229+
_PyCell_Set = ctypes.pythonapi.PyCell_Set
230+
_PyCell_Set.argtypes = [ctypes.py_object, ctypes.py_object]
231+
_PyCell_Set.restype = ctypes.c_int
232+
for cell, protected in zip(ctx.closure_cells, ctx.closure_protected):
233+
if cell is None or protected is None:
234+
continue
235+
try:
236+
val = cell.cell_contents
237+
except ValueError:
238+
continue
239+
if isinstance(val, core.eager.Tensor) and not val._is_initialized():
240+
_PyCell_Set(cell, protected)
241+
242+
226243
class RecomputeFunction(PyLayer):
227244
@staticmethod
228245
def forward(
@@ -243,6 +260,32 @@ def forward(
243260
ctx.offload_indices = offload_indices
244261
ctx.kwargs = kwargs
245262

263+
# Protect tensor-type closure variables of run_function against
264+
# pipeline-parallel _release_input/_release_output calling _clear_dataptr().
265+
# Explicit args are already protected by _protect_tensors(); here we cover
266+
# any tensors captured in the function's __closure__ (e.g. grid_thw).
267+
ctx.closure_cells = []
268+
ctx.closure_protected = []
269+
fn = (
270+
run_function.forward
271+
if isinstance(run_function, paddle.nn.Layer)
272+
else run_function
273+
)
274+
if hasattr(fn, '__closure__') and fn.__closure__:
275+
for cell in fn.__closure__:
276+
try:
277+
val = cell.cell_contents
278+
except ValueError: # empty cell
279+
ctx.closure_cells.append(None)
280+
ctx.closure_protected.append(None)
281+
continue
282+
if isinstance(val, core.eager.Tensor):
283+
ctx.closure_cells.append(cell)
284+
ctx.closure_protected.append(val._new_shared_tensor())
285+
else:
286+
ctx.closure_cells.append(None)
287+
ctx.closure_protected.append(None)
288+
246289
# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
247290
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input
248291
# None tensor inputs will be filtered in backward inputs.
@@ -385,6 +428,7 @@ def backward(ctx, *args):
385428
dtype=ctx.amp_dtype,
386429
),
387430
):
431+
_restore_freed_closure_tensors(ctx)
388432
detached_inputs = detach_variable(tuple(inputs))
389433
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
390434
else:
@@ -395,6 +439,7 @@ def backward(ctx, *args):
395439
level=ctx.amp_level,
396440
dtype=ctx.amp_dtype,
397441
):
442+
_restore_freed_closure_tensors(ctx)
398443
detached_inputs = detach_variable(tuple(inputs))
399444
outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
400445

0 commit comments

Comments
 (0)