@@ -96,6 +96,109 @@ def _get_align_mode_scale():
9696 )
9797
9898
99+ def _can_free (t ):
100+ """
101+ Check if a tensor can be freed.
102+
103+ A tensor can be freed only if all of the following conditions are met:
104+ 1. Tensor is not None
105+ 2. Is a paddle.Tensor type
106+ 3. Has been initialized
107+ 4. inplace_version is 0 (not using in-place ops) or explicitly marked as freeable
108+
109+ Args:
110+ t: The tensor to check
111+
112+ Returns:
113+ bool: True if the tensor can be freed, False otherwise
114+ """
115+ return (
116+ t is not None
117+ and isinstance (t , paddle .Tensor )
118+ and t ._is_initialized ()
119+ and (t .inplace_version == 0 or getattr (t , "pp_can_free" , False ))
120+ )
121+
122+
123+ def _collect_all_tensors (obj , tensor_set ):
124+ """
125+ Recursively collect all tensors from a complex object.
126+
127+ This function traverses nested data structures (tuple, list, dict) and finds
128+ all paddle.Tensor instances, adding them to the tensor_set. Used in Pipeline
129+ Parallel to identify all tensors that need to be managed.
130+
131+ Args:
132+ obj: Any complex object that may contain nested tuple, list, dict and paddle.Tensor
133+ tensor_set: A set to store the collected tensors
134+ """
135+ visited = set ()
136+ stack = [obj ]
137+
138+ while stack :
139+ current = stack .pop ()
140+ obj_id = id (current )
141+ if obj_id in visited :
142+ continue
143+ visited .add (obj_id )
144+
145+ if isinstance (current , (tuple , list )):
146+ stack .extend (current )
147+ elif isinstance (current , dict ):
148+ stack .extend (current .values ())
149+ elif isinstance (current , paddle .Tensor ):
150+ # Check for duplicate addition
151+ if current in tensor_set :
152+ logger .debug (f"Duplicate tensor detected: { current } " )
153+ tensor_set .add (current )
154+
155+
156+ def _release_output (output ):
157+ """
158+ Release the data pointer of output tensors.
159+
160+ Collects all tensors from output and frees the data pointer of those that
161+ meet the release criteria. Used in Pipeline Parallel to release output
162+ tensor memory after forward propagation to avoid unnecessary memory usage.
163+
164+ Args:
165+ output: The output object, which can be a tensor, tuple, list, or dict
166+ """
167+ all_tensors = set ()
168+ _collect_all_tensors (output , all_tensors )
169+ for t in all_tensors :
170+ if _can_free (t ):
171+ t ._clear_dataptr ()
172+
173+
174+ def _release_input (input , output ):
175+ """
176+ Release the data pointer of input tensors.
177+
178+ Only releases input tensors that do not appear in the output. This is because
179+ in Pipeline Parallel, if an input tensor is used in the output (e.g., residual
180+ connection), it cannot be freed early. This function ensures that input memory
181+ is released without affecting tensors needed for subsequent computation.
182+
183+ Args:
184+ input: The input object, which can be a tensor, tuple, list, or dict
185+ output: The output object, used to determine which input tensors should not be freed
186+ """
187+ output_tensors = set ()
188+ _collect_all_tensors (output , output_tensors )
189+
190+ def can_release (t ):
191+ if not _can_free (t ):
192+ return False
193+ return t not in output_tensors
194+
195+ input_tensors = set ()
196+ _collect_all_tensors (input , input_tensors )
197+ for t in input_tensors :
198+ if can_release (t ):
199+ t ._clear_dataptr ()
200+
201+
99202# assume only the first stage and last stage need data, and data consumption is ordered
100203# to be replaced by real micro dataset from reader
101204class FakeMicroDataset :
@@ -1126,7 +1229,7 @@ def forward_backward_pipeline(
11261229 output_buffers .append (output_tensor_tuple )
11271230
11281231 if not self .is_pipeline_last_stage ():
1129- self . _release_output (output_tensor_tuple )
1232+ _release_output (output_tensor_tuple )
11301233
11311234 if steady_steps > 0 and not static_scheduler :
11321235 input_tensor = self ._p2p_helper .recv_forward (
@@ -1175,7 +1278,7 @@ def forward_backward_pipeline(
11751278 output_buffers .append (output_tensor_tuple )
11761279
11771280 if not self .is_pipeline_last_stage ():
1178- self . _release_output (output_tensor_tuple )
1281+ _release_output (output_tensor_tuple )
11791282
11801283 input_tensor , output_tensor = (
11811284 input_buffers .pop (0 ),
@@ -1426,7 +1529,7 @@ def eval_batch(
14261529 batch_p2p_comm = self ._use_batch_p2p_comm ,
14271530 )
14281531 if not self .is_pipeline_last_stage ():
1429- self . _release_output (output_tensor_tuple )
1532+ _release_output (output_tensor_tuple )
14301533 else :
14311534 self ._offload_tensors (output_tensor_tuple )
14321535
@@ -1456,7 +1559,7 @@ def eval_batch(
14561559 batch_p2p_comm = self ._use_batch_p2p_comm ,
14571560 )
14581561 if not self .is_pipeline_last_stage ():
1459- self . _release_output (output_tensor_tuple )
1562+ _release_output (output_tensor_tuple )
14601563 else :
14611564 self ._offload_tensors (output_tensor_tuple )
14621565
@@ -1567,6 +1670,7 @@ def _forward_step(
15671670 # Only increase micro batch id at virtual first/last pp stage.
15681671 # The micro batch id is used to load data, therefore, only increase it when load data.
15691672 self .micro_batch_id += 1
1673+ _release_input (input_tensor , output_tensor )
15701674 if self ._enable_timer :
15711675 self .timers ("forward_step" ).stop ()
15721676 if self .processed_steps < g_profile_pipeline_details_steps :
@@ -2726,7 +2830,7 @@ def _process_bwd_buffer(step_id, tensor):
27262830
27272831 # append input_tensor no matter none or not
27282832 self .input_tensors [next_virtual_pp_rank ].append (input_tensor )
2729- self . _release_output (output_tensor )
2833+ _release_output (output_tensor )
27302834
27312835 # run 1f1b steady steps
27322836 for micro_step in range (steady_steps ):
@@ -2766,11 +2870,10 @@ def _process_bwd_buffer(step_id, tensor):
27662870 if self ._overlap_p2p_comm :
27672871 backward_micro_step_id = micro_step
27682872
2769- def forward_handle_wait (fwd_wait_handles , output_tensor ):
2873+ def forward_handle_wait (fwd_wait_handles ):
27702874 if fwd_wait_handles is not None :
27712875 for req in fwd_wait_handles :
27722876 req .wait ()
2773- self ._release_output (output_tensor )
27742877
27752878 def forward_async_comm (forward_micro_step_id , output_tensor ):
27762879 forward_virtual_pp_rank = self ._get_virtual_pp_rank (
@@ -2816,6 +2919,7 @@ def forward_async_comm(forward_micro_step_id, output_tensor):
28162919 overlap_p2p_comm = True ,
28172920 skip_check_meta = not self .training ,
28182921 )
2922+ _release_output (output_tensor )
28192923 return (
28202924 next_forward_virtual_pp_rank ,
28212925 input_tensor ,
@@ -2905,9 +3009,7 @@ def backward_async_comm(
29053009 # structure to simplify function parameter passing
29063010 p2p_async_handle = P2PAsyncHandle (
29073011 partial (
2908- forward_handle_wait ,
2909- fwd_wait_handles = fwd_wait_handles ,
2910- output_tensor = output_tensor ,
3012+ forward_handle_wait , fwd_wait_handles = fwd_wait_handles
29113013 ),
29123014 partial (
29133015 forward_async_comm ,
@@ -3077,11 +3179,11 @@ def backward_async_comm(
30773179 output_tensor_grad
30783180 )
30793181
3080- self . _release_output (output_tensor )
3182+ _release_output (output_tensor )
30813183
30823184 assert fwd_buffer_queue .empty (), "forward buffer should be empty"
30833185 if not static_scheduler :
3084- self . _release_output (output_tensor )
3186+ _release_output (output_tensor )
30853187
30863188 # remaining backward steps
30873189 if not forward_only :
@@ -3502,7 +3604,7 @@ def forward_backward_pipeline(
35023604 )
35033605 self .input_tensors [next_virtual_pp_rank ].append (input_tensor )
35043606
3505- self . _release_output (output_tensor )
3607+ _release_output (output_tensor )
35063608
35073609 assert send_recv_buffer_queue .empty (), (
35083610 "send_recv buffer should be empty"
@@ -3756,7 +3858,7 @@ def forward_backward_pipeline(
37563858 self .input_tensors [next_forward_virtual_pp_rank ].append (
37573859 input_tensor
37583860 )
3759- self . _release_output (output_tensor )
3861+ _release_output (output_tensor )
37603862
37613863 if self .is_pipeline_first_stage (ignore_virtual = True ):
37623864 assert (
0 commit comments