Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/subgraph/rope.c
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,26 @@ static enum xnn_status reshape_rope_operator(
if (status != xnn_status_success) {
return status;
}

// The rotation table is addressed per token: the microkernel reads `channels`
// weight elements at offset `token * channels` for every token in
// [0, tokens). A runtime input carrying more tokens than the weights tensor
// was built for therefore walks off the end of the weights buffer. Reject it
// here, where both shapes are known, rather than reading out of bounds.
const uint32_t weights_id = opdata->inputs[1];
assert(weights_id < num_values);
const struct xnn_runtime_value* weights_value = values + weights_id;
const size_t weights_elements =
xnn_shape_multiply_all_dims(&weights_value->shape);
if (weights_elements / channels < tokens) {
xnn_log_error(
"failed to reshape %s operator with input ID #%" PRIu32
": %zu tokens of %zu channels exceed the %zu-element weights tensor",
xnn_node_type_to_string(xnn_node_type_rope),
input_id, tokens, channels, weights_elements);
return xnn_status_invalid_parameter;
}

const uint32_t output_id = opdata->outputs[0];
assert(output_id < num_values);
struct xnn_runtime_value* output_value = values + output_id;
Expand Down
32 changes: 32 additions & 0 deletions test/subgraph/rope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,36 @@ void TestImpl() {
TEST(RoPEF16, test) { TestImpl<xnn_float16>(); }
TEST(RoPEF32, test) { TestImpl<float>(); }

TEST(RoPEF32, reshape_rejects_tokens_exceeding_weights) {
ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));

SubgraphTester subgraph(3);
subgraph.AddInputTensor(4, xnn_datatype_fp32, 0)
.AddInputTensor(2, xnn_datatype_fp32, 1)
.AddOutputTensor(4, xnn_datatype_fp32, 2)
.AddRoPE(0, 1, 2);
const xnn_status create_status = subgraph.CreateRuntime();
if (create_status == xnn_status_unsupported_hardware) {
GTEST_SKIP();
return;
}
ASSERT_EQ(create_status, xnn_status_success);

const size_t batch_size = 1;
const size_t heads = 2;
const size_t channels = 4;
const size_t weights_tokens = 4;
const size_t input_tokens = 16; // more tokens than the weights table holds

Tensor<float> input({batch_size, input_tokens, heads, channels},
XnnExtraBytes);
Tensor<float> weights({weights_tokens, channels}, XnnExtraBytes);

subgraph
.ReshapeExternalTensor({batch_size, input_tokens, heads, channels},
input.base(), 0)
.ReshapeExternalTensor({weights_tokens, channels}, weights.base(), 1);
ASSERT_EQ(subgraph.ReshapeRuntime().Status(), xnn_status_invalid_parameter);
}

} // namespace xnnpack