Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions runtime/onert/core/src/compiler/ShapeValidator.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Comment thread
ragmani marked this conversation as resolved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -167,29 +168,17 @@ void ShapeValidator::visit(const ir::operation::BroadcastTo &node)

const auto input_index{node.getInputs().at(ir::operation::BroadcastTo::Input::INPUT)};
const auto shape_index{node.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
const auto &input_shape = operands.at(input_index).shape();
const auto &output_shape_vec = operands.at(shape_index).asVector<int32_t>();
int input_num_dims = input_shape.rank();
int output_num_dims = output_shape_vec.size();
OP_REQUIRES(input_num_dims <= output_num_dims);

std::vector<int32_t> input_shape = operands.at(input_index).shape().dims();
std::vector<int32_t> target_shape = operands.at(shape_index).asVector<int32_t>();

int in_len = input_shape.size();
int tgt_len = target_shape.size();
int max_len = std::max(in_len, tgt_len);

std::vector<int32_t> in_shape_padded(max_len, 1);
std::vector<int32_t> tgt_shape_padded(max_len, 1);

for (int i = 0; i < in_len; i++)
{
in_shape_padded[max_len - in_len + i] = input_shape[i];
}
for (int i = 0; i < tgt_len; i++)
{
tgt_shape_padded[max_len - tgt_len + i] = target_shape[i];
}

for (int i = max_len - 1; i >= 0; --i)
int extending_dims = output_num_dims - input_num_dims;
for (int idx = 0; idx < input_num_dims; ++idx)
{
OP_REQUIRES((in_shape_padded[i] == tgt_shape_padded[i]) || (in_shape_padded[i] == 1));
OP_REQUIRES(input_shape.dim(idx) == 1 ||
input_shape.dim(idx) == output_shape_vec.at(extending_dims + idx));
}
}

Expand Down