Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 116 additions & 14 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,109 @@ def _get_align_mode_scale():
)


def _can_free(t):
"""
Check if a tensor can be freed.

A tensor can be freed only if all of the following conditions are met:
1. Tensor is not None
2. Is a paddle.Tensor type
3. Has been initialized
4. inplace_version is 0 (not using in-place ops) or explicitly marked as freeable

Args:
t: The tensor to check

Returns:
bool: True if the tensor can be freed, False otherwise
"""
return (
t is not None
and isinstance(t, paddle.Tensor)
and t._is_initialized()
and (t.inplace_version == 0 or getattr(t, "pp_can_free", False))
)


def _collect_all_tensors(obj, tensor_set):
"""
Recursively collect all tensors from a complex object.

This function traverses nested data structures (tuple, list, dict) and finds
all paddle.Tensor instances, adding them to the tensor_set. Used in Pipeline
Parallel to identify all tensors that need to be managed.

Args:
obj: Any complex object that may contain nested tuple, list, dict and paddle.Tensor
tensor_set: A set to store the collected tensors
"""
visited = set()
stack = [obj]

while stack:
current = stack.pop()
obj_id = id(current)
if obj_id in visited:
continue
visited.add(obj_id)

if isinstance(current, (tuple, list)):
stack.extend(current)
elif isinstance(current, dict):
stack.extend(current.values())
elif isinstance(current, paddle.Tensor):
# Check for duplicate addition
if current in tensor_set:
logger.debug(f"Duplicate tensor detected: {current}")
tensor_set.add(current)


def _release_output(output):
"""
Release the data pointer of output tensors.

Collects all tensors from output and frees the data pointer of those that
meet the release criteria. Used in Pipeline Parallel to release output
tensor memory after forward propagation to avoid unnecessary memory usage.

Args:
output: The output object, which can be a tensor, tuple, list, or dict
"""
all_tensors = set()
_collect_all_tensors(output, all_tensors)
for t in all_tensors:
if _can_free(t):
t._clear_dataptr()


def _release_input(input, output):
"""
Release the data pointer of input tensors.

Only releases input tensors that do not appear in the output. This is because
in Pipeline Parallel, if an input tensor is used in the output (e.g., residual
connection), it cannot be freed early. This function ensures that input memory
is released without affecting tensors needed for subsequent computation.

Args:
input: The input object, which can be a tensor, tuple, list, or dict
output: The output object, used to determine which input tensors should not be freed
"""
output_tensors = set()
_collect_all_tensors(output, output_tensors)

def can_release(t):
if not _can_free(t):
return False
return t not in output_tensors

input_tensors = set()
_collect_all_tensors(input, input_tensors)
for t in input_tensors:
if can_release(t):
t._clear_dataptr()


# assume only the first stage and last stage need data, and data consumption is ordered
# to be replaced by real micro dataset from reader
class FakeMicroDataset:
Expand Down Expand Up @@ -1126,7 +1229,7 @@ def forward_backward_pipeline(
output_buffers.append(output_tensor_tuple)

if not self.is_pipeline_last_stage():
self._release_output(output_tensor_tuple)
_release_output(output_tensor_tuple)

if steady_steps > 0 and not static_scheduler:
input_tensor = self._p2p_helper.recv_forward(
Expand Down Expand Up @@ -1175,7 +1278,7 @@ def forward_backward_pipeline(
output_buffers.append(output_tensor_tuple)

if not self.is_pipeline_last_stage():
self._release_output(output_tensor_tuple)
_release_output(output_tensor_tuple)

input_tensor, output_tensor = (
input_buffers.pop(0),
Expand Down Expand Up @@ -1426,7 +1529,7 @@ def eval_batch(
batch_p2p_comm=self._use_batch_p2p_comm,
)
if not self.is_pipeline_last_stage():
self._release_output(output_tensor_tuple)
_release_output(output_tensor_tuple)
else:
self._offload_tensors(output_tensor_tuple)

Expand Down Expand Up @@ -1456,7 +1559,7 @@ def eval_batch(
batch_p2p_comm=self._use_batch_p2p_comm,
)
if not self.is_pipeline_last_stage():
self._release_output(output_tensor_tuple)
_release_output(output_tensor_tuple)
else:
self._offload_tensors(output_tensor_tuple)

Expand Down Expand Up @@ -1567,6 +1670,7 @@ def _forward_step(
# Only increase micro batch id at virtual first/last pp stage.
# The micro batch id is used to load data, therefore, only increase it when load data.
self.micro_batch_id += 1
_release_input(input_tensor, output_tensor)
if self._enable_timer:
self.timers("forward_step").stop()
if self.processed_steps < g_profile_pipeline_details_steps:
Expand Down Expand Up @@ -2726,7 +2830,7 @@ def _process_bwd_buffer(step_id, tensor):

# append input_tensor no matter none or not
self.input_tensors[next_virtual_pp_rank].append(input_tensor)
self._release_output(output_tensor)
_release_output(output_tensor)

# run 1f1b steady steps
for micro_step in range(steady_steps):
Expand Down Expand Up @@ -2766,11 +2870,10 @@ def _process_bwd_buffer(step_id, tensor):
if self._overlap_p2p_comm:
backward_micro_step_id = micro_step

def forward_handle_wait(fwd_wait_handles, output_tensor):
def forward_handle_wait(fwd_wait_handles):
if fwd_wait_handles is not None:
for req in fwd_wait_handles:
req.wait()
self._release_output(output_tensor)

def forward_async_comm(forward_micro_step_id, output_tensor):
forward_virtual_pp_rank = self._get_virtual_pp_rank(
Expand Down Expand Up @@ -2816,6 +2919,7 @@ def forward_async_comm(forward_micro_step_id, output_tensor):
overlap_p2p_comm=True,
skip_check_meta=not self.training,
)
_release_output(output_tensor)
return (
next_forward_virtual_pp_rank,
input_tensor,
Expand Down Expand Up @@ -2905,9 +3009,7 @@ def backward_async_comm(
# structure to simplify function parameter passing
p2p_async_handle = P2PAsyncHandle(
partial(
forward_handle_wait,
fwd_wait_handles=fwd_wait_handles,
output_tensor=output_tensor,
forward_handle_wait, fwd_wait_handles=fwd_wait_handles
),
partial(
forward_async_comm,
Expand Down Expand Up @@ -3077,11 +3179,11 @@ def backward_async_comm(
output_tensor_grad
)

self._release_output(output_tensor)
_release_output(output_tensor)

assert fwd_buffer_queue.empty(), "forward buffer should be empty"
if not static_scheduler:
self._release_output(output_tensor)
_release_output(output_tensor)

# remaining backward steps
if not forward_only:
Expand Down Expand Up @@ -3502,7 +3604,7 @@ def forward_backward_pipeline(
)
self.input_tensors[next_virtual_pp_rank].append(input_tensor)

self._release_output(output_tensor)
_release_output(output_tensor)

assert send_recv_buffer_queue.empty(), (
"send_recv buffer should be empty"
Expand Down Expand Up @@ -3756,7 +3858,7 @@ def forward_backward_pipeline(
self.input_tensors[next_forward_virtual_pp_rank].append(
input_tensor
)
self._release_output(output_tensor)
_release_output(output_tensor)

if self.is_pipeline_first_stage(ignore_virtual=True):
assert (
Expand Down
44 changes: 39 additions & 5 deletions python/paddle/distributed/fleet/recompute/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,42 @@ def check_recompute_necessary(inputs):
)


def _protect_tensors(seq):
"""For each element in seq (a list or tuple of forward args), create a new
tensor Python object that shares the same underlying buffer via
_new_shared_tensor(), so that when pipeline-parallel calls
_release_input/_release_output (which clears the data pointer of the
original tensor), the copies held by recompute for backward are not
invalidated. Non-tensor elements are kept as-is.
Returns a list with the same length as seq.
"""
result = list(seq)
for idx, arg in enumerate(result):
if isinstance(arg, core.eager.Tensor):
# _new_shared_tensor() creates a new Python-level tensor object
# that shares the same C++ storage with arg, without cloning data.
shared = arg._new_shared_tensor()
assert shared is not arg, (
"_protect_tensors() must return a new Python object distinct from the original "
"tensor, otherwise the protection against pipeline-parallel tensor "
"release is ineffective."
)
result[idx] = shared
elif isinstance(arg, tuple):
# For tuple args (e.g., pipeline-parallel passes inputs as tuples),
# protect each tensor element inside the tuple individually;
# non-tensor elements (e.g., int, bool) are passed through unchanged.
protected_tuple = []
for t in arg:
if isinstance(t, core.eager.Tensor):
shared = t._new_shared_tensor()
protected_tuple.append(shared)
else:
protected_tuple.append(t)
result[idx] = tuple(protected_tuple)
return result


class CustomStatesManager:
"""CustomStatesManager"""

Expand Down Expand Up @@ -683,8 +719,8 @@ def recompute(function, *args, **kwargs):

if use_reentrant:
offload_indices = kwargs.pop('offload_indices', [])
input_args = []
# rearrange `position-args + keyword-args` into `position-args`
input_args = []
if isinstance(function, paddle.nn.Layer):
dyfunc_sig = inspect.signature(function.forward)
else:
Expand Down Expand Up @@ -712,16 +748,14 @@ def recompute(function, *args, **kwargs):
else:
raise ValueError("Unknown parameter kind.")
# Make a shallow copy of each Tensor to prevent the release of some Tensors reserved for backward in some special scenarios (such as scheduling logic of parallel pipelines)
for idx, arg in enumerate(input_args):
if isinstance(arg, core.eager.Tensor):
input_args[idx] = arg._new_shared_tensor()
protected_args = _protect_tensors(input_args)
return RecomputeFunction.apply(
function,
preserve,
offload_indices,
custom_get_state_func,
custom_set_state_func,
*input_args,
*protected_args,
)
else:
return _recompute_without_reentrant(
Expand Down
8 changes: 6 additions & 2 deletions python/paddle/distributed/fleet/recompute/recompute_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..meta_parallel.parallel_layers.random import get_rng_state_tracker
from ..meta_parallel.pp_utils import utils
from .recompute import (
_protect_tensors,
check_recompute_necessary,
custom_state_manager,
detach_variable,
Expand Down Expand Up @@ -154,10 +155,13 @@ def forward(
ctx.amp_dtype = tracer._amp_dtype
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

# Protect input tensors before saving to prevent release by pipeline parallel
protected_args = _protect_tensors(args)

with paddle.no_grad():
outputs = run_function(*args, **kwargs)
outputs = run_function(*protected_args, **kwargs)

for i, arg in enumerate(args):
for i, arg in enumerate(protected_args):
if paddle.is_tensor(arg):
state = arg.stop_gradient
if partition:
Expand Down
Loading
Loading