Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 59 additions & 28 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,9 @@ def train_batch(

return train_loss

def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
def eval_batch(
self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False
):
# check loss_fn_idx is valid and loss_fn exists
assert (
loss_fn_idx in range(len(self._layers._loss_fn))
Expand All @@ -637,6 +639,7 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
)

loss_list = []
output_list = []
for _ in range(self.accumulate_steps):
# data prepare
data_iter = next(micro_dataset)
Expand All @@ -648,43 +651,71 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
# forward
output_tensor = self._layers.forward(input_tensor)

# loss is loss_fn[loss_fn_idx]'s result
loss = None
# cal loss
for idx, loss_fn in enumerate(self._layers._loss_fn):
loss_tensor = loss_fn(output_tensor, label)
assert isinstance(loss_tensor, paddle.Tensor), (
"Currently, loss_fn should obtain Paddle.Tensor dtype"
)
if compute_loss:
# loss is loss_fn[loss_fn_idx]'s result
loss = None

if self.total_loss is None:
self.total_loss = []
# when self.total_loss length is less than idx, append a new tensor
if len(self.total_loss) <= idx:
self.total_loss.append([])
# cal loss
for idx, loss_fn in enumerate(self._layers._loss_fn):
loss_tensor = loss_fn(output_tensor, label)
assert isinstance(loss_tensor, paddle.Tensor), (
"Currently, loss_fn should obtain Paddle.Tensor dtype"
)
if self.total_loss is None:
self.total_loss = []
# when self.total_loss length is less than idx, append a new tensor
if len(self.total_loss) <= idx:
self.total_loss.append([])

self.total_loss[idx].append(loss_tensor.detach())
self.total_loss[idx].append(loss_tensor.detach())

if idx == self.loss_fn_idx:
loss = loss_tensor
if idx == self.loss_fn_idx:
loss = loss_tensor

assert self.total_loss is not None, (
"train_batch() in last stage should obtain valid loss"
)
assert self.total_loss is not None, (
"train_batch() in last stage should obtain valid loss"
)
else:
if return_host_tensor:
self._offload_tensors(output_tensor)
output_list.append(output_tensor)

losses = []
return_micro_batch_loss = False
for idx in range(len(self._layers._loss_fn)):
self.total_loss[idx] = paddle.to_tensor(self.total_loss[idx])
if not return_micro_batch_loss:
if compute_loss:
losses = []
return_micro_batch_loss = False
for idx in range(len(self._layers._loss_fn)):
self.total_loss[idx] = paddle.to_tensor(self.total_loss[idx])
# if not return_micro_batch_loss:
# TODO(shenliang03): it will use mean/sum to calculate loss
tmp = paddle.zeros_like(self.total_loss[idx][0])
for loss in self.total_loss[idx]:
tmp += loss.detach()
losses.append(tmp / self.accumulate_steps)
else:
losses.append(self.total_loss[idx].detach())
return losses[0] if len(losses) == 1 else losses
# else:
# losses.append(self.total_loss[idx].detach())
res = losses[0] if len(losses) == 1 else losses
else:
res = output_list
return res

def _offload_tensors(self, output_tensor):
if isinstance(output_tensor, (tuple, list)):
for t in output_tensor:
if not isinstance(t, paddle.Tensor):
continue
host_tensor = (
t.pin_memory() if hasattr(t, "pin_memory") else t.cpu()
)
host_tensor._share_buffer_to(t)
else:
if not isinstance(output_tensor, paddle.Tensor):
return
host_tensor = (
output_tensor.pin_memory()
if hasattr(output_tensor, "pin_memory")
else output_tensor.cpu()
)
host_tensor._share_buffer_to(output_tensor)


class PipelineParallel(MetaParallelBase):
Expand Down
38 changes: 38 additions & 0 deletions test/collective/fleet/hybrid_parallel_pp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,44 @@ def data_generator():
# The generator gets exhausted or other expected errors
pass

def test_eval_batch_return_host_tensor(self):
"""Test eval_batch with return_host_tensor=True, covering _offload_tensors."""
alex_desc = get_alex_spec()
pipe_model = build_spec_layer(alex_desc, num_stages=1)
pipe_model = NoPipelineParallel(pipe_model, self.strategy, self.hcg)

input = paddle.randn([256, 3, 224, 224])
label = paddle.randint(0, 10, [147, 1])
data = [[input, input, input, input], [label, label, label, label]]

# compute_loss=False, return_host_tensor=True → triggers _offload_tensors
# with a single Tensor output (lines 680, 710-718)
result = pipe_model.eval_batch(
data, compute_loss=False, return_host_tensor=True
)
self.assertIsInstance(result, list)
self.assertEqual(len(result), pipe_model.accumulate_steps)

def test_offload_tensors_branches(self):
"""Directly test _offload_tensors covering all branches (lines 702-718)."""
alex_desc = get_alex_spec()
pipe_model = build_spec_layer(alex_desc, num_stages=1)
pipe_model = NoPipelineParallel(pipe_model, self.strategy, self.hcg)

t = paddle.randn([4, 4])

# Branch: single Tensor (lines 710-718)
pipe_model._offload_tensors(t)

# Branch: single non-Tensor → early return (line 712)
pipe_model._offload_tensors("not_a_tensor")

# Branch: tuple with Tensor elements (lines 702-709)
pipe_model._offload_tensors((t, paddle.randn([2, 2])))

# Branch: tuple containing a non-Tensor element → continue (lines 703-705)
pipe_model._offload_tensors((t, "not_a_tensor"))


if __name__ == "__main__":
unittest.main()
Loading