Skip to content

Commit fc06b80

Browse files
authored
[0 SIZE] Add zero-size input validation for WarpCTC kernel (#78316)
1 parent d1a149c commit fc06b80

File tree

3 files changed

+93
-2
lines changed

3 files changed

+93
-2
lines changed

paddle/phi/infermeta/multiary.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5741,6 +5741,34 @@ void WarpctcInferMeta(const MetaTensor& logits,
57415741
max_sequence_length = logits_dims[0];
57425742
num_sequences = logits_dims[1];
57435743
sequence_width = logits_dims[2];
5744+
5745+
int64_t labels_batch_size = label.dims()[0];
5746+
int64_t logits_length_batch_size = logits_length.dims()[0];
5747+
int64_t labels_length_batch_size = labels_length.dims()[0];
5748+
5749+
PADDLE_ENFORCE_EQ(
5750+
labels_batch_size,
5751+
num_sequences,
5752+
common::errors::InvalidArgument(
5753+
"Expected label to have size %lld at dimension 0, but got size %d",
5754+
num_sequences,
5755+
labels_batch_size));
5756+
5757+
PADDLE_ENFORCE_EQ(
5758+
logits_length_batch_size,
5759+
num_sequences,
5760+
common::errors::InvalidArgument("Expected logits_length to have size "
5761+
"%lld at dimension 0, but got size %d",
5762+
num_sequences,
5763+
logits_length_batch_size));
5764+
5765+
PADDLE_ENFORCE_EQ(
5766+
labels_length_batch_size,
5767+
num_sequences,
5768+
common::errors::InvalidArgument("Expected labels_length to have size "
5769+
"%lld at dimension 0, but got size %d",
5770+
num_sequences,
5771+
labels_length_batch_size));
57445772
} else {
57455773
max_sequence_length = -1;
57465774
num_sequences = -1;

test/legacy_test/test_layers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ def test_warpctc_with_padding(self):
10871087
# TODO(minqiyang): dygraph do not support lod now
10881088
with self.static_graph():
10891089
input_length = paddle.static.data(
1090-
name='logits_length', shape=[11], dtype='int64'
1090+
name='logits_length', shape=[12], dtype='int64'
10911091
)
10921092
label_length = paddle.static.data(
10931093
name='labels_length', shape=[12], dtype='int64'
@@ -1096,7 +1096,7 @@ def test_warpctc_with_padding(self):
10961096
name='label', shape=[12, 1], dtype='int32'
10971097
)
10981098
predict = paddle.static.data(
1099-
name='predict', shape=[4, 4, 8], dtype='float32'
1099+
name='predict', shape=[4, 12, 8], dtype='float32'
11001100
)
11011101
output = paddle.nn.functional.ctc_loss(
11021102
log_probs=predict,

test/legacy_test/test_warpctc_op.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,69 @@ def test_dygraph_zero_size():
630630
self.assertRaises(ValueError, test_dygraph_zero_size)
631631
paddle.enable_static()
632632

633+
def test_dygraph_zero_size_with_padding(self):
634+
"""Test zero size inputs when using LogitsLength and LabelLength."""
635+
paddle.disable_static()
636+
637+
batch_size = 4
638+
num_classes = 8
639+
max_sequence_length = 5
640+
logits = paddle.uniform(
641+
[max_sequence_length, batch_size, num_classes],
642+
min=0.1,
643+
max=1.0,
644+
)
645+
labels = paddle.randint(
646+
0, num_classes - 1, [batch_size, num_classes], dtype="int32"
647+
)
648+
labels_length = paddle.to_tensor([2, 1, 3, 2])
649+
logits_length = paddle.to_tensor([5, 4, 3, 2])
650+
651+
def test_zero_labels_batch_size():
652+
labels = paddle.zeros([0, 3], dtype="int32")
653+
paddle.nn.functional.ctc_loss(
654+
log_probs=logits,
655+
labels=labels,
656+
input_lengths=logits_length,
657+
label_lengths=labels_length,
658+
)
659+
660+
def test_zero_logits_length_batch_size():
661+
logits_length = paddle.zeros([0], dtype="int32")
662+
paddle.nn.functional.ctc_loss(
663+
log_probs=logits,
664+
labels=labels,
665+
input_lengths=logits_length,
666+
label_lengths=labels_length,
667+
)
668+
669+
def test_zero_labels_length_batch_size():
670+
labels_length = paddle.zeros([0], dtype="int32")
671+
paddle.nn.functional.ctc_loss(
672+
log_probs=logits,
673+
labels=labels,
674+
input_lengths=logits_length,
675+
label_lengths=labels_length,
676+
)
677+
678+
self.assertRaisesRegex(
679+
ValueError,
680+
f"Expected label to have size {batch_size} at dimension 0, but got size 0",
681+
test_zero_labels_batch_size,
682+
)
683+
self.assertRaisesRegex(
684+
ValueError,
685+
f"Expected logits_length to have size {batch_size} at dimension 0, but got size 0",
686+
test_zero_logits_length_batch_size,
687+
)
688+
self.assertRaisesRegex(
689+
ValueError,
690+
f"Expected labels_length to have size {batch_size} at dimension 0, but got size 0",
691+
test_zero_labels_length_batch_size,
692+
)
693+
694+
paddle.enable_static()
695+
633696

634697
class TestCTCLossAPICase(unittest.TestCase):
635698
def test_class_api(self):

0 commit comments

Comments
 (0)