Skip to content

Commit

Permalink
add stream_flatten transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed Mar 26, 2021
1 parent 74b6d12 commit 7d486b0
Showing 1 changed file with 54 additions and 4 deletions.
58 changes: 54 additions & 4 deletions src/internal/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,13 +410,13 @@ bool stream_dense_fold(Type &type) {
/* if nested streams, if the child count is 1, the parent's
elements are just the child's element
*/
bool stream_fold_child_count_one(Type &type) {
bool stream_elision(Type &type) {

bool changed = false;

// try to fold all children into their parents first
for (Type &child : type.children()) {
changed |= stream_fold_child_count_one(child);
changed |= stream_elision(child);
}

// type and child must be StreamData
Expand All @@ -442,6 +442,53 @@ bool stream_fold_child_count_one(Type &type) {
return changed;
}

/* detect, for example, where two vectors of two blocks is just one vector of
four blocks
when stride of parent is count * stride of child
parent/child replaced with something like child, except count is now
child.count * parent.count
this probably only comes in vectors of subarrays since subarrays have
padding on the end
*/
bool stream_flatten(Type &type) {

bool changed = false;

// try to fold all children into their parents first
for (Type &child : type.children()) {
changed |= stream_flatten(child);
}

// type and child must be StreamData
if (!std::holds_alternative<StreamData>(type.data)) {
return false;
}
assert(1 == type.children().size());
Type &child = type.children()[0];
if (!std::holds_alternative<StreamData>(child.data)) {
return false;
}

StreamData &pData = std::get<StreamData>(type.data);
const StreamData &cData = std::get<StreamData>(child.data);

if (pData.stride == cData.count * cData.stride) {
changed = true;

// transform parent in to child, with count multiplied by parents count
pData.count *= cData.count;
pData.stride = cData.stride;
pData.off += cData.off;
std::vector<Type> gchildren = child.children();
type.children() = gchildren;
}

return changed;
}

/* tries to convert as much of the type to subarrays as possible
*/
Type simplify(const Type &type) {
Expand All @@ -461,8 +508,11 @@ Type simplify(const Type &type) {
changed |= stream_dense_fold(simp);
LOG_SPEW("after stream_dense_fold");
LOG_SPEW("\n" + simp.str());
changed |= stream_fold_child_count_one(simp);
LOG_SPEW("after stream_fold_child_count_one");
changed |= stream_flatten(simp);
LOG_SPEW("after stream_flatten");
LOG_SPEW("\n" + simp.str());
changed |= stream_elision(simp);
LOG_SPEW("after stream_elision");
LOG_SPEW("\n" + simp.str());
}

Expand Down

0 comments on commit 7d486b0

Please sign in to comment.