diff --git a/compiler/luci/pass/src/ForwardTransposeOpPass.cpp b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp index b68a27eda87..46ef5780384 100644 --- a/compiler/luci/pass/src/ForwardTransposeOpPass.cpp +++ b/compiler/luci/pass/src/ForwardTransposeOpPass.cpp @@ -153,10 +153,11 @@ uint32_t cal_offset(const std::vector &shape, const std::vector &t) +template +CircleConst *reverse_transposed_t(CircleConst *node, std::vector &t) { + using T = typename loco::DataTypeImpl
::Type; + assert(node->rank() == t.size()); // FIX_CALLER_UNLESS assert(node->rank() == 4); // FIX_CALLER_UNLESS @@ -190,8 +191,8 @@ CircleConst *reverse_transposed(CircleConst *node, std::vector &t) std::vector orig_indices{new_indices[t[0]], new_indices[t[1]], new_indices[t[2]], new_indices[t[3]]}; - const auto data = node->at(cal_offset(orig_shape, orig_indices)); - clone_const->at(cal_offset(new_shape, new_indices)) = data; + const auto v = node->at
(cal_offset(orig_shape, orig_indices)); + clone_const->at
(cal_offset(new_shape, new_indices)) = static_cast(v); } } } @@ -200,28 +201,72 @@ CircleConst *reverse_transposed(CircleConst *node, std::vector &t) return clone_const; } -bool check_rank_four(const CircleConst *c) { return c->rank() == 4; } - -bool has_single_element(const luci::CircleConst *node) +// Return reverse-transpose of 'node' +// i.e., Transpose(return value) = node +CircleConst *reverse_transposed(CircleConst *node, std::vector &t) { - bool has_single_elem = false; switch (node->dtype()) { case loco::DataType::FLOAT32: - has_single_elem = node->size() == 1; - break; + return reverse_transposed_t(node, t); + case loco::DataType::FLOAT16: + return reverse_transposed_t(node, t); + + case loco::DataType::S32: + return reverse_transposed_t(node, t); + case loco::DataType::S16: + return reverse_transposed_t(node, t); + case loco::DataType::S8: + return reverse_transposed_t(node, t); + case loco::DataType::U8: + return reverse_transposed_t(node, t); + default: - // NYI - break; + return nullptr; } +} + +bool check_rank_four(const CircleConst *c) { return c->rank() == 4; } - if (has_single_elem) +template bool has_single_element_t(const luci::CircleConst *node) +{ + if (node->size
() != 1) { - for (uint32_t i = 0; i < node->rank(); i++) - assert(node->dim(i).value() == 1); // FIX_ME_UNLESS + return false; } + for (uint32_t i = 0; i < node->rank(); i++) + { + assert(node->dim(i).value() == 1); // FIX_ME_UNLESS + } + return true; +} + +bool has_single_element(const luci::CircleConst *node) +{ + switch (node->dtype()) + { + case loco::DataType::FLOAT32: + return has_single_element_t(node); + case loco::DataType::FLOAT16: + return has_single_element_t(node); + + case loco::DataType::S64: + return has_single_element_t(node); + case loco::DataType::S32: + return has_single_element_t(node); + case loco::DataType::S16: + return has_single_element_t(node); + case loco::DataType::S8: + return has_single_element_t(node); + case loco::DataType::U8: + return has_single_element_t(node); - return has_single_elem; + case loco::DataType::BOOL: + return has_single_element_t(node); + + default: + return false; + } } // Elementwise Binary Operator with const