Skip to content

Commit 184db22

Browse files
authored
[SCFToCalyx] Support memref operations [12/12] (#1863)
This commit adds support for memory accessing operations. All index types are converted to a fixed-width integer. This is necessary due to the lack of a bitwidth inference pass. Upon an index-typed value being used as a memory address input, the address value is truncated to the width of the memory port.
1 parent ef45313 commit 184db22

File tree

2 files changed

+529
-1
lines changed

2 files changed

+529
-1
lines changed

lib/Conversion/SCFToCalyx/SCFToCalyx.cpp

Lines changed: 270 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "circt/Dialect/HW/HWOps.h"
1818
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1919
#include "mlir/Conversion/LLVMCommon/Pattern.h"
20+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2021
#include "mlir/Dialect/StandardOps/IR/Ops.h"
2122
#include "mlir/IR/AsmState.h"
2223
#include "mlir/IR/Matchers.h"
@@ -60,6 +61,14 @@ static bool matchConstantOp(Operation *op, APInt &value) {
6061
return mlir::detail::constant_int_op_binder(&value).match(op);
6162
}
6263

64+
/// Returns true if there exists only a single memref::LoadOp which loads from
65+
/// the memory referenced by loadOp.
66+
static bool singleLoadFromMemory(memref::LoadOp loadOp) {
67+
return llvm::count_if(loadOp.memref().getUses(), [](auto &user) {
68+
return dyn_cast<memref::LoadOp>(user.getOwner());
69+
}) <= 1;
70+
}
71+
6372
/// Creates a DictionaryAttr containing a unit attribute 'name'. Used for
6473
/// defining mandatory port attributes for calyx::ComponentOp's.
6574
static DictionaryAttr getMandatoryPortAttr(MLIRContext *ctx, StringRef name) {
@@ -304,6 +313,23 @@ class ComponentLoweringState {
304313
return it->second;
305314
}
306315

316+
/// Registers a calyx::MemoryOp as being associated with a memory identified
317+
/// by 'memref'.
318+
void registerMemory(Value memref, calyx::MemoryOp memoryOp) {
319+
assert(memref.getType().isa<MemRefType>());
320+
assert(memories.find(memref) == memories.end() &&
321+
"Memory already registered for memref");
322+
memories[memref] = memoryOp;
323+
}
324+
325+
/// Returns a calyx::MemoryOp registered for the given memref.
326+
calyx::MemoryOp getMemory(Value memref) {
327+
assert(memref.getType().isa<MemRefType>());
328+
auto it = memories.find(memref);
329+
assert(it != memories.end() && "No memory registered for memref");
330+
return it->second;
331+
}
332+
307333
private:
308334
/// A reference to the parent program lowering state.
309335
ProgramLoweringState &programLoweringState;
@@ -348,6 +374,9 @@ class ComponentLoweringState {
348374

349375
/// A mapping from while ops to iteration argument registers.
350376
DenseMap<Operation *, DenseMap<unsigned, calyx::RegisterOp>> whileIterRegs;
377+
378+
/// A mapping from memref's to their corresponding calyx memory op.
379+
DenseMap<Value, calyx::MemoryOp> memories;
351380
};
352381

353382
/// ProgramLoweringState handles the current state of lowering of a Calyx
@@ -565,10 +594,13 @@ class BuildOpGroups : public FuncOpPartialLoweringPattern {
565594
.template Case<ConstantOp, ReturnOp, BranchOpInterface,
566595
/// SCF
567596
scf::YieldOp,
597+
/// memref
598+
memref::AllocOp, memref::LoadOp, memref::StoreOp,
568599
/// standard arithmetic
569600
AddIOp, SubIOp, CmpIOp, ShiftLeftOp,
570601
UnsignedShiftRightOp, SignedShiftRightOp, AndOp,
571-
XOrOp, OrOp, ZeroExtendIOp, TruncateIOp>(
602+
XOrOp, OrOp, ZeroExtendIOp, TruncateIOp,
603+
IndexCastOp>(
572604
[&](auto op) { return buildOp(rewriter, op).succeeded(); })
573605
.template Case<scf::WhileOp, mlir::FuncOp, scf::ConditionOp>(
574606
[&](auto) {
@@ -606,6 +638,10 @@ class BuildOpGroups : public FuncOpPartialLoweringPattern {
606638
LogicalResult buildOp(PatternRewriter &rewriter, TruncateIOp op) const;
607639
LogicalResult buildOp(PatternRewriter &rewriter, ZeroExtendIOp op) const;
608640
LogicalResult buildOp(PatternRewriter &rewriter, ReturnOp op) const;
641+
LogicalResult buildOp(PatternRewriter &rewriter, IndexCastOp op) const;
642+
LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocOp op) const;
643+
LogicalResult buildOp(PatternRewriter &rewriter, memref::LoadOp op) const;
644+
LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op) const;
609645

610646
/// buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the
611647
/// source operation TSrcOp.
@@ -666,8 +702,107 @@ class BuildOpGroups : public FuncOpPartialLoweringPattern {
666702
return createGroup<TGroupOp>(rewriter, getComponentState().getComponentOp(),
667703
block->front().getLoc(), groupName);
668704
}
705+
706+
/// Creates assignments within the provided group to the address ports of the
707+
/// memoryOp based on the provided addressValues.
708+
void assignAddressPorts(PatternRewriter &rewriter, Location loc,
709+
calyx::GroupInterface group, calyx::MemoryOp memoryOp,
710+
Operation::operand_range addressValues) const {
711+
IRRewriter::InsertionGuard guard(rewriter);
712+
rewriter.setInsertionPointToEnd(group.getBody());
713+
auto addrPorts = memoryOp.addrPorts();
714+
assert(addrPorts.size() == addressValues.size() &&
715+
"Mismatch between number of address ports of the provided memory "
716+
"and address assignment values");
717+
for (auto &idx : enumerate(addressValues))
718+
rewriter.create<calyx::AssignOp>(loc, addrPorts[idx.index()], idx.value(),
719+
Value());
720+
}
669721
};
670722

723+
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
724+
memref::LoadOp loadOp) const {
725+
auto memoryOp = getComponentState().getMemory(loadOp.memref());
726+
if (singleLoadFromMemory(loadOp)) {
727+
/// Single load from memory; Combinational case - we do not have to consider
728+
/// adding registers in front of the memory.
729+
auto combGroup = createGroupForOp<calyx::CombGroupOp>(rewriter, loadOp);
730+
assignAddressPorts(rewriter, loadOp.getLoc(), combGroup, memoryOp,
731+
loadOp.getIndices());
732+
733+
/// We refrain from replacing the loadOp result with memoryOp.readData,
734+
/// since multiple loadOp's need to be converted to a single memory's
735+
/// ReadData. If this replacement is done now, we lose the link between
736+
/// which SSA memref::LoadOp values map to which groups for loading a value
737+
/// from the Calyx memory. At this point of lowering, we keep the
738+
/// memref::LoadOp SSA value, and do value replacement _after_ control has
739+
/// been generated (see LateSSAReplacement). This is *vital* for things such
740+
/// as InlineCombGroups to be able to properly track which memory assignment
741+
/// groups belong to which accesses.
742+
getComponentState().registerEvaluatingGroup(loadOp.getResult(), combGroup);
743+
} else {
744+
auto group = createGroupForOp<calyx::GroupOp>(rewriter, loadOp);
745+
assignAddressPorts(rewriter, loadOp.getLoc(), group, memoryOp,
746+
loadOp.getIndices());
747+
748+
/// Multiple loads from the same memory; In this case, we _may_ have a
749+
/// structural hazard in the design we generate. To get around this, we
750+
/// conservatively place a register in front of each load operation, and
751+
/// replace all uses of the loaded value with the register output. Proper
752+
/// handling of this requires the combinational group inliner/scheduler to
753+
/// be aware of when a combinational expression references multiple loaded
754+
/// values from the same memory, and then schedule assignments to temporary
755+
/// registers to get around the structural hazard.
756+
auto reg = createReg(getComponentState(), rewriter, loadOp.getLoc(),
757+
getComponentState().getUniqueName("load"),
758+
loadOp.getMemRefType().getElementTypeBitWidth());
759+
buildAssignmentsForRegisterWrite(getComponentState(), rewriter, group, reg,
760+
memoryOp.readData());
761+
loadOp.getResult().replaceAllUsesWith(reg.out());
762+
getComponentState().addBlockScheduleable(loadOp->getBlock(), group);
763+
}
764+
return success();
765+
}
766+
767+
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
768+
memref::StoreOp storeOp) const {
769+
auto memoryOp = getComponentState().getMemory(storeOp.memref());
770+
auto group = createGroupForOp<calyx::GroupOp>(rewriter, storeOp);
771+
772+
/// This is a sequential group, so register it as being scheduleable for the
773+
/// block.
774+
getComponentState().addBlockScheduleable(storeOp->getBlock(),
775+
cast<calyx::GroupOp>(group));
776+
assignAddressPorts(rewriter, storeOp.getLoc(), group, memoryOp,
777+
storeOp.getIndices());
778+
rewriter.setInsertionPointToEnd(group.getBody());
779+
rewriter.create<calyx::AssignOp>(storeOp.getLoc(), memoryOp.writeData(),
780+
storeOp.getValueToStore(), Value());
781+
rewriter.create<calyx::AssignOp>(
782+
storeOp.getLoc(), memoryOp.writeEn(),
783+
getComponentState().getConstant(rewriter, storeOp.getLoc(), 1, 1),
784+
Value());
785+
rewriter.create<calyx::GroupDoneOp>(storeOp.getLoc(), memoryOp.done(),
786+
Value());
787+
return success();
788+
}
789+
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
790+
memref::AllocOp allocOp) const {
791+
rewriter.setInsertionPointToStart(getComponent()->getBody());
792+
MemRefType memtype = allocOp.getType();
793+
SmallVector<int64_t> addrSizes;
794+
SmallVector<int64_t> sizes;
795+
for (int64_t dim : memtype.getShape()) {
796+
sizes.push_back(dim);
797+
addrSizes.push_back(llvm::Log2_64_Ceil(dim));
798+
}
799+
auto memoryOp = rewriter.create<calyx::MemoryOp>(
800+
allocOp.getLoc(), getComponentState().getUniqueName("mem"),
801+
memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
802+
getComponentState().registerMemory(allocOp.getResult(), memoryOp);
803+
return success();
804+
}
805+
671806
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
672807
scf::YieldOp yieldOp) const {
673808
if (yieldOp.getOperands().size() == 0)
@@ -816,6 +951,120 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
816951
rewriter, op, {op.getOperand().getType()}, {op.getType()});
817952
}
818953

954+
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
955+
IndexCastOp op) const {
956+
Type sourceType = op.getOperand().getType();
957+
sourceType = sourceType.isIndex() ? rewriter.getI32Type() : sourceType;
958+
Type targetType = op.getResult().getType();
959+
targetType = targetType.isIndex() ? rewriter.getI32Type() : targetType;
960+
unsigned targetBits = targetType.getIntOrFloatBitWidth();
961+
unsigned sourceBits = sourceType.getIntOrFloatBitWidth();
962+
LogicalResult res = success();
963+
964+
if (targetBits == sourceBits) {
965+
/// Drop the index cast and replace uses of the target value with the source
966+
/// value.
967+
op.getResult().replaceAllUsesWith(op.getOperand());
968+
} else {
969+
/// pad/slice the source operand.
970+
if (sourceBits > targetBits)
971+
res = buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
972+
rewriter, op, {sourceType}, {targetType});
973+
else
974+
res = buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
975+
rewriter, op, {sourceType}, {targetType});
976+
}
977+
rewriter.eraseOp(op);
978+
return res;
979+
}
980+
981+
/// This pass rewrites memory accesses that have a width mismatch. Such
982+
/// mismatches are due to index types being assumed 32-bit wide due to the lack
983+
/// of a width inference pass.
984+
class RewriteMemoryAccesses : public PartialLoweringPattern<calyx::AssignOp> {
985+
public:
986+
RewriteMemoryAccesses(MLIRContext *context, LogicalResult &resRef,
987+
ProgramLoweringState &pls)
988+
: PartialLoweringPattern(context, resRef), pls(pls) {}
989+
990+
LogicalResult partiallyLower(calyx::AssignOp assignOp,
991+
PatternRewriter &rewriter) const override {
992+
auto dest = assignOp.dest();
993+
auto destDefOp = dest.getDefiningOp();
994+
/// Is this an assignment to a memory op?
995+
if (!destDefOp)
996+
return success();
997+
auto destDefMem = dyn_cast<calyx::MemoryOp>(destDefOp);
998+
if (!destDefMem)
999+
return success();
1000+
1001+
/// Is this an assignment to an address port of the memory op?
1002+
bool isAssignToAddrPort = llvm::any_of(
1003+
destDefMem.addrPorts(), [&](auto port) { return port == dest; });
1004+
1005+
auto src = assignOp.src();
1006+
auto &state =
1007+
pls.compLoweringState(assignOp->getParentOfType<calyx::ComponentOp>());
1008+
1009+
unsigned srcBits = src.getType().getIntOrFloatBitWidth();
1010+
unsigned dstBits = dest.getType().getIntOrFloatBitWidth();
1011+
if (srcBits == dstBits)
1012+
return success();
1013+
1014+
if (isAssignToAddrPort) {
1015+
SmallVector<Type> types = {rewriter.getIntegerType(srcBits),
1016+
rewriter.getIntegerType(dstBits)};
1017+
auto sliceOp = state.getNewLibraryOpInstance<calyx::SliceLibOp>(
1018+
rewriter, assignOp.getLoc(), types);
1019+
rewriter.setInsertionPoint(assignOp->getBlock(),
1020+
assignOp->getBlock()->begin());
1021+
rewriter.create<calyx::AssignOp>(assignOp->getLoc(), sliceOp.getResult(0),
1022+
src, Value());
1023+
assignOp.setOperand(1, sliceOp.getResult(1));
1024+
} else
1025+
return assignOp.emitError()
1026+
<< "Will only infer slice operators for assign width mismatches "
1027+
"to memory address ports.";
1028+
1029+
return success();
1030+
}
1031+
1032+
private:
1033+
ProgramLoweringState &pls;
1034+
};
1035+
1036+
/// Connverts all index-typed operations and values to i32 values.
1037+
class ConvertIndexTypes : public FuncOpPartialLoweringPattern {
1038+
using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;
1039+
1040+
LogicalResult
1041+
PartiallyLowerFuncToComp(mlir::FuncOp funcOp,
1042+
PatternRewriter &rewriter) const override {
1043+
funcOp.walk([&](Block *block) {
1044+
for (auto arg : block->getArguments())
1045+
if (arg.getType().isIndex())
1046+
arg.setType(rewriter.getI32Type());
1047+
});
1048+
1049+
funcOp.walk([&](Operation *op) {
1050+
for (auto res : op->getResults()) {
1051+
if (!res.getType().isIndex())
1052+
continue;
1053+
1054+
res.setType(rewriter.getI32Type());
1055+
if (auto constOp = dyn_cast<ConstantOp>(op)) {
1056+
APInt value;
1057+
matchConstantOp(constOp, value);
1058+
rewriter.setInsertionPoint(constOp);
1059+
rewriter.replaceOpWithNewOp<ConstantOp>(
1060+
constOp, rewriter.getI32IntegerAttr(value.getSExtValue()));
1061+
}
1062+
}
1063+
});
1064+
return success();
1065+
}
1066+
};
1067+
8191068
/// Inlines Calyx ExecuteRegionOp operations within their parent blocks.
8201069
/// An execution region op (ERO) is inlined by:
8211070
/// i : add a sink basic block for all yield operations inside the
@@ -1330,6 +1579,18 @@ class LateSSAReplacement : public FuncOpPartialLoweringPattern {
13301579
for (auto res : getComponentState().getWhileIterRegs(whileOp))
13311580
whileOp.getResults()[res.first].replaceAllUsesWith(res.second.out());
13321581
});
1582+
1583+
funcOp.walk([&](memref::LoadOp loadOp) {
1584+
if (singleLoadFromMemory(loadOp)) {
1585+
/// In buildOpGroups we did not replace loadOp's results, to ensure a
1586+
/// link between evaluating groups (which fix the input addresses of a
1587+
/// memory op) and a readData result. Now, we may replace these SSA
1588+
/// values with their memoryOp readData output.
1589+
loadOp.getResult().replaceAllUsesWith(
1590+
getComponentState().getMemory(loadOp.memref()).readData());
1591+
}
1592+
});
1593+
13331594
return success();
13341595
}
13351596
};
@@ -1719,6 +1980,10 @@ void SCFToCalyxPass::runOnOperation() {
17191980
/// This pass inlines scf.ExecuteRegionOp's by adding control-flow.
17201981
addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);
17211982

1983+
/// This pattern converts all index types to a predefined width (currently
1984+
/// i32).
1985+
addOncePattern<ConvertIndexTypes>(loweringPatterns, funcMap, *loweringState);
1986+
17221987
/// This pattern creates registers for all basic-block arguments.
17231988
addOncePattern<BuildBBRegs>(loweringPatterns, funcMap, *loweringState);
17241989

@@ -1752,6 +2017,10 @@ void SCFToCalyxPass::runOnOperation() {
17522017
/// after control generation.
17532018
addOncePattern<LateSSAReplacement>(loweringPatterns, funcMap, *loweringState);
17542019

2020+
/// This pattern rewrites accesses to memories which are too wide due to
2021+
/// index types being converted to a fixed-width integer type.
2022+
addOncePattern<RewriteMemoryAccesses>(loweringPatterns, *loweringState);
2023+
17552024
/// This pattern removes the source FuncOp which has now been converted into
17562025
/// a Calyx component.
17572026
addOncePattern<CleanupFuncOps>(loweringPatterns, funcMap, *loweringState);

0 commit comments

Comments
 (0)