Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 62 additions & 17 deletions compiler/luci/pass/src/ForwardTransposeOpPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,11 @@ uint32_t cal_offset(const std::vector<uint32_t> &shape, const std::vector<uint32
return offset;
}

// Return reverse-transpose of 'node'
// i.e., Transpose(return value) = node
CircleConst *reverse_transposed(CircleConst *node, std::vector<uint32_t> &t)
template <loco::DataType DT>
CircleConst *reverse_transposed_t(CircleConst *node, std::vector<uint32_t> &t)
{
using T = typename loco::DataTypeImpl<DT>::Type;

assert(node->rank() == t.size()); // FIX_CALLER_UNLESS
assert(node->rank() == 4); // FIX_CALLER_UNLESS

Expand Down Expand Up @@ -190,8 +191,8 @@ CircleConst *reverse_transposed(CircleConst *node, std::vector<uint32_t> &t)
std::vector<uint32_t> orig_indices{new_indices[t[0]], new_indices[t[1]],
new_indices[t[2]], new_indices[t[3]]};

const auto data = node->at<loco::DataType::FLOAT32>(cal_offset(orig_shape, orig_indices));
clone_const->at<loco::DataType::FLOAT32>(cal_offset(new_shape, new_indices)) = data;
const auto v = node->at<DT>(cal_offset(orig_shape, orig_indices));
clone_const->at<DT>(cal_offset(new_shape, new_indices)) = static_cast<T>(v);
}
}
}
Expand All @@ -200,28 +201,72 @@ CircleConst *reverse_transposed(CircleConst *node, std::vector<uint32_t> &t)
return clone_const;
}

bool check_rank_four(const CircleConst *c) { return c->rank() == 4; }

bool has_single_element(const luci::CircleConst *node)
// Return reverse-transpose of 'node'
// i.e., Transpose(return value) = node
CircleConst *reverse_transposed(CircleConst *node, std::vector<uint32_t> &t)
{
bool has_single_elem = false;
switch (node->dtype())
{
case loco::DataType::FLOAT32:
has_single_elem = node->size<loco::DataType::FLOAT32>() == 1;
break;
return reverse_transposed_t<loco::DataType::FLOAT32>(node, t);
case loco::DataType::FLOAT16:
return reverse_transposed_t<loco::DataType::FLOAT16>(node, t);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it inteded line break?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right. I just split it according to the float vs. int. It's just a preference.

case loco::DataType::S32:
return reverse_transposed_t<loco::DataType::S32>(node, t);
case loco::DataType::S16:
return reverse_transposed_t<loco::DataType::S16>(node, t);
case loco::DataType::S8:
return reverse_transposed_t<loco::DataType::S8>(node, t);
case loco::DataType::U8:
return reverse_transposed_t<loco::DataType::U8>(node, t);

default:
// NYI
break;
return nullptr;
}
}

bool check_rank_four(const CircleConst *c) { return c->rank() == 4; }

if (has_single_elem)
template <loco::DataType DT> bool has_single_element_t(const luci::CircleConst *node)
{
if (node->size<DT>() != 1)
{
for (uint32_t i = 0; i < node->rank(); i++)
assert(node->dim(i).value() == 1); // FIX_ME_UNLESS
return false;
}
for (uint32_t i = 0; i < node->rank(); i++)
{
assert(node->dim(i).value() == 1); // FIX_ME_UNLESS
}
return true;
}

bool has_single_element(const luci::CircleConst *node)
{
switch (node->dtype())
{
case loco::DataType::FLOAT32:
return has_single_element_t<loco::DataType::FLOAT32>(node);
case loco::DataType::FLOAT16:
return has_single_element_t<loco::DataType::FLOAT16>(node);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

case loco::DataType::S64:
return has_single_element_t<loco::DataType::S64>(node);
case loco::DataType::S32:
return has_single_element_t<loco::DataType::S32>(node);
case loco::DataType::S16:
return has_single_element_t<loco::DataType::S16>(node);
case loco::DataType::S8:
return has_single_element_t<loco::DataType::S8>(node);
case loco::DataType::U8:
return has_single_element_t<loco::DataType::U8>(node);

return has_single_elem;
case loco::DataType::BOOL:
return has_single_element_t<loco::DataType::BOOL>(node);

default:
return false;
}
}

// Elementwise Binary Operator with const
Expand Down