diff --git a/include/circt/Dialect/Synth/Transforms/SynthPasses.td b/include/circt/Dialect/Synth/Transforms/SynthPasses.td index ecb593fa17eb..07c3d671f5f5 100644 --- a/include/circt/Dialect/Synth/Transforms/SynthPasses.td +++ b/include/circt/Dialect/Synth/Transforms/SynthPasses.td @@ -89,7 +89,9 @@ def LowerVariadic : Pass<"synth-lower-variadic", "hw::HWModuleOp"> { ListOption<"opNames", "op-names", "std::string", "Specify operation names to lower (empty means all)">, Option<"timingAware", "timing-aware", "bool", "true", - "Lower operators with timing information"> + "Lower operators with timing information">, + Option<"reuseSubsets", "reuse-subsets", "bool", /*default=*/"false", + "Reuse existing logic subsets to minimize area"> ]; let dependentDialects = [ "circt::comb::CombDialect", "circt::hw::HWDialect", diff --git a/lib/Dialect/Synth/Transforms/LowerVariadic.cpp b/lib/Dialect/Synth/Transforms/LowerVariadic.cpp index 26c94d6ec461..bda940cb5379 100644 --- a/lib/Dialect/Synth/Transforms/LowerVariadic.cpp +++ b/lib/Dialect/Synth/Transforms/LowerVariadic.cpp @@ -19,7 +19,17 @@ #include "circt/Dialect/Synth/SynthOps.h" #include "circt/Dialect/Synth/Transforms/SynthPasses.h" #include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/IR/Block.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include +#include #define DEBUG_TYPE "synth-lower-variadic" @@ -93,6 +103,109 @@ static LogicalResult replaceWithBalancedTree( return success(); } +using OperandKey = std::vector>; + +namespace llvm { +template <> +struct DenseMapInfo { + static OperandKey getEmptyKey() { + // Return a vector containing the mlir::Value empty key + return {{DenseMapInfo::getEmptyKey(), false}}; + } + + static OperandKey getTombstoneKey() { + // Return a vector containing the mlir::Value tombstone key + return {{DenseMapInfo::getTombstoneKey(), false}}; + } + + static unsigned getHashValue(const OperandKey &val) { + llvm::hash_code hash = 0; + // Iteratively combine the hash of each pair in the vector + for (const auto &pair : val) { + hash = llvm::hash_combine( + hash, DenseMapInfo::getHashValue(pair.first), + pair.second); + } + return static_cast(hash); + } + + static bool isEqual(const OperandKey &lhs, const OperandKey &rhs) { + // std::vector and std::pair already implement operator==, + // which does a deep equality check of the elements. + return lhs == rhs; + } +}; +} // namespace llvm + +// Struct for ordering the andInverterOp operations we have already seen +struct OperandPairLess { + bool operator()(const std::pair &lhs, + const std::pair &rhs) const { + if (lhs.first != rhs.first) { + auto lhsArg = llvm::dyn_cast(lhs.first); + auto rhsArg = llvm::dyn_cast(rhs.first); + if (lhsArg && rhsArg) + return lhsArg.getArgNumber() < rhsArg.getArgNumber(); + if (lhsArg) + return true; + if (rhsArg) + return false; + + auto *lhsOp = lhs.first.getDefiningOp(); + auto *rhsOp = rhs.first.getDefiningOp(); + return lhsOp->isBeforeInBlock(rhsOp); + } + return lhs.second < rhs.second; + } +}; + +static OperandKey getSortedOperandKey(aig::AndInverterOp op) { + OperandKey key; + for (size_t i = 0, e = op.getNumOperands(); i < e; ++i) { + key.emplace_back(op.getOperand(i), op.isInverted(i)); + } + std::sort(key.begin(), key.end(), OperandPairLess()); + return key; +} + +static void simplifyWithExistingOperations( + aig::AndInverterOp op, mlir::IRRewriter &rewriter, + llvm::DenseMap &seenExpressions) { + + if (op.getNumOperands() <= 2) + return; + + OperandKey allOperands = getSortedOperandKey(op); + mlir::SmallVector newValues; + mlir::SmallVector newInversions; + + for (auto it = allOperands.begin(); it != allOperands.end(); ++it) { + // Look at the remaining operands from 'it' to the end + OperandKey remaining(it, allOperands.end()); + + auto match = seenExpressions.find(remaining); + if (match != seenExpressions.end() && match->second != op.getResult()) { + newValues.push_back(match->second); + newInversions.push_back(false); + + // We found a match that covers everything from 'it' to the end, + // so we can stop searching. + break; + } + + // No match, add it to the new list of values and inversions. + newValues.push_back(it->first); + newInversions.push_back(it->second); + } + + if (newValues.size() < allOperands.size()) { + rewriter.modifyOpInPlace(op, [&]() { + op.getOperation()->setOperands(newValues); + op.setInverted(newInversions); + }); + } +} + void LowerVariadicPass::runOnOperation() { // Topologically sort operations in graph regions to ensure operands are // defined before uses. @@ -131,6 +244,26 @@ void LowerVariadicPass::runOnOperation() { mlir::IRRewriter rewriter(&getContext()); rewriter.setListener(analysis); + // To be used in simplifyWithExistingOperations. + llvm::DenseMap seenExpressions; + // Simplify exising andInverterOps by reusing operations. + if (reuseSubsets) { + // First collect all the andInverterOp operations in the block. + for (auto &op : moduleOp.getBodyBlock()->getOperations()) { + if (auto andInverterOp = llvm::dyn_cast(op)) { + OperandKey key = getSortedOperandKey(andInverterOp); + seenExpressions[key] = andInverterOp.getResult(); + } + } + // Now try to replace operations with subsets. + for (auto &op : moduleOp.getBodyBlock()->getOperations()) { + if (auto andInverterOp = llvm::dyn_cast(op)) { + simplifyWithExistingOperations(andInverterOp, rewriter, + seenExpressions); + } + } + } + // FIXME: Currently only top-level operations are lowered due to the lack of // topological sorting in across nested regions. for (auto &opRef : @@ -158,7 +291,6 @@ void LowerVariadicPass::runOnOperation() { }); if (failed(result)) return signalPassFailure(); - continue; } // Handle commutative operations (and, or, xor, mul, add, etc.) using diff --git a/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp b/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp index 49c023885ac7..7a11626132f5 100644 --- a/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp +++ b/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp @@ -40,10 +40,12 @@ static void addOpName(SmallVectorImpl &ops) { (ops.push_back(AllowedOpTy::getOperationName().str()), ...); } template -static std::unique_ptr createLowerVariadicPass(bool timingAware) { +static std::unique_ptr +createLowerVariadicPass(bool timingAware, bool reuseSubsets = false) { LowerVariadicOptions options; addOpName(options.opNames); options.timingAware = timingAware; + options.reuseSubsets = reuseSubsets; return createLowerVariadic(options); } void circt::synth::buildCombLoweringPipeline( @@ -81,7 +83,9 @@ void circt::synth::buildCombLoweringPipeline( if (options.targetIR.getValue() == TargetIR::AIG) { // For AIG, lower variadic XoR since AIG cannot keep variadic // representation. - pm.addPass(createLowerVariadicPass(options.timingAware)); + pm.addPass(createLowerVariadicPass( + options.timingAware, + options.synthesisStrategy == OptimizationStrategyArea)); } else if (options.targetIR.getValue() == TargetIR::MIG) { // For MIG, lower variadic And, Or, and Xor since MIG cannot keep variadic // representation. diff --git a/test/Dialect/Synth/lower-variadic.mlir b/test/Dialect/Synth/lower-variadic.mlir index b980e155c824..495e76d8ba6e 100644 --- a/test/Dialect/Synth/lower-variadic.mlir +++ b/test/Dialect/Synth/lower-variadic.mlir @@ -1,5 +1,6 @@ // RUN: circt-opt %s --synth-lower-variadic --split-input-file | FileCheck %s --check-prefixes=COMMON,TIMING // RUN: circt-opt %s --synth-lower-variadic=timing-aware=false --split-input-file | FileCheck %s --check-prefixes=COMMON,NO-TIMING +// RUN: circt-opt %s --synth-lower-variadic=reuse-subsets=true | FileCheck %s // COMMON-LABEL: hw.module @Basic hw.module @Basic(in %a: i2, in %b: i2, in %c: i2, in %d: i2, in %e: i2, out f: i2) { // COMMON-NEXT: %[[RES0:.+]] = synth.aig.and_inv not %a, %b : i2 @@ -71,3 +72,21 @@ hw.module @Issue9115(in %a : i16, in %b : i16, in %c : i16, in %d : i16, out pro // COMMON-NEXT: comb.mul %c, %[[TMP]] : i16 hw.output %0 : i16 } + +// COMMON-LABEL: hw.module @SharingHeuristic +hw.module @SharingHeuristic(in %in0 : i1, in %in1 : i1, in %in2 : i1, in %in3 : i1, in %in4 : i1, out out1 : i1, out out2 : i1) { + + // These represent the subset tree (out2) + // CHECK: %[[N0:.+]] = synth.aig.and_inv %in1, %in2 + // CHECK: %[[N1:.+]] = synth.aig.and_inv %in3, %in4 + // CHECK: %[[SUBSET_RES:.+]] = synth.aig.and_inv %[[N0]], %[[N1]] + %out2 = synth.aig.and_inv %in1, %in2, %in3, %in4 : i1 + + // out1 should now just use the SUBSET_RES directly + // CHECK: %[[OUT1_ROOT:.+]] = synth.aig.and_inv %in0, %[[SUBSET_RES]] + %out1 = synth.aig.and_inv %in0, %in1, %in2, %in3, %in4 : i1 + + // CHECK: hw.output %[[OUT1_ROOT]], %[[SUBSET_RES]] + hw.output %out1, %out2 : i1, i1 +} +