Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions mlir/lib/Conversion/SolToStandard/SolToYul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,23 @@ struct BytesCastOpLowering : public OpConversionPattern<sol::BytesCastOp> {
ConversionPatternRewriter &r) const override {
Location loc = op.getLoc();
mlir::solgen::BuilderExt bExt(r, loc);
Type inpTy = op.getInp().getType();
Type outTy = op.getType();

// Bytes to bytes
if (auto inpBytesTy = dyn_cast<sol::BytesType>(inpTy)) {
if (auto outBytesTy = dyn_cast<sol::BytesType>(outTy)) {
unsigned keepBytes = inpBytesTy.getSize() < outBytesTy.getSize()
? inpBytesTy.getSize()
: outBytesTy.getSize();
auto shiftAmt = bExt.genI256Const(256 - (8 * keepBytes));
auto shr = r.create<arith::ShRUIOp>(loc, adaptor.getInp(), shiftAmt);
r.replaceOpWithNewOp<arith::ShLIOp>(op, shr, shiftAmt);
return success();
}

// Bytes to int
if (auto inpBytesTy = dyn_cast<sol::BytesType>(op.getInp().getType())) {
auto outIntTy = cast<IntegerType>(op.getType());
// Bytes to int
auto outIntTy = cast<IntegerType>(outTy);
auto shiftAmt = bExt.genI256Const(256 - (8 * inpBytesTy.getSize()));
auto shr = r.create<arith::ShRUIOp>(loc, adaptor.getInp(), shiftAmt);
auto repl = bExt.genIntCast(outIntTy.getWidth(), /*isSigned=*/false, shr);
Expand All @@ -146,8 +159,8 @@ struct BytesCastOpLowering : public OpConversionPattern<sol::BytesCastOp> {
}

// Int to bytes
assert(isa<IntegerType>(adaptor.getInp().getType()));
auto outBytesTy = cast<sol::BytesType>(op.getType());
assert(isa<IntegerType>(inpTy));
auto outBytesTy = cast<sol::BytesType>(outTy);
Value inpAsI256 =
bExt.genIntCast(/*width=*/256, /*isSigned=*/false, adaptor.getInp());
auto shiftAmt = bExt.genI256Const(256 - (8 * outBytesTy.getSize()));
Expand Down
22 changes: 18 additions & 4 deletions mlir/lib/Dialect/Sol/SolOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,26 @@ bool EnumCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

bool BytesCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
assert(inputs.size() == 1 && outputs.size() == 1);
if (auto inpIntTy = dyn_cast<IntegerType>(inputs.front())) {
auto outBytesTy = cast<BytesType>(outputs.front());
Type inpTy = inputs.front();
Type outTy = outputs.front();

// int -> int is not a bytes cast.
if (isa<IntegerType>(inpTy) && isa<IntegerType>(outTy))
return false;

// bytes -> bytes.
if (isa<BytesType>(inpTy) && isa<BytesType>(outTy))
return true;

// int -> bytes.
if (auto inpIntTy = dyn_cast<IntegerType>(inpTy)) {
auto outBytesTy = cast<BytesType>(outTy);
return inpIntTy.getWidth() == outBytesTy.getSize() * 8;
}
return cast<BytesType>(inputs.front()).getSize() * 8 ==
cast<IntegerType>(outputs.front()).getWidth();

// bytes -> int.
return cast<BytesType>(inpTy).getSize() * 8 ==
cast<IntegerType>(outTy).getWidth();
}

//===----------------------------------------------------------------------===//
Expand Down
Loading