Skip to content

Commit 598d83d

Browse files
add utest
1 parent 1f12254 commit 598d83d

1 file changed

Lines changed: 38 additions & 0 deletions

File tree

test/collective/fleet/hybrid_parallel_pp_layers.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,44 @@ def data_generator():
342342
# The generator gets exhausted or other expected errors
343343
pass
344344

345+
def test_eval_batch_return_host_tensor(self):
346+
"""Test eval_batch with return_host_tensor=True, covering _offload_tensors."""
347+
alex_desc = get_alex_spec()
348+
pipe_model = build_spec_layer(alex_desc, num_stages=1)
349+
pipe_model = NoPipelineParallel(pipe_model, self.strategy, self.hcg)
350+
351+
input = paddle.randn([256, 3, 224, 224])
352+
label = paddle.randint(0, 10, [147, 1])
353+
data = [[input, input, input, input], [label, label, label, label]]
354+
355+
# compute_loss=False, return_host_tensor=True → triggers _offload_tensors
356+
# with a single Tensor output (lines 680, 710-718)
357+
result = pipe_model.eval_batch(
358+
data, compute_loss=False, return_host_tensor=True
359+
)
360+
self.assertIsInstance(result, list)
361+
self.assertEqual(len(result), pipe_model.accumulate_steps)
362+
363+
def test_offload_tensors_branches(self):
364+
"""Directly test _offload_tensors covering all branches (lines 702-718)."""
365+
alex_desc = get_alex_spec()
366+
pipe_model = build_spec_layer(alex_desc, num_stages=1)
367+
pipe_model = NoPipelineParallel(pipe_model, self.strategy, self.hcg)
368+
369+
t = paddle.randn([4, 4])
370+
371+
# Branch: single Tensor (lines 710-718)
372+
pipe_model._offload_tensors(t)
373+
374+
# Branch: single non-Tensor → early return (line 712)
375+
pipe_model._offload_tensors("not_a_tensor")
376+
377+
# Branch: tuple with Tensor elements (lines 702-709)
378+
pipe_model._offload_tensors((t, paddle.randn([2, 2])))
379+
380+
# Branch: tuple containing a non-Tensor element → continue (lines 703-705)
381+
pipe_model._offload_tensors((t, "not_a_tensor"))
382+
345383

346384
if __name__ == "__main__":
347385
unittest.main()

0 commit comments

Comments
 (0)