diff --git a/src/internal/types.cpp b/src/internal/types.cpp index cb53999..131b3a4 100644 --- a/src/internal/types.cpp +++ b/src/internal/types.cpp @@ -410,13 +410,13 @@ bool stream_dense_fold(Type &type) { /* if nested streams, if the child count is 1, the parent's elements are just the child's element */ -bool stream_fold_child_count_one(Type &type) { +bool stream_elision(Type &type) { bool changed = false; // try to fold all children into their parents first for (Type &child : type.children()) { - changed |= stream_fold_child_count_one(child); + changed |= stream_elision(child); } // type and child must be StreamData @@ -442,6 +442,53 @@ bool stream_fold_child_count_one(Type &type) { return changed; } +/* detect, for example, where two vectors of two blocks is just one vector of + four blocks + + when stride of parent is count * stride of child + + parent/child replaced with something like child, except count is now + child.count * parent.count + + this probably only comes in vectors of subarrays since subarrays have + padding on the end +*/ +bool stream_flatten(Type &type) { + + bool changed = false; + + // try to fold all children into their parents first + for (Type &child : type.children()) { + changed |= stream_flatten(child); + } + + // type and child must be StreamData + if (!std::holds_alternative(type.data)) { + return false; + } + assert(1 == type.children().size()); + Type &child = type.children()[0]; + if (!std::holds_alternative(child.data)) { + return false; + } + + StreamData &pData = std::get(type.data); + const StreamData &cData = std::get(child.data); + + if (pData.stride == cData.count * cData.stride) { + changed = true; + + // transform parent in to child, with count multiplied by parents count + pData.count *= cData.count; + pData.stride = cData.stride; + pData.off += cData.off; + std::vector gchildren = child.children(); + type.children() = gchildren; + } + + return changed; +} + /* tries to convert as much of the type to subarrays as possible */ Type simplify(const Type &type) { @@ -461,8 +508,11 @@ Type simplify(const Type &type) { changed |= stream_dense_fold(simp); LOG_SPEW("after stream_dense_fold"); LOG_SPEW("\n" + simp.str()); - changed |= stream_fold_child_count_one(simp); - LOG_SPEW("after stream_fold_child_count_one"); + changed |= stream_flatten(simp); + LOG_SPEW("after stream_flatten"); + LOG_SPEW("\n" + simp.str()); + changed |= stream_elision(simp); + LOG_SPEW("after stream_elision"); LOG_SPEW("\n" + simp.str()); }