Skip to content

Commit

Permalink
[SCFToCalyx] Support memref operations [12/12] (#1863)
Browse files Browse the repository at this point in the history
This commit adds support for memory accessing operations.

All index types are converted to a fixed-width integer. This is necessary due to the lack of a bitwidth inference pass. Upon an index-typed value being used as a memory address input, the address value is truncated to the width of the memory port.
  • Loading branch information
mortbopet authored Sep 23, 2021
1 parent ef45313 commit 184db22
Show file tree
Hide file tree
Showing 2 changed files with 529 additions and 1 deletion.
271 changes: 270 additions & 1 deletion lib/Conversion/SCFToCalyx/SCFToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "circt/Dialect/HW/HWOps.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -60,6 +61,14 @@ static bool matchConstantOp(Operation *op, APInt &value) {
return mlir::detail::constant_int_op_binder(&value).match(op);
}

/// Returns true if there exists only a single memref::LoadOp which loads from
/// the memory referenced by loadOp.
static bool singleLoadFromMemory(memref::LoadOp loadOp) {
return llvm::count_if(loadOp.memref().getUses(), [](auto &user) {
return dyn_cast<memref::LoadOp>(user.getOwner());
}) <= 1;
}

/// Creates a DictionaryAttr containing a unit attribute 'name'. Used for
/// defining mandatory port attributes for calyx::ComponentOp's.
static DictionaryAttr getMandatoryPortAttr(MLIRContext *ctx, StringRef name) {
Expand Down Expand Up @@ -304,6 +313,23 @@ class ComponentLoweringState {
return it->second;
}

/// Registers a calyx::MemoryOp as being associated with a memory identified
/// by 'memref'.
void registerMemory(Value memref, calyx::MemoryOp memoryOp) {
assert(memref.getType().isa<MemRefType>());
assert(memories.find(memref) == memories.end() &&
"Memory already registered for memref");
memories[memref] = memoryOp;
}

/// Returns a calyx::MemoryOp registered for the given memref.
calyx::MemoryOp getMemory(Value memref) {
assert(memref.getType().isa<MemRefType>());
auto it = memories.find(memref);
assert(it != memories.end() && "No memory registered for memref");
return it->second;
}

private:
/// A reference to the parent program lowering state.
ProgramLoweringState &programLoweringState;
Expand Down Expand Up @@ -348,6 +374,9 @@ class ComponentLoweringState {

/// A mapping from while ops to iteration argument registers.
DenseMap<Operation *, DenseMap<unsigned, calyx::RegisterOp>> whileIterRegs;

/// A mapping from memref's to their corresponding calyx memory op.
DenseMap<Value, calyx::MemoryOp> memories;
};

/// ProgramLoweringState handles the current state of lowering of a Calyx
Expand Down Expand Up @@ -565,10 +594,13 @@ class BuildOpGroups : public FuncOpPartialLoweringPattern {
.template Case<ConstantOp, ReturnOp, BranchOpInterface,
/// SCF
scf::YieldOp,
/// memref
memref::AllocOp, memref::LoadOp, memref::StoreOp,
/// standard arithmetic
AddIOp, SubIOp, CmpIOp, ShiftLeftOp,
UnsignedShiftRightOp, SignedShiftRightOp, AndOp,
XOrOp, OrOp, ZeroExtendIOp, TruncateIOp>(
XOrOp, OrOp, ZeroExtendIOp, TruncateIOp,
IndexCastOp>(
[&](auto op) { return buildOp(rewriter, op).succeeded(); })
.template Case<scf::WhileOp, mlir::FuncOp, scf::ConditionOp>(
[&](auto) {
Expand Down Expand Up @@ -606,6 +638,10 @@ class BuildOpGroups : public FuncOpPartialLoweringPattern {
LogicalResult buildOp(PatternRewriter &rewriter, TruncateIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ZeroExtendIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ReturnOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, IndexCastOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, memref::AllocOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, memref::LoadOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, memref::StoreOp op) const;

/// buildLibraryOp will build a TCalyxLibOp inside a TGroupOp based on the
/// source operation TSrcOp.
Expand Down Expand Up @@ -666,8 +702,107 @@ class BuildOpGroups : public FuncOpPartialLoweringPattern {
return createGroup<TGroupOp>(rewriter, getComponentState().getComponentOp(),
block->front().getLoc(), groupName);
}

/// Creates assignments within the provided group to the address ports of the
/// memoryOp based on the provided addressValues.
void assignAddressPorts(PatternRewriter &rewriter, Location loc,
calyx::GroupInterface group, calyx::MemoryOp memoryOp,
Operation::operand_range addressValues) const {
IRRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(group.getBody());
auto addrPorts = memoryOp.addrPorts();
assert(addrPorts.size() == addressValues.size() &&
"Mismatch between number of address ports of the provided memory "
"and address assignment values");
for (auto &idx : enumerate(addressValues))
rewriter.create<calyx::AssignOp>(loc, addrPorts[idx.index()], idx.value(),
Value());
}
};

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
memref::LoadOp loadOp) const {
auto memoryOp = getComponentState().getMemory(loadOp.memref());
if (singleLoadFromMemory(loadOp)) {
/// Single load from memory; Combinational case - we do not have to consider
/// adding registers in front of the memory.
auto combGroup = createGroupForOp<calyx::CombGroupOp>(rewriter, loadOp);
assignAddressPorts(rewriter, loadOp.getLoc(), combGroup, memoryOp,
loadOp.getIndices());

/// We refrain from replacing the loadOp result with memoryOp.readData,
/// since multiple loadOp's need to be converted to a single memory's
/// ReadData. If this replacement is done now, we lose the link between
/// which SSA memref::LoadOp values map to which groups for loading a value
/// from the Calyx memory. At this point of lowering, we keep the
/// memref::LoadOp SSA value, and do value replacement _after_ control has
/// been generated (see LateSSAReplacement). This is *vital* for things such
/// as InlineCombGroups to be able to properly track which memory assignment
/// groups belong to which accesses.
getComponentState().registerEvaluatingGroup(loadOp.getResult(), combGroup);
} else {
auto group = createGroupForOp<calyx::GroupOp>(rewriter, loadOp);
assignAddressPorts(rewriter, loadOp.getLoc(), group, memoryOp,
loadOp.getIndices());

/// Multiple loads from the same memory; In this case, we _may_ have a
/// structural hazard in the design we generate. To get around this, we
/// conservatively place a register in front of each load operation, and
/// replace all uses of the loaded value with the register output. Proper
/// handling of this requires the combinational group inliner/scheduler to
/// be aware of when a combinational expression references multiple loaded
/// values from the same memory, and then schedule assignments to temporary
/// registers to get around the structural hazard.
auto reg = createReg(getComponentState(), rewriter, loadOp.getLoc(),
getComponentState().getUniqueName("load"),
loadOp.getMemRefType().getElementTypeBitWidth());
buildAssignmentsForRegisterWrite(getComponentState(), rewriter, group, reg,
memoryOp.readData());
loadOp.getResult().replaceAllUsesWith(reg.out());
getComponentState().addBlockScheduleable(loadOp->getBlock(), group);
}
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
memref::StoreOp storeOp) const {
auto memoryOp = getComponentState().getMemory(storeOp.memref());
auto group = createGroupForOp<calyx::GroupOp>(rewriter, storeOp);

/// This is a sequential group, so register it as being scheduleable for the
/// block.
getComponentState().addBlockScheduleable(storeOp->getBlock(),
cast<calyx::GroupOp>(group));
assignAddressPorts(rewriter, storeOp.getLoc(), group, memoryOp,
storeOp.getIndices());
rewriter.setInsertionPointToEnd(group.getBody());
rewriter.create<calyx::AssignOp>(storeOp.getLoc(), memoryOp.writeData(),
storeOp.getValueToStore(), Value());
rewriter.create<calyx::AssignOp>(
storeOp.getLoc(), memoryOp.writeEn(),
getComponentState().getConstant(rewriter, storeOp.getLoc(), 1, 1),
Value());
rewriter.create<calyx::GroupDoneOp>(storeOp.getLoc(), memoryOp.done(),
Value());
return success();
}
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
memref::AllocOp allocOp) const {
rewriter.setInsertionPointToStart(getComponent()->getBody());
MemRefType memtype = allocOp.getType();
SmallVector<int64_t> addrSizes;
SmallVector<int64_t> sizes;
for (int64_t dim : memtype.getShape()) {
sizes.push_back(dim);
addrSizes.push_back(llvm::Log2_64_Ceil(dim));
}
auto memoryOp = rewriter.create<calyx::MemoryOp>(
allocOp.getLoc(), getComponentState().getUniqueName("mem"),
memtype.getElementType().getIntOrFloatBitWidth(), sizes, addrSizes);
getComponentState().registerMemory(allocOp.getResult(), memoryOp);
return success();
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
scf::YieldOp yieldOp) const {
if (yieldOp.getOperands().size() == 0)
Expand Down Expand Up @@ -816,6 +951,120 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
rewriter, op, {op.getOperand().getType()}, {op.getType()});
}

LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
IndexCastOp op) const {
Type sourceType = op.getOperand().getType();
sourceType = sourceType.isIndex() ? rewriter.getI32Type() : sourceType;
Type targetType = op.getResult().getType();
targetType = targetType.isIndex() ? rewriter.getI32Type() : targetType;
unsigned targetBits = targetType.getIntOrFloatBitWidth();
unsigned sourceBits = sourceType.getIntOrFloatBitWidth();
LogicalResult res = success();

if (targetBits == sourceBits) {
/// Drop the index cast and replace uses of the target value with the source
/// value.
op.getResult().replaceAllUsesWith(op.getOperand());
} else {
/// pad/slice the source operand.
if (sourceBits > targetBits)
res = buildLibraryOp<calyx::CombGroupOp, calyx::SliceLibOp>(
rewriter, op, {sourceType}, {targetType});
else
res = buildLibraryOp<calyx::CombGroupOp, calyx::PadLibOp>(
rewriter, op, {sourceType}, {targetType});
}
rewriter.eraseOp(op);
return res;
}

/// This pass rewrites memory accesses that have a width mismatch. Such
/// mismatches are due to index types being assumed 32-bit wide due to the lack
/// of a width inference pass.
class RewriteMemoryAccesses : public PartialLoweringPattern<calyx::AssignOp> {
public:
RewriteMemoryAccesses(MLIRContext *context, LogicalResult &resRef,
ProgramLoweringState &pls)
: PartialLoweringPattern(context, resRef), pls(pls) {}

LogicalResult partiallyLower(calyx::AssignOp assignOp,
PatternRewriter &rewriter) const override {
auto dest = assignOp.dest();
auto destDefOp = dest.getDefiningOp();
/// Is this an assignment to a memory op?
if (!destDefOp)
return success();
auto destDefMem = dyn_cast<calyx::MemoryOp>(destDefOp);
if (!destDefMem)
return success();

/// Is this an assignment to an address port of the memory op?
bool isAssignToAddrPort = llvm::any_of(
destDefMem.addrPorts(), [&](auto port) { return port == dest; });

auto src = assignOp.src();
auto &state =
pls.compLoweringState(assignOp->getParentOfType<calyx::ComponentOp>());

unsigned srcBits = src.getType().getIntOrFloatBitWidth();
unsigned dstBits = dest.getType().getIntOrFloatBitWidth();
if (srcBits == dstBits)
return success();

if (isAssignToAddrPort) {
SmallVector<Type> types = {rewriter.getIntegerType(srcBits),
rewriter.getIntegerType(dstBits)};
auto sliceOp = state.getNewLibraryOpInstance<calyx::SliceLibOp>(
rewriter, assignOp.getLoc(), types);
rewriter.setInsertionPoint(assignOp->getBlock(),
assignOp->getBlock()->begin());
rewriter.create<calyx::AssignOp>(assignOp->getLoc(), sliceOp.getResult(0),
src, Value());
assignOp.setOperand(1, sliceOp.getResult(1));
} else
return assignOp.emitError()
<< "Will only infer slice operators for assign width mismatches "
"to memory address ports.";

return success();
}

private:
ProgramLoweringState &pls;
};

/// Connverts all index-typed operations and values to i32 values.
class ConvertIndexTypes : public FuncOpPartialLoweringPattern {
using FuncOpPartialLoweringPattern::FuncOpPartialLoweringPattern;

LogicalResult
PartiallyLowerFuncToComp(mlir::FuncOp funcOp,
PatternRewriter &rewriter) const override {
funcOp.walk([&](Block *block) {
for (auto arg : block->getArguments())
if (arg.getType().isIndex())
arg.setType(rewriter.getI32Type());
});

funcOp.walk([&](Operation *op) {
for (auto res : op->getResults()) {
if (!res.getType().isIndex())
continue;

res.setType(rewriter.getI32Type());
if (auto constOp = dyn_cast<ConstantOp>(op)) {
APInt value;
matchConstantOp(constOp, value);
rewriter.setInsertionPoint(constOp);
rewriter.replaceOpWithNewOp<ConstantOp>(
constOp, rewriter.getI32IntegerAttr(value.getSExtValue()));
}
}
});
return success();
}
};

/// Inlines Calyx ExecuteRegionOp operations within their parent blocks.
/// An execution region op (ERO) is inlined by:
/// i : add a sink basic block for all yield operations inside the
Expand Down Expand Up @@ -1330,6 +1579,18 @@ class LateSSAReplacement : public FuncOpPartialLoweringPattern {
for (auto res : getComponentState().getWhileIterRegs(whileOp))
whileOp.getResults()[res.first].replaceAllUsesWith(res.second.out());
});

funcOp.walk([&](memref::LoadOp loadOp) {
if (singleLoadFromMemory(loadOp)) {
/// In buildOpGroups we did not replace loadOp's results, to ensure a
/// link between evaluating groups (which fix the input addresses of a
/// memory op) and a readData result. Now, we may replace these SSA
/// values with their memoryOp readData output.
loadOp.getResult().replaceAllUsesWith(
getComponentState().getMemory(loadOp.memref()).readData());
}
});

return success();
}
};
Expand Down Expand Up @@ -1719,6 +1980,10 @@ void SCFToCalyxPass::runOnOperation() {
/// This pass inlines scf.ExecuteRegionOp's by adding control-flow.
addGreedyPattern<InlineExecuteRegionOpPattern>(loweringPatterns);

/// This pattern converts all index types to a predefined width (currently
/// i32).
addOncePattern<ConvertIndexTypes>(loweringPatterns, funcMap, *loweringState);

/// This pattern creates registers for all basic-block arguments.
addOncePattern<BuildBBRegs>(loweringPatterns, funcMap, *loweringState);

Expand Down Expand Up @@ -1752,6 +2017,10 @@ void SCFToCalyxPass::runOnOperation() {
/// after control generation.
addOncePattern<LateSSAReplacement>(loweringPatterns, funcMap, *loweringState);

/// This pattern rewrites accesses to memories which are too wide due to
/// index types being converted to a fixed-width integer type.
addOncePattern<RewriteMemoryAccesses>(loweringPatterns, *loweringState);

/// This pattern removes the source FuncOp which has now been converted into
/// a Calyx component.
addOncePattern<CleanupFuncOps>(loweringPatterns, funcMap, *loweringState);
Expand Down
Loading

0 comments on commit 184db22

Please sign in to comment.