Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,69 @@ LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
}
}

// Eliminate unneeded bits in the index. These arise from duplicate values in
// the mux. This is done by slicing the mux into a mux tree.
// multibit_mux(index, {a, b, a, c}) -> multibit_mux(index[0],
// multibit_mux(index[1], a,a), multibit_mux(index[1]}, {b,c})) Search for
// identities in specific bit slices. This is robust to unknown width
// indexes.
for (uint64_t bit = 0,
lastbit = op.getIndex().getType().getBitWidthOrSentinel();
bit < lastbit; ++bit) {
for (int curval = 0; curval <= 1; ++curval) {
// We don't collect values here as the normal case is we don't find a
// match, so we don't want to move data around and do allocations.
Value v;
uint64_t count = 0;
for (uint64_t i = 0, e = op.getInputs().size(); i < e; ++i) {
if (((i >> bit) & 1) != curval)
continue;
++count;
if (!v)
v = op.getInputs()[e - i - 1];
if (v != op.getInputs()[e - i - 1]) {
v = {};
break;
}
}
if (!v || count == 1)
continue;
// Found match, collect varying side of the future mux
SmallVector<Value> nonSimple;
for (uint64_t i = 0, e = op.getInputs().size(); i < e; ++i) {
if (((i >> bit) & 1) != curval)
nonSimple.push_back(op.getInputs()[e - i - 1]);
}
std::reverse(nonSimple.begin(), nonSimple.end());
Value indBit = rewriter.createOrFold<BitsPrimOp>(op.getLoc(),
op.getIndex(), bit, bit);
Value indBitRemLow;
if (bit)
indBitRemLow = rewriter.createOrFold<BitsPrimOp>(
op.getLoc(), op.getIndex(), bit - 1, 0);
else
indBitRemLow = rewriter.create<ConstantOp>(
op.getLoc(), IntType::get(op.getContext(), false, 0),
APInt(0U, 0UL));
Value indBitRemHigh;
if (bit == lastbit - 1)
indBitRemHigh = rewriter.create<ConstantOp>(
op.getLoc(), IntType::get(op.getContext(), false, 0),
APInt(0U, 0UL));
else
indBitRemHigh = rewriter.createOrFold<BitsPrimOp>(
op.getLoc(), op.getIndex(), lastbit - 1, bit + 1);
Value indBitRem = rewriter.createOrFold<CatPrimOp>(
op.getLoc(), indBitRemHigh, indBitRemLow);
Value otherSide =
rewriter.create<MultibitMuxOp>(op.getLoc(), indBitRem, nonSimple);
Value high = curval ? v : otherSide;
Value low = curval ? otherSide : v;
replaceOpWithNewOpAndCopyName<MuxPrimOp>(rewriter, op, indBit, high, low);
return success();
}
}

// If the size is 2, canonicalize into a normal mux to introduce more folds.
if (op.getInputs().size() != 2)
return failure();
Expand Down
19 changes: 19 additions & 0 deletions test/Dialect/FIRRTL/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3309,4 +3309,23 @@ firrtl.module @Whens(in %clock: !firrtl.clock, in %a: !firrtl.uint<1>, in %reset
}
}

// CHECK-LABEL: firrtl.module @UselessIndexBit
firrtl.module @UselessIndexBit(in %a: !firrtl.uint<3>, out %b: !firrtl.uint<4>, in %c: !firrtl.uint<4>, in %d: !firrtl.uint<4>, in %e: !firrtl.uint<4>) attributes {convention = #firrtl<convention scalarized>} {
%c0_ui4 = firrtl.constant 0 : !firrtl.uint<4> {name = "ttable_2"}
%0 = firrtl.multibit_mux %a, %c0_ui4, %c0_ui4, %c0_ui4, %c0_ui4, %c0_ui4, %c0_ui4, %d, %c : !firrtl.uint<3>, !firrtl.uint<4>
firrtl.strictconnect %b, %0 : !firrtl.uint<4>
// CHECK: firrtl.mux({{.*}}, %d, %c)
// CHECK: firrtl.mux({{.*}}, %c0_ui4, {{.*}})
}

// CHECK-LABEL: firrtl.module @UselessIndexBit2
firrtl.module @UselessIndexBit2(in %a: !firrtl.uint<3>, out %b: !firrtl.uint<4>, in %c: !firrtl.uint<4>, in %d: !firrtl.uint<4>, in %e: !firrtl.uint<4>) {
%0 = firrtl.multibit_mux %a, %c, %c, %e, %c, %c, %c : !firrtl.uint<3>, !firrtl.uint<4>
firrtl.strictconnect %b, %0 : !firrtl.uint<4>
// CHECK: %0 = firrtl.bits %a 0 to 0
// CHECK: %1 = firrtl.bits %a 1 to 1 : (!firrtl.uint<3>) -> !firrtl.uint<1>
// CHECK: %2 = firrtl.and %0, %1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: firrtl.mux(%2, %e, %c) : (!firrtl.uint<1>, !firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<4>
}

}