|
17 | 17 | #include "circt/Dialect/HW/HWOps.h"
|
18 | 18 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
19 | 19 | #include "mlir/Conversion/LLVMCommon/Pattern.h"
|
| 20 | +#include "mlir/Dialect/MemRef/IR/MemRef.h" |
20 | 21 | #include "mlir/Dialect/StandardOps/IR/Ops.h"
|
21 | 22 | #include "mlir/IR/AsmState.h"
|
22 | 23 | #include "mlir/IR/Matchers.h"
|
@@ -60,6 +61,14 @@ static bool matchConstantOp(Operation *op, APInt &value) {
|
60 | 61 | return mlir::detail::constant_int_op_binder(&value).match(op);
|
61 | 62 | }
|
62 | 63 |
|
| 64 | +/// Returns true if there exists only a single memref::LoadOp which loads from |
| 65 | +/// the memory referenced by loadOp. |
| 66 | +static bool singleLoadFromMemory(memref::LoadOp loadOp) { |
| 67 | + return llvm::count_if(loadOp.memref().getUses(), [](auto &user) { |
| 68 | + return dyn_cast<memref::LoadOp>(user.getOwner()); |
| 69 | + }) <= 1; |
| 70 | +} |
| 71 | + |
63 | 72 | /// Creates a DictionaryAttr containing a unit attribute 'name'. Used for
|
64 | 73 | /// defining mandatory port attributes for calyx::ComponentOp's.
|
65 | 74 | static DictionaryAttr getMandatoryPortAttr(MLIRContext *ctx, StringRef name) {
|
@@ -304,6 +313,23 @@ class ComponentLoweringState {
|
304 | 313 | return it->second;
|
305 | 314 | }
|
306 | 315 |
|
| 316 | + /// Registers a calyx::MemoryOp as being associated with a memory identified |
| 317 | + /// by 'memref'. |
| 318 | + void registerMemory(Value memref, calyx::MemoryOp memoryOp) { |
| 319 | + assert(memref.getType().isa<MemRefType>()); |
| 320 | + assert(memories.find(memref) == memories.end() && |
| 321 | + "Memory already registered for memref"); |
| 322 | + memories[memref] = memoryOp; |
| 323 | + } |
| 324 | + |
| 325 | + /// Returns a calyx::MemoryOp registered for the given memref. |
| 326 | + calyx::MemoryOp getMemory(Value memref) { |
| 327 | + assert(memref.getType().isa<MemRefType>()); |
| 328 | + auto it = memories.find(memref); |
| 329 | + assert(it != memories.end() && "No memory registered for memref"); |
| 330 | + return it->second; |
| 331 | + } |
| 332 | + |
307 | 333 | private:
|
308 | 334 | /// A reference to the parent program lowering state.
|
309 | 335 | ProgramLoweringState &programLoweringState;
|
@@ -348,6 +374,9 @@ class ComponentLoweringState {
|
348 | 374 |
|
349 | 375 | /// A mapping from while ops to iteration argument registers.
|
350 | 376 | DenseMap<Operation *, DenseMap<unsigned, calyx::RegisterOp>> whileIterRegs;
|
| 377 | + |
| 378 | + /// A mapping from memref's to their corresponding calyx memory op. |
| 379 | + DenseMap<Value, calyx::MemoryOp> memories; |
351 | 380 | };
|
352 | 381 |
|
353 | 382 | /// ProgramLoweringState handles the current state of lowering of a Calyx
|
@@ -565,10 +594,13 @@ class BuildOpGroups : public FuncOpPartialLoweringPattern {
|
565 | 594 | .template Case<ConstantOp, ReturnOp, BranchOpInterface,
|
566 | 595 | /// SCF
|
567 | 596 | scf::YieldOp,
|
| 597 | + /// memref |
| 598 | + memref::AllocOp, memref::LoadOp, memref::StoreOp, |
568 | 599 | /// standard arithmetic
|
569 | 600 | AddIOp, SubIOp, CmpIOp, ShiftLeftOp,
|
570 | 601 | UnsignedShiftRightOp, SignedShiftRightOp, AndOp,
|
571 |
| - XOrOp, OrOp, ZeroExtendIOp, TruncateIOp>( |
| 602 | + XOrOp, OrOp, ZeroExtendIOp, TruncateIOp, |
| 603 | + IndexCastOp>( |
572 | 604 | [&](auto op) { return buildOp(rewriter, op).succeeded(); })
|
573 | 605 | .template Case<scf::WhileOp, mlir::FuncOp, scf::ConditionOp>(
|
574 | 606 | [&](auto) {
|
@@ -606,6 +638,10 @@ class BuildOpGroups : public FuncOpPartialLoweringPattern {
|
606 | 638 | LogicalResult buildOp(PatternRewriter &rewriter, TruncateIOp op) const;
|
607 | 639 | LogicalResult buildOp(PatternRewriter &rewriter, ZeroExtendIOp op) const;
|
608 | 640 | LogicalResult buildOp(PatternRewriter &rewriter, ReturnOp op) const;
|
| 641 | + LogicalResult buildOp(PatternRewriter &rewriter, IndexCastOp op) const; |
| 642 | + LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocOp op) const; |
| 643 | + LogicalResult buildOp(PatternRewriter &rewriter, memref::LoadOp op) const; |
| 644 | + LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op) const; |
609 | 645 |
|
610 | 646 | /// buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the
|
611 | 647 | /// source operation TSrcOp.
|
@@ -666,8 +702,107 @@ class BuildOpGroups : public FuncOpPartialLoweringPattern {
|
666 | 702 | return createGroup<TGroupOp>(rewriter, getComponentState().getComponentOp(),
|
667 | 703 | block->front().getLoc(), groupName);
|
668 | 704 | }
|
| 705 | + |
| 706 | + /// Creates assignments within the provided group to the address ports of the |
| 707 | + /// memoryOp based on the provided addressValues. |
| 708 | + void assignAddressPorts(PatternRewriter &rewriter, Location loc, |
| 709 | + calyx::GroupInterface group, calyx::MemoryOp memoryOp, |
| 710 | + Operation::operand_range addressValues) const { |
| 711 | + IRRewriter::InsertionGuard guard(rewriter); |
| 712 | + rewriter.setInsertionPointToEnd(group.getBody()); |
| 713 | + auto addrPorts = memoryOp.addrPorts(); |
| 714 | + assert(addrPorts.size() == addressValues.size() && |
| 715 | + "Mismatch between number of address ports of the provided memory " |
| 716 | + "and address assignment values"); |
| 717 | + for (auto &idx : enumerate(addressValues)) |
| 718 | + rewriter.create<calyx::AssignOp>(loc, addrPorts[idx.index()], idx.value(), |
| 719 | + Value()); |
| 720 | + } |
669 | 721 | };
|
670 | 722 |
|
| 723 | +LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, |
| 724 | + memref::LoadOp loadOp) const { |
| 725 | + auto memoryOp = getComponentState().getMemory(loadOp.memref()); |
| 726 | + if (singleLoadFromMemory(loadOp)) { |
| 727 | + /// Single load from memory; Combinational case - we do not have to consider |
| 728 | + /// adding registers in front of the memory. |
| 729 | + auto combGroup = createGroupForOp<calyx::CombGroupOp>(rewriter, loadOp); |
| 730 | + assignAddressPorts(rewriter, loadOp.getLoc(), combGroup, memoryOp, |
| 731 | + loadOp.getIndices()); |
| 732 | + |
| 733 | + /// We refrain from replacing the loadOp result with memoryOp.readData, |
| 734 | + /// since multiple loadOp's need to be converted to a single memory's |
| 735 | + /// ReadData. If this replacement is done now, we lose the link between |
| 736 | + /// which SSA memref::LoadOp values map to which groups for loading a value |
| 737 | + /// from the Calyx memory. At this point of lowering, we keep the |
| 738 | + /// memref::LoadOp SSA value, and do value replacement _after_ control has |
| 739 | + /// been generated (see LateSSAReplacement). This is *vital* for things such |
| 740 | + /// as InlineCombGroups to be able to properly track which memory assignment |
| 741 | + /// groups belong to which accesses. |
| 742 | + getComponentState().registerEvaluatingGroup(loadOp.getResult(), combGroup); |
| 743 | + } else { |
| 744 | + auto group = createGroupForOp<calyx::GroupOp>(rewriter, loadOp); |
| 745 | + assignAddressPorts(rewriter, loadOp.getLoc(), group, memoryOp, |
| 746 | + loadOp.getIndices()); |
| 747 | + |
| 748 | + /// Multiple loads from the same memory; In this case, we _may_ have a |
| 749 | + /// structural hazard in the design we generate. To get around this, we |
| 750 | + /// conservatively place a register in front of each load operation, and |
| 751 | + /// replace all uses of the loaded value with the register output. Proper |
| 752 | + /// handling of this requires the combinational group inliner/scheduler to |
| 753 | + /// be aware of when a combinational expression references multiple loaded |
| 754 | + /// values from the same memory, and then schedule assignments to temporary |
| 755 | + /// registers to get around the structural hazard. |
| 756 | + auto reg = createReg(getComponentState(), rewriter, loadOp.getLoc(), |
| 757 | + getComponentState().getUniqueName("load"), |
| 758 | + loadOp.getMemRefType().getElementTypeBitWidth()); |
| 759 | + buildAssignmentsForRegisterWrite(getComponentState(), rewriter, group, reg, |
| 760 | + memoryOp.readData()); |
| 761 | + loadOp.getResult().replaceAllUsesWith(reg.out()); |
| 762 | + getComponentState().addBlockScheduleable(loadOp->getBlock(), group); |
| 763 | + } |
| 764 | + return success(); |
| 765 | +} |
| 766 | + |
| 767 | +LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, |
| 768 | + memref::StoreOp storeOp) const { |
| 769 | + auto memoryOp = getComponentState().getMemory(storeOp.memref()); |
| 770 | + auto group = createGroupForOp<calyx::GroupOp>(rewriter, storeOp); |
| 771 | + |
| 772 | + /// This is a sequential group, so register it as being scheduleable for the |
| 773 | + /// block. |
| 774 | + getComponentState().addBlockScheduleable(storeOp->getBlock(), |
| 775 | + cast<calyx::GroupOp>(group)); |
| 776 | + assignAddressPorts(rewriter, storeOp.getLoc(), group, memoryOp, |
| 777 | + storeOp.getIndices()); |
| 778 | + rewriter.setInsertionPointToEnd(group.getBody()); |
| 779 | + rewriter.create<calyx::AssignOp>(storeOp.getLoc(), memoryOp.writeData(), |
| 780 | + storeOp.getValueToStore(), Value()); |
| 781 | + rewriter.create<calyx::AssignOp>( |
| 782 | + storeOp.getLoc(), memoryOp.writeEn(), |
| 783 | + getComponentState().getConstant(rewriter, storeOp.getLoc(), 1, 1), |
| 784 | + Value()); |
| 785 | + rewriter.create<calyx::GroupDoneOp>(storeOp.getLoc(), memoryOp.done(), |
| 786 | + Value()); |
| 787 | + return success(); |
| 788 | +} |
| 789 | +LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, |
| 790 | + memref::AllocOp allocOp) const { |
| 791 | + rewriter.setInsertionPointToStart(getComponent()->getBody()); |
| 792 | + MemRefType memtype = allocOp.getType(); |
| 793 | + SmallVector<int64_t> addrSizes; |
| 794 | + SmallVector<int64_t> sizes; |
| 795 | + for (int64_t dim : memtype.getShape()) { |
| 796 | + sizes.push_back(dim); |
| 797 | + addrSizes.push_back(llvm::Log2_64_Ceil(dim)); |
| 798 | + } |
| 799 | + auto memoryOp = rewriter.create<calyx::MemoryOp>( |
| 800 | + allocOp.getLoc(), getComponentState().getUniqueName("mem"), |
| 801 | + memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes); |
| 802 | + getComponentState().registerMemory(allocOp.getResult(), memoryOp); |
| 803 | + return success(); |
| 804 | +} |
| 805 | + |
671 | 806 | LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
|
672 | 807 | scf::YieldOp yieldOp) const {
|
673 | 808 | if (yieldOp.getOperands().size() == 0)
|
@@ -816,6 +951,120 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
|
816 | 951 | rewriter, op, {op.getOperand().getType()}, {op.getType()});
|
817 | 952 | }
|
818 | 953 |
|
| 954 | +LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter, |
| 955 | + IndexCastOp op) const { |
| 956 | + Type sourceType = op.getOperand().getType(); |
| 957 | + sourceType = sourceType.isIndex() ? rewriter.getI32Type() : sourceType; |
| 958 | + Type targetType = op.getResult().getType(); |
| 959 | + targetType = targetType.isIndex() ? rewriter.getI32Type() : targetType; |
| 960 | + unsigned targetBits = targetType.getIntOrFloatBitWidth(); |
| 961 | + unsigned sourceBits = sourceType.getIntOrFloatBitWidth(); |
| 962 | + LogicalResult res = success(); |
| 963 | + |
| 964 | + if (targetBits == sourceBits) { |
| 965 | + /// Drop the index cast and replace uses of the target value with the source |
| 966 | + /// value. |
| 967 | + op.getResult().replaceAllUsesWith(op.getOperand()); |
| 968 | + } else { |
| 969 | + /// pad/slice the source operand. |
| 970 | + if (sourceBits > targetBits) |
| 971 | + res = buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>( |
| 972 | + rewriter, op, {sourceType}, {targetType}); |
| 973 | + else |
| 974 | + res = buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>( |
| 975 | + rewriter, op, {sourceType}, {targetType}); |
| 976 | + } |
| 977 | + rewriter.eraseOp(op); |
| 978 | + return res; |
| 979 | +} |
| 980 | + |
| 981 | +/// This pass rewrites memory accesses that have a width mismatch. Such |
| 982 | +/// mismatches are due to index types being assumed 32-bit wide due to the lack |
| 983 | +/// of a width inference pass. |
| 984 | +class RewriteMemoryAccesses : public PartialLoweringPattern<calyx::AssignOp> { |
| 985 | +public: |
| 986 | + RewriteMemoryAccesses(MLIRContext *context, LogicalResult &resRef, |
| 987 | + ProgramLoweringState &pls) |
| 988 | + : PartialLoweringPattern(context, resRef), pls(pls) {} |
| 989 | + |
| 990 | + LogicalResult partiallyLower(calyx::AssignOp assignOp, |
| 991 | + PatternRewriter &rewriter) const override { |
| 992 | + auto dest = assignOp.dest(); |
| 993 | + auto destDefOp = dest.getDefiningOp(); |
| 994 | + /// Is this an assignment to a memory op? |
| 995 | + if (!destDefOp) |
| 996 | + return success(); |
| 997 | + auto destDefMem = dyn_cast<calyx::MemoryOp>(destDefOp); |
| 998 | + if (!destDefMem) |
| 999 | + return success(); |
| 1000 | + |
| 1001 | + /// Is this an assignment to an address port of the memory op? |
| 1002 | + bool isAssignToAddrPort = llvm::any_of( |
| 1003 | + destDefMem.addrPorts(), [&](auto port) { return port == dest; }); |
| 1004 | + |
| 1005 | + auto src = assignOp.src(); |
| 1006 | + auto &state = |
| 1007 | + pls.compLoweringState(assignOp->getParentOfType<calyx::ComponentOp>()); |
| 1008 | + |
| 1009 | + unsigned srcBits = src.getType().getIntOrFloatBitWidth(); |
| 1010 | + unsigned dstBits = dest.getType().getIntOrFloatBitWidth(); |
| 1011 | + if (srcBits == dstBits) |
| 1012 | + return success(); |
| 1013 | + |
| 1014 | + if (isAssignToAddrPort) { |
| 1015 | + SmallVector<Type> types = {rewriter.getIntegerType(srcBits), |
| 1016 | + rewriter.getIntegerType(dstBits)}; |
| 1017 | + auto sliceOp = state.getNewLibraryOpInstance<calyx::SliceLibOp>( |
| 1018 | + rewriter, assignOp.getLoc(), types); |
| 1019 | + rewriter.setInsertionPoint(assignOp->getBlock(), |
| 1020 | + assignOp->getBlock()->begin()); |
| 1021 | + rewriter.create<calyx::AssignOp>(assignOp->getLoc(), sliceOp.getResult(0), |
| 1022 | + src, Value()); |
| 1023 | + assignOp.setOperand(1, sliceOp.getResult(1)); |
| 1024 | + } else |
| 1025 | + return assignOp.emitError() |
| 1026 | + << "Will only infer slice operators for assign width mismatches " |
| 1027 | + "to memory address ports."; |
| 1028 | + |
| 1029 | + return success(); |
| 1030 | + } |
| 1031 | + |
| 1032 | +private: |
| 1033 | + ProgramLoweringState &pls; |
| 1034 | +}; |
| 1035 | + |
| 1036 | +/// Connverts all index-typed operations and values to i32 values. |
| 1037 | +class ConvertIndexTypes : public FuncOpPartialLoweringPattern { |
| 1038 | + using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern; |
| 1039 | + |
| 1040 | + LogicalResult |
| 1041 | + PartiallyLowerFuncToComp(mlir::FuncOp funcOp, |
| 1042 | + PatternRewriter &rewriter) const override { |
| 1043 | + funcOp.walk([&](Block *block) { |
| 1044 | + for (auto arg : block->getArguments()) |
| 1045 | + if (arg.getType().isIndex()) |
| 1046 | + arg.setType(rewriter.getI32Type()); |
| 1047 | + }); |
| 1048 | + |
| 1049 | + funcOp.walk([&](Operation *op) { |
| 1050 | + for (auto res : op->getResults()) { |
| 1051 | + if (!res.getType().isIndex()) |
| 1052 | + continue; |
| 1053 | + |
| 1054 | + res.setType(rewriter.getI32Type()); |
| 1055 | + if (auto constOp = dyn_cast<ConstantOp>(op)) { |
| 1056 | + APInt value; |
| 1057 | + matchConstantOp(constOp, value); |
| 1058 | + rewriter.setInsertionPoint(constOp); |
| 1059 | + rewriter.replaceOpWithNewOp<ConstantOp>( |
| 1060 | + constOp, rewriter.getI32IntegerAttr(value.getSExtValue())); |
| 1061 | + } |
| 1062 | + } |
| 1063 | + }); |
| 1064 | + return success(); |
| 1065 | + } |
| 1066 | +}; |
| 1067 | + |
819 | 1068 | /// Inlines Calyx ExecuteRegionOp operations within their parent blocks.
|
820 | 1069 | /// An execution region op (ERO) is inlined by:
|
821 | 1070 | /// i : add a sink basic block for all yield operations inside the
|
@@ -1330,6 +1579,18 @@ class LateSSAReplacement : public FuncOpPartialLoweringPattern {
|
1330 | 1579 | for (auto res : getComponentState().getWhileIterRegs(whileOp))
|
1331 | 1580 | whileOp.getResults()[res.first].replaceAllUsesWith(res.second.out());
|
1332 | 1581 | });
|
| 1582 | + |
| 1583 | + funcOp.walk([&](memref::LoadOp loadOp) { |
| 1584 | + if (singleLoadFromMemory(loadOp)) { |
| 1585 | + /// In buildOpGroups we did not replace loadOp's results, to ensure a |
| 1586 | + /// link between evaluating groups (which fix the input addresses of a |
| 1587 | + /// memory op) and a readData result. Now, we may replace these SSA |
| 1588 | + /// values with their memoryOp readData output. |
| 1589 | + loadOp.getResult().replaceAllUsesWith( |
| 1590 | + getComponentState().getMemory(loadOp.memref()).readData()); |
| 1591 | + } |
| 1592 | + }); |
| 1593 | + |
1333 | 1594 | return success();
|
1334 | 1595 | }
|
1335 | 1596 | };
|
@@ -1719,6 +1980,10 @@ void SCFToCalyxPass::runOnOperation() {
|
1719 | 1980 | /// This pass inlines scf.ExecuteRegionOp's by adding control-flow.
|
1720 | 1981 | addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);
|
1721 | 1982 |
|
| 1983 | + /// This pattern converts all index types to a predefined width (currently |
| 1984 | + /// i32). |
| 1985 | + addOncePattern<ConvertIndexTypes>(loweringPatterns, funcMap, *loweringState); |
| 1986 | + |
1722 | 1987 | /// This pattern creates registers for all basic-block arguments.
|
1723 | 1988 | addOncePattern<BuildBBRegs>(loweringPatterns, funcMap, *loweringState);
|
1724 | 1989 |
|
@@ -1752,6 +2017,10 @@ void SCFToCalyxPass::runOnOperation() {
|
1752 | 2017 | /// after control generation.
|
1753 | 2018 | addOncePattern<LateSSAReplacement>(loweringPatterns, funcMap, *loweringState);
|
1754 | 2019 |
|
| 2020 | + /// This pattern rewrites accesses to memories which are too wide due to |
| 2021 | + /// index types being converted to a fixed-width integer type. |
| 2022 | + addOncePattern<RewriteMemoryAccesses>(loweringPatterns, *loweringState); |
| 2023 | + |
1755 | 2024 | /// This pattern removes the source FuncOp which has now been converted into
|
1756 | 2025 | /// a Calyx component.
|
1757 | 2026 | addOncePattern<CleanupFuncOps>(loweringPatterns, funcMap, *loweringState);
|
|
0 commit comments