Skip to content

Commit 7d486b0

Browse files
committed
add stream_flatten transformation
1 parent 74b6d12 commit 7d486b0

File tree

1 file changed

+54
-4
lines changed

1 file changed

+54
-4
lines changed

src/internal/types.cpp

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -410,13 +410,13 @@ bool stream_dense_fold(Type &type) {
410410
/* if nested streams, if the child count is 1, the parent's
411411
elements are just the child's element
412412
*/
413-
bool stream_fold_child_count_one(Type &type) {
413+
bool stream_elision(Type &type) {
414414

415415
bool changed = false;
416416

417417
// try to fold all children into their parents first
418418
for (Type &child : type.children()) {
419-
changed |= stream_fold_child_count_one(child);
419+
changed |= stream_elision(child);
420420
}
421421

422422
// type and child must be StreamData
@@ -442,6 +442,53 @@ bool stream_fold_child_count_one(Type &type) {
442442
return changed;
443443
}
444444

445+
/* detect, for example, where two vectors of two blocks is just one vector of
446+
four blocks
447+
448+
when stride of parent is count * stride of child
449+
450+
parent/child replaced with something like child, except count is now
451+
child.count * parent.count
452+
453+
this probably only comes in vectors of subarrays since subarrays have
454+
padding on the end
455+
*/
456+
bool stream_flatten(Type &type) {
457+
458+
bool changed = false;
459+
460+
// try to fold all children into their parents first
461+
for (Type &child : type.children()) {
462+
changed |= stream_flatten(child);
463+
}
464+
465+
// type and child must be StreamData
466+
if (!std::holds_alternative<StreamData>(type.data)) {
467+
return false;
468+
}
469+
assert(1 == type.children().size());
470+
Type &child = type.children()[0];
471+
if (!std::holds_alternative<StreamData>(child.data)) {
472+
return false;
473+
}
474+
475+
StreamData &pData = std::get<StreamData>(type.data);
476+
const StreamData &cData = std::get<StreamData>(child.data);
477+
478+
if (pData.stride == cData.count * cData.stride) {
479+
changed = true;
480+
481+
// transform parent in to child, with count multiplied by parents count
482+
pData.count *= cData.count;
483+
pData.stride = cData.stride;
484+
pData.off += cData.off;
485+
std::vector<Type> gchildren = child.children();
486+
type.children() = gchildren;
487+
}
488+
489+
return changed;
490+
}
491+
445492
/* tries to convert as much of the type to subarrays as possible
446493
*/
447494
Type simplify(const Type &type) {
@@ -461,8 +508,11 @@ Type simplify(const Type &type) {
461508
changed |= stream_dense_fold(simp);
462509
LOG_SPEW("after stream_dense_fold");
463510
LOG_SPEW("\n" + simp.str());
464-
changed |= stream_fold_child_count_one(simp);
465-
LOG_SPEW("after stream_fold_child_count_one");
511+
changed |= stream_flatten(simp);
512+
LOG_SPEW("after stream_flatten");
513+
LOG_SPEW("\n" + simp.str());
514+
changed |= stream_elision(simp);
515+
LOG_SPEW("after stream_elision");
466516
LOG_SPEW("\n" + simp.str());
467517
}
468518

0 commit comments

Comments
 (0)