Skip to content

Commit 31362cd

Browse files
committed
release input and output tensor in pipelineparallel
1 parent 879657d commit 31362cd

File tree

5 files changed

+563
-21
lines changed

5 files changed

+563
-21
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 116 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,109 @@ def _get_align_mode_scale():
9696
)
9797

9898

99+
def _can_free(t):
100+
"""
101+
Check if a tensor can be freed.
102+
103+
A tensor can be freed only if all of the following conditions are met:
104+
1. Tensor is not None
105+
2. Is a paddle.Tensor type
106+
3. Has been initialized
107+
4. inplace_version is 0 (not using in-place ops) or explicitly marked as freeable
108+
109+
Args:
110+
t: The tensor to check
111+
112+
Returns:
113+
bool: True if the tensor can be freed, False otherwise
114+
"""
115+
return (
116+
t is not None
117+
and isinstance(t, paddle.Tensor)
118+
and t._is_initialized()
119+
and (t.inplace_version == 0 or getattr(t, "pp_can_free", False))
120+
)
121+
122+
123+
def _collect_all_tensors(obj, tensor_set):
124+
"""
125+
Recursively collect all tensors from a complex object.
126+
127+
This function traverses nested data structures (tuple, list, dict) and finds
128+
all paddle.Tensor instances, adding them to the tensor_set. Used in Pipeline
129+
Parallel to identify all tensors that need to be managed.
130+
131+
Args:
132+
obj: Any complex object that may contain nested tuple, list, dict and paddle.Tensor
133+
tensor_set: A set to store the collected tensors
134+
"""
135+
visited = set()
136+
stack = [obj]
137+
138+
while stack:
139+
current = stack.pop()
140+
obj_id = id(current)
141+
if obj_id in visited:
142+
continue
143+
visited.add(obj_id)
144+
145+
if isinstance(current, (tuple, list)):
146+
stack.extend(current)
147+
elif isinstance(current, dict):
148+
stack.extend(current.values())
149+
elif isinstance(current, paddle.Tensor):
150+
# Check for duplicate addition
151+
if current in tensor_set:
152+
logger.debug(f"Duplicate tensor detected: {current}")
153+
tensor_set.add(current)
154+
155+
156+
def _release_output(output):
157+
"""
158+
Release the data pointer of output tensors.
159+
160+
Collects all tensors from output and frees the data pointer of those that
161+
meet the release criteria. Used in Pipeline Parallel to release output
162+
tensor memory after forward propagation to avoid unnecessary memory usage.
163+
164+
Args:
165+
output: The output object, which can be a tensor, tuple, list, or dict
166+
"""
167+
all_tensors = set()
168+
_collect_all_tensors(output, all_tensors)
169+
for t in all_tensors:
170+
if _can_free(t):
171+
t._clear_dataptr()
172+
173+
174+
def _release_input(input, output):
175+
"""
176+
Release the data pointer of input tensors.
177+
178+
Only releases input tensors that do not appear in the output. This is because
179+
in Pipeline Parallel, if an input tensor is used in the output (e.g., residual
180+
connection), it cannot be freed early. This function ensures that input memory
181+
is released without affecting tensors needed for subsequent computation.
182+
183+
Args:
184+
input: The input object, which can be a tensor, tuple, list, or dict
185+
output: The output object, used to determine which input tensors should not be freed
186+
"""
187+
output_tensors = set()
188+
_collect_all_tensors(output, output_tensors)
189+
190+
def can_release(t):
191+
if not _can_free(t):
192+
return False
193+
return t not in output_tensors
194+
195+
input_tensors = set()
196+
_collect_all_tensors(input, input_tensors)
197+
for t in input_tensors:
198+
if can_release(t):
199+
t._clear_dataptr()
200+
201+
99202
# assume only the first stage and last stage need data, and data consumption is ordered
100203
# to be replaced by real micro dataset from reader
101204
class FakeMicroDataset:
@@ -1126,7 +1229,7 @@ def forward_backward_pipeline(
11261229
output_buffers.append(output_tensor_tuple)
11271230

11281231
if not self.is_pipeline_last_stage():
1129-
self._release_output(output_tensor_tuple)
1232+
_release_output(output_tensor_tuple)
11301233

11311234
if steady_steps > 0 and not static_scheduler:
11321235
input_tensor = self._p2p_helper.recv_forward(
@@ -1175,7 +1278,7 @@ def forward_backward_pipeline(
11751278
output_buffers.append(output_tensor_tuple)
11761279

11771280
if not self.is_pipeline_last_stage():
1178-
self._release_output(output_tensor_tuple)
1281+
_release_output(output_tensor_tuple)
11791282

11801283
input_tensor, output_tensor = (
11811284
input_buffers.pop(0),
@@ -1426,7 +1529,7 @@ def eval_batch(
14261529
batch_p2p_comm=self._use_batch_p2p_comm,
14271530
)
14281531
if not self.is_pipeline_last_stage():
1429-
self._release_output(output_tensor_tuple)
1532+
_release_output(output_tensor_tuple)
14301533
else:
14311534
self._offload_tensors(output_tensor_tuple)
14321535

@@ -1456,7 +1559,7 @@ def eval_batch(
14561559
batch_p2p_comm=self._use_batch_p2p_comm,
14571560
)
14581561
if not self.is_pipeline_last_stage():
1459-
self._release_output(output_tensor_tuple)
1562+
_release_output(output_tensor_tuple)
14601563
else:
14611564
self._offload_tensors(output_tensor_tuple)
14621565

@@ -1567,6 +1670,7 @@ def _forward_step(
15671670
# Only increase micro batch id at virtual first/last pp stage.
15681671
# The micro batch id is used to load data, therefore, only increase it when load data.
15691672
self.micro_batch_id += 1
1673+
_release_input(input_tensor, output_tensor)
15701674
if self._enable_timer:
15711675
self.timers("forward_step").stop()
15721676
if self.processed_steps < g_profile_pipeline_details_steps:
@@ -2726,7 +2830,7 @@ def _process_bwd_buffer(step_id, tensor):
27262830

27272831
# append input_tensor no matter none or not
27282832
self.input_tensors[next_virtual_pp_rank].append(input_tensor)
2729-
self._release_output(output_tensor)
2833+
_release_output(output_tensor)
27302834

27312835
# run 1f1b steady steps
27322836
for micro_step in range(steady_steps):
@@ -2766,11 +2870,10 @@ def _process_bwd_buffer(step_id, tensor):
27662870
if self._overlap_p2p_comm:
27672871
backward_micro_step_id = micro_step
27682872

2769-
def forward_handle_wait(fwd_wait_handles, output_tensor):
2873+
def forward_handle_wait(fwd_wait_handles):
27702874
if fwd_wait_handles is not None:
27712875
for req in fwd_wait_handles:
27722876
req.wait()
2773-
self._release_output(output_tensor)
27742877

27752878
def forward_async_comm(forward_micro_step_id, output_tensor):
27762879
forward_virtual_pp_rank = self._get_virtual_pp_rank(
@@ -2816,6 +2919,7 @@ def forward_async_comm(forward_micro_step_id, output_tensor):
28162919
overlap_p2p_comm=True,
28172920
skip_check_meta=not self.training,
28182921
)
2922+
_release_output(output_tensor)
28192923
return (
28202924
next_forward_virtual_pp_rank,
28212925
input_tensor,
@@ -2905,9 +3009,7 @@ def backward_async_comm(
29053009
# structure to simplify function parameter passing
29063010
p2p_async_handle = P2PAsyncHandle(
29073011
partial(
2908-
forward_handle_wait,
2909-
fwd_wait_handles=fwd_wait_handles,
2910-
output_tensor=output_tensor,
3012+
forward_handle_wait, fwd_wait_handles=fwd_wait_handles
29113013
),
29123014
partial(
29133015
forward_async_comm,
@@ -3077,11 +3179,11 @@ def backward_async_comm(
30773179
output_tensor_grad
30783180
)
30793181

3080-
self._release_output(output_tensor)
3182+
_release_output(output_tensor)
30813183

30823184
assert fwd_buffer_queue.empty(), "forward buffer should be empty"
30833185
if not static_scheduler:
3084-
self._release_output(output_tensor)
3186+
_release_output(output_tensor)
30853187

30863188
# remaining backward steps
30873189
if not forward_only:
@@ -3502,7 +3604,7 @@ def forward_backward_pipeline(
35023604
)
35033605
self.input_tensors[next_virtual_pp_rank].append(input_tensor)
35043606

3505-
self._release_output(output_tensor)
3607+
_release_output(output_tensor)
35063608

35073609
assert send_recv_buffer_queue.empty(), (
35083610
"send_recv buffer should be empty"
@@ -3756,7 +3858,7 @@ def forward_backward_pipeline(
37563858
self.input_tensors[next_forward_virtual_pp_rank].append(
37573859
input_tensor
37583860
)
3759-
self._release_output(output_tensor)
3861+
_release_output(output_tensor)
37603862

37613863
if self.is_pipeline_first_stage(ignore_virtual=True):
37623864
assert (

python/paddle/distributed/fleet/recompute/recompute.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,42 @@ def check_recompute_necessary(inputs):
115115
)
116116

117117

118+
def _protect_tensors(seq):
119+
"""For each element in seq (a list or tuple of forward args), create a new
120+
tensor Python object that shares the same underlying buffer via
121+
_new_shared_tensor(), so that when pipeline-parallel calls
122+
_release_input/_release_output (which clears the data pointer of the
123+
original tensor), the copies held by recompute for backward are not
124+
invalidated. Non-tensor elements are kept as-is.
125+
Returns a list with the same length as seq.
126+
"""
127+
result = list(seq)
128+
for idx, arg in enumerate(result):
129+
if isinstance(arg, core.eager.Tensor):
130+
# _new_shared_tensor() creates a new Python-level tensor object
131+
# that shares the same C++ storage with arg, without cloning data.
132+
shared = arg._new_shared_tensor()
133+
assert shared is not arg, (
134+
"_protect_tensors() must return a new Python object distinct from the original "
135+
"tensor, otherwise the protection against pipeline-parallel tensor "
136+
"release is ineffective."
137+
)
138+
result[idx] = shared
139+
elif isinstance(arg, tuple):
140+
# For tuple args (e.g., pipeline-parallel passes inputs as tuples),
141+
# protect each tensor element inside the tuple individually;
142+
# non-tensor elements (e.g., int, bool) are passed through unchanged.
143+
protected_tuple = []
144+
for t in arg:
145+
if isinstance(t, core.eager.Tensor):
146+
shared = t._new_shared_tensor()
147+
protected_tuple.append(shared)
148+
else:
149+
protected_tuple.append(t)
150+
result[idx] = tuple(protected_tuple)
151+
return result
152+
153+
118154
class CustomStatesManager:
119155
"""CustomStatesManager"""
120156

@@ -683,8 +719,8 @@ def recompute(function, *args, **kwargs):
683719

684720
if use_reentrant:
685721
offload_indices = kwargs.pop('offload_indices', [])
686-
input_args = []
687722
# rearrange `position-args + keyword-args` into `position-args`
723+
input_args = []
688724
if isinstance(function, paddle.nn.Layer):
689725
dyfunc_sig = inspect.signature(function.forward)
690726
else:
@@ -712,16 +748,14 @@ def recompute(function, *args, **kwargs):
712748
else:
713749
raise ValueError("Unknown parameter kind.")
714750
# Make a shallow copy of each Tensor to prevent the release of some Tensors reserved for backward in some special scenarios (such as scheduling logic of parallel pipelines)
715-
for idx, arg in enumerate(input_args):
716-
if isinstance(arg, core.eager.Tensor):
717-
input_args[idx] = arg._new_shared_tensor()
751+
protected_args = _protect_tensors(input_args)
718752
return RecomputeFunction.apply(
719753
function,
720754
preserve,
721755
offload_indices,
722756
custom_get_state_func,
723757
custom_set_state_func,
724-
*input_args,
758+
*protected_args,
725759
)
726760
else:
727761
return _recompute_without_reentrant(

python/paddle/distributed/fleet/recompute/recompute_hybrid.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..meta_parallel.parallel_layers.random import get_rng_state_tracker
2727
from ..meta_parallel.pp_utils import utils
2828
from .recompute import (
29+
_protect_tensors,
2930
check_recompute_necessary,
3031
custom_state_manager,
3132
detach_variable,
@@ -154,10 +155,13 @@ def forward(
154155
ctx.amp_dtype = tracer._amp_dtype
155156
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
156157

158+
# Protect input tensors before saving to prevent release by pipeline parallel
159+
protected_args = _protect_tensors(args)
160+
157161
with paddle.no_grad():
158-
outputs = run_function(*args, **kwargs)
162+
outputs = run_function(*protected_args, **kwargs)
159163

160-
for i, arg in enumerate(args):
164+
for i, arg in enumerate(protected_args):
161165
if paddle.is_tensor(arg):
162166
state = arg.stop_gradient
163167
if partition:

0 commit comments

Comments
 (0)