1616
1717import contextlib
1818import copy
19+ import ctypes
1920import inspect
2021import random
2122import weakref
@@ -223,6 +224,22 @@ def switch_rng_state_tracker(
223224 custom_set_state_func (orig_custom_state )
224225
225226
227+ def _restore_freed_closure_tensors (ctx ):
228+ """..."""
229+ _PyCell_Set = ctypes .pythonapi .PyCell_Set
230+ _PyCell_Set .argtypes = [ctypes .py_object , ctypes .py_object ]
231+ _PyCell_Set .restype = ctypes .c_int
232+ for cell , protected in zip (ctx .closure_cells , ctx .closure_protected ):
233+ if cell is None or protected is None :
234+ continue
235+ try :
236+ val = cell .cell_contents
237+ except ValueError :
238+ continue
239+ if isinstance (val , core .eager .Tensor ) and not val ._is_initialized ():
240+ _PyCell_Set (cell , protected )
241+
242+
226243class RecomputeFunction (PyLayer ):
227244 @staticmethod
228245 def forward (
@@ -243,6 +260,32 @@ def forward(
243260 ctx .offload_indices = offload_indices
244261 ctx .kwargs = kwargs
245262
263+ # Protect tensor-type closure variables of run_function against
264+ # pipeline-parallel _release_input/_release_output calling _clear_dataptr().
265+ # Explicit args are already protected by _protect_tensors(); here we cover
266+ # any tensors captured in the function's __closure__ (e.g. grid_thw).
267+ ctx .closure_cells = []
268+ ctx .closure_protected = []
269+ fn = (
270+ run_function .forward
271+ if isinstance (run_function , paddle .nn .Layer )
272+ else run_function
273+ )
274+ if hasattr (fn , '__closure__' ) and fn .__closure__ :
275+ for cell in fn .__closure__ :
276+ try :
277+ val = cell .cell_contents
278+ except ValueError : # empty cell
279+ ctx .closure_cells .append (None )
280+ ctx .closure_protected .append (None )
281+ continue
282+ if isinstance (val , core .eager .Tensor ):
283+ ctx .closure_cells .append (cell )
284+ ctx .closure_protected .append (val ._new_shared_tensor ())
285+ else :
286+ ctx .closure_cells .append (None )
287+ ctx .closure_protected .append (None )
288+
246289 # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
247290 # the order of tensors in backward()'s output should be the same as tensors in forward()'s input
248291 # None tensor inputs will be filtered in backward inputs.
@@ -385,6 +428,7 @@ def backward(ctx, *args):
385428 dtype = ctx .amp_dtype ,
386429 ),
387430 ):
431+ _restore_freed_closure_tensors (ctx )
388432 detached_inputs = detach_variable (tuple (inputs ))
389433 outputs = ctx .run_function (* detached_inputs , ** ctx .kwargs )
390434 else :
@@ -395,6 +439,7 @@ def backward(ctx, *args):
395439 level = ctx .amp_level ,
396440 dtype = ctx .amp_dtype ,
397441 ):
442+ _restore_freed_closure_tensors (ctx )
398443 detached_inputs = detach_variable (tuple (inputs ))
399444 outputs = ctx .run_function (* detached_inputs , ** ctx .kwargs )
400445
0 commit comments