Skip to content

Commit fb4e3d2

Browse files
committed
reject input rank mismatch in transpose reshape
1 parent 697543f commit fb4e3d2

2 files changed

Lines changed: 31 additions & 1 deletion

File tree

src/subgraph/static-transpose.c

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,15 @@ static enum xnn_status reshape_transpose_operator(
7474
assert(output_id < num_values);
7575

7676
const size_t num_dims = opdata->shape1.num_dims;
77-
assert(input->shape.num_dims == num_dims);
77+
if (input->shape.num_dims != num_dims) {
78+
xnn_log_error(
79+
"failed to reshape %s operator with input ID #%" PRIu32
80+
": number of input dimensions (%zu) does not match the number of "
81+
"permutation dimensions (%zu)",
82+
xnn_node_type_to_string(xnn_node_type_static_transpose), input_id,
83+
input->shape.num_dims, num_dims);
84+
return xnn_status_invalid_parameter;
85+
}
7886

7987
switch (opdata->operator_objects[0]->type) {
8088
case xnn_operator_type_transpose_nd_x16: {

test/subgraph/static-transpose.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,26 @@ INSTANTIATE_TEST_SUITE_P(Transpose, TransposeQU8, rank_params);
8585
INSTANTIATE_TEST_SUITE_P(Transpose, TransposeF16, rank_params);
8686
INSTANTIATE_TEST_SUITE_P(Transpose, TransposeF32, rank_params);
8787

88+
TEST(Transpose, reshape_rejects_input_rank_mismatch) {
89+
ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
90+
91+
// The permutation fixes the operator rank at define time. An external input
92+
// may be reshaped to a different rank at runtime; the reshape must reject the
93+
// mismatch instead of transposing with a stale dimension count and reading
94+
// past the input buffer.
95+
const std::vector<size_t> perm = {2, 0, 1};
96+
SubgraphTester subgraph(2);
97+
subgraph.AddInputTensor(3, xnn_datatype_fp32, 0)
98+
.AddOutputTensor(3, xnn_datatype_fp32, 1)
99+
.AddTranspose(perm, 0, 1);
100+
if (subgraph.CreateRuntime() == xnn_status_unsupported_hardware) {
101+
GTEST_SKIP();
102+
}
103+
104+
std::vector<float> data(6, 0.0f);
105+
subgraph.ReshapeExternalTensor(std::vector<size_t>({2, 3}), data.data(), 0)
106+
.ReshapeRuntime();
107+
EXPECT_EQ(subgraph.Status(), xnn_status_invalid_parameter);
108+
}
109+
88110
} // namespace xnnpack

0 commit comments

Comments
 (0)