Skip to content

Commit 7fab541

Browse files
[mlir][Sol] ABI encode cleanup checks and enable bytes arrays in encodePacked
This commit adds ABI encode cleanup checks following Yul behaviour and enables bytes arrays in encodePacked. Signed-off-by: Vladimir Radosavljevic <vr@matterlabs.dev>
1 parent 60a1008 commit 7fab541

File tree

2 files changed

+116
-68
lines changed

2 files changed

+116
-68
lines changed

mlir/include/mlir/Conversion/SolToStandard/EVMUtil.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,12 @@ class Builder {
142142
int64_t recDepth,
143143
std::optional<Location> locArg = std::nullopt);
144144

145+
/// Returns a normalized scalar value for ABI encoding and emits range checks
146+
/// to match Yul cleanup/validator behavior.
147+
Value normalizeABIScalarForEncoding(
148+
Type ty, Value val, Location loc,
149+
std::optional<sol::DataLocation> srcDataLoc = std::nullopt);
150+
145151
public:
146152
//
147153
// TODO? Should we work with the high level types + OpAdaptor for the APIs
@@ -200,10 +206,10 @@ class Builder {
200206

201207
/// Generates the tuple encoder code as per the ABI and return the new tail
202208
/// address.
203-
Value genABITupleEncoding(Type ty, Value src, Value dstAddr,
204-
bool dstAddrInTail, Value tupleStart,
205-
Value tailAddr,
206-
std::optional<Location> locArg = std::nullopt);
209+
Value genABITupleEncoding(
210+
Type ty, Value src, Value dstAddr, bool dstAddrInTail, Value tupleStart,
211+
Value tailAddr, std::optional<Location> locArg = std::nullopt,
212+
std::optional<sol::DataLocation> srcDataLoc = std::nullopt);
207213

208214
/// Generates the tuple encoder code as per the ABI and returns the address at
209215
/// the end of the tuple.

mlir/lib/Conversion/SolToStandard/EVMUtil.cpp

Lines changed: 106 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,81 @@ unsigned evm::getStorageByteSize(Type ty) {
118118
llvm_unreachable("NYI");
119119
}
120120

121+
Value evm::Builder::normalizeABIScalarForEncoding(
122+
Type ty, Value val, Location loc,
123+
std::optional<sol::DataLocation> srcDataLoc) {
124+
mlir::solgen::BuilderExt bExt(b, loc);
125+
bool fromCalldata = srcDataLoc && *srcDataLoc == sol::DataLocation::CallData;
126+
127+
if (auto intTy = dyn_cast<IntegerType>(ty)) {
128+
if (intTy.getWidth() == 256)
129+
return val;
130+
131+
assert(intTy.getWidth() < 256 &&
132+
"Expected integer types no wider than 256 bits");
133+
auto valTy = cast<IntegerType>(val.getType());
134+
assert((valTy.getWidth() == intTy.getWidth() || valTy.getWidth() == 256) &&
135+
"Expected integer value with source width or i256 width");
136+
137+
// If the value is already at source width, only widen to i256 for storing.
138+
if (valTy.getWidth() == intTy.getWidth())
139+
return bExt.genIntCast(/*width=*/256, intTy.isSigned(), val, loc);
140+
141+
Value normalized;
142+
if (intTy.getWidth() == 1)
143+
// Follow what Yul does for bool values 'iszero(iszero(x))' which is
144+
// effectively a 'x != 0'.
145+
normalized = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, val,
146+
bExt.genI256Const(0));
147+
else
148+
// Do the truncation for non-bool integers.
149+
normalized =
150+
bExt.genIntCast(intTy.getWidth(), intTy.isSigned(), val, loc);
151+
152+
// Finally, extend to 256 bits.
153+
normalized =
154+
bExt.genIntCast(/*width=*/256, intTy.isSigned(), normalized, loc);
155+
if (fromCalldata) {
156+
Value revertCond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
157+
val, normalized);
158+
genRevert(revertCond, loc);
159+
}
160+
return normalized;
161+
}
162+
163+
if (auto enumTy = dyn_cast<sol::EnumType>(ty)) {
164+
Value normalized =
165+
bExt.genIntCast(/*width=*/256, /*isSigned=*/false, val, loc);
166+
Value revertCond =
167+
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, normalized,
168+
bExt.genI256Const(enumTy.getMax()));
169+
if (fromCalldata)
170+
genRevert(revertCond, loc);
171+
else
172+
genPanic(mlir::evm::PanicCode::EnumConversionError, revertCond, loc);
173+
return normalized;
174+
}
175+
176+
if (auto bytesTy = dyn_cast<sol::BytesType>(ty)) {
177+
Value casted = bExt.genIntCast(/*width=*/256, /*isSigned=*/false, val, loc);
178+
if (bytesTy.getSize() == 32)
179+
return casted;
180+
181+
assert(bytesTy.getSize() < 32 && "Expected fixed-bytes width <= 32");
182+
APInt mask = APInt::getHighBitsSet(256, bytesTy.getSize() * 8);
183+
Value normalized =
184+
b.create<arith::AndIOp>(loc, casted, bExt.genI256Const(mask));
185+
if (fromCalldata) {
186+
Value revertCond = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
187+
casted, normalized);
188+
genRevert(revertCond, loc);
189+
}
190+
return normalized;
191+
}
192+
193+
return val;
194+
}
195+
121196
Value evm::Builder::genHeapPtr(Value addr, std::optional<Location> locArg) {
122197
Location loc = locArg ? *locArg : defLoc;
123198

@@ -567,29 +642,30 @@ void evm::Builder::genABITupleSizeAssert(TypeRange tys, Value tupleSize,
567642
genRevert(shortTupleCond, loc);
568643
}
569644

570-
Value evm::Builder::genABITupleEncoding(Type ty, Value src, Value dstAddr,
571-
bool dstAddrInTail, Value tupleStart,
572-
Value tailAddr,
573-
std::optional<Location> locArg) {
645+
Value evm::Builder::genABITupleEncoding(
646+
Type ty, Value src, Value dstAddr, bool dstAddrInTail, Value tupleStart,
647+
Value tailAddr, std::optional<Location> locArg,
648+
std::optional<sol::DataLocation> srcDataLoc) {
574649
Location loc = locArg ? *locArg : defLoc;
575650
mlir::solgen::BuilderExt bExt(b, loc);
576651

577652
// Integer type
578653
if (auto intTy = dyn_cast<IntegerType>(ty)) {
579-
src = bExt.genIntCast(/*width=*/256, intTy.isSigned(), src);
654+
src = normalizeABIScalarForEncoding(intTy, src, loc, srcDataLoc);
580655
b.create<yul::MStoreOp>(loc, dstAddr, src);
581656
return tailAddr;
582657
}
583658

584659
// Enum type
585660
if (auto enumTy = dyn_cast<sol::EnumType>(ty)) {
586-
src = bExt.genIntCast(/*width=*/256, /*isSigned=*/false, src);
661+
src = normalizeABIScalarForEncoding(enumTy, src, loc, srcDataLoc);
587662
b.create<yul::MStoreOp>(loc, dstAddr, src);
588663
return tailAddr;
589664
}
590665

591666
// Bytes type
592667
if (auto bytesTy = dyn_cast<sol::BytesType>(ty)) {
668+
src = normalizeABIScalarForEncoding(bytesTy, src, loc, srcDataLoc);
593669
b.create<yul::MStoreOp>(loc, dstAddr, src);
594670
return tailAddr;
595671
}
@@ -657,13 +733,13 @@ Value evm::Builder::genABITupleEncoding(Type ty, Value src, Value dstAddr,
657733
loc, iDstAddr,
658734
b.create<arith::SubIOp>(loc, iTailAddr, dstArrAddr));
659735
assert(dstAddrInTail);
660-
nextTailAddr =
661-
genABITupleEncoding(arrTy.getEltType(), srcVal, iTailAddr,
662-
dstAddrInTail, tupleStart, iTailAddr, loc);
736+
nextTailAddr = genABITupleEncoding(
737+
arrTy.getEltType(), srcVal, iTailAddr, dstAddrInTail,
738+
tupleStart, iTailAddr, loc, arrTy.getDataLocation());
663739
} else {
664-
nextTailAddr =
665-
genABITupleEncoding(arrTy.getEltType(), srcVal, iDstAddr,
666-
dstAddrInTail, tupleStart, iTailAddr, loc);
740+
nextTailAddr = genABITupleEncoding(
741+
arrTy.getEltType(), srcVal, iDstAddr, dstAddrInTail, tupleStart,
742+
iTailAddr, loc, arrTy.getDataLocation());
667743
}
668744

669745
Value dstStride =
@@ -764,32 +840,30 @@ Value evm::Builder::genABIPackedEncoding(Type ty, Value val, Value addr,
764840

765841
// bool is stored as uint8 in packed encoding.
766842
unsigned byteSize = isBool ? 1 : bitWidth / 8;
767-
Value casted = bExt.genIntCast(/*width=*/256, intTy.isSigned(), val, loc);
768-
if (byteSize < 32) {
769-
unsigned shiftBits = 256 - byteSize * 8;
770-
Value shifted =
771-
b.create<arith::ShLIOp>(loc, casted, bExt.genI256Const(shiftBits));
772-
b.create<yul::MStoreOp>(loc, addr, shifted);
773-
} else {
774-
b.create<yul::MStoreOp>(loc, addr, casted);
775-
}
843+
Value normalized = normalizeABIScalarForEncoding(intTy, val, loc);
844+
if (byteSize < 32)
845+
normalized = b.create<arith::ShLIOp>(
846+
loc, normalized, bExt.genI256Const(256 - byteSize * 8));
847+
848+
b.create<yul::MStoreOp>(loc, addr, normalized);
776849
return b.create<arith::AddIOp>(loc, addr, bExt.genI256Const(byteSize));
777850
}
778851

779852
// Enum type.
780853
if (auto enumTy = dyn_cast<sol::EnumType>(ty)) {
781854
assert(enumTy.getMax() <= 255 &&
782855
"Expected enums with at most 256 elements");
783-
Value casted = bExt.genIntCast(/*width=*/256, /*isSigned=*/false, val, loc);
856+
Value normalized = normalizeABIScalarForEncoding(enumTy, val, loc);
784857
Value shifted =
785-
b.create<arith::ShLIOp>(loc, casted, bExt.genI256Const(248));
858+
b.create<arith::ShLIOp>(loc, normalized, bExt.genI256Const(248));
786859
b.create<yul::MStoreOp>(loc, addr, shifted);
787860
return b.create<arith::AddIOp>(loc, addr, bExt.genI256Const(1));
788861
}
789862

790863
// Bytes type.
791864
if (auto bytesTy = dyn_cast<sol::BytesType>(ty)) {
792-
b.create<yul::MStoreOp>(loc, addr, val);
865+
Value normalized = normalizeABIScalarForEncoding(bytesTy, val, loc);
866+
b.create<yul::MStoreOp>(loc, addr, normalized);
793867
return b.create<arith::AddIOp>(loc, addr,
794868
bExt.genI256Const(bytesTy.getSize()));
795869
}
@@ -835,47 +909,15 @@ Value evm::Builder::genABIPackedEncoding(Type ty, Value val, Value addr,
835909
Type eltTy = arrTy.getEltType();
836910
sol::DataLocation dataLoc = arrTy.getDataLocation();
837911

838-
// Normalize a src element value to i256 for packed encoding,
839-
// performing range checks if needed.
840-
auto normalizeElt = [&](Value srcVal) -> Value {
841-
if (auto intTy = dyn_cast<IntegerType>(eltTy)) {
842-
if (intTy.getWidth() == 256)
843-
return srcVal;
844-
845-
assert(intTy.getWidth() < 256 &&
846-
"Expected integer types smaller than 256 bits");
847-
848-
// Truncate then re-extend to produce the i256 normalized form,
849-
// and validate the range for calldata elements.
850-
Value trunc = bExt.genIntCast(intTy.getWidth(), intTy.isSigned(),
851-
srcVal, loc);
852-
Value ext =
853-
bExt.genIntCast(/*width=*/256, intTy.isSigned(), trunc, loc);
854-
if (dataLoc == sol::DataLocation::CallData) {
855-
auto revertCond = b.create<arith::CmpIOp>(
856-
loc, arith::CmpIPredicate::ne, srcVal, ext);
857-
genRevert(revertCond, loc);
858-
}
859-
return ext;
860-
}
861-
if (auto enumTy = dyn_cast<sol::EnumType>(eltTy)) {
862-
auto cond = b.create<arith::CmpIOp>(
863-
loc, arith::CmpIPredicate::ugt, srcVal,
864-
bExt.genI256Const(enumTy.getMax()));
865-
if (dataLoc == sol::DataLocation::CallData)
866-
genRevert(cond, loc);
867-
else
868-
genPanic(mlir::evm::PanicCode::EnumConversionError, cond, loc);
869-
return bExt.genIntCast(/*width=*/256, /*isSigned=*/false, srcVal,
870-
loc);
871-
}
872-
if (isa<sol::BytesType>(eltTy))
873-
llvm_unreachable("NYI: packed encoding of bytes arrays");
874-
llvm_unreachable("Unexpected type in packed array");
875-
};
876-
877912
Value srcVal = genLoad(iSrcAddr, dataLoc, loc);
878-
b.create<yul::MStoreOp>(loc, iDstAddr, normalizeElt(srcVal));
913+
if (!isa<IntegerType>(eltTy) && !isa<sol::EnumType>(eltTy) &&
914+
!isa<sol::BytesType>(eltTy))
915+
llvm_unreachable(
916+
"Only integer, enum, and bytes types can be packed");
917+
918+
srcVal = normalizeABIScalarForEncoding(eltTy, srcVal, loc, dataLoc);
919+
b.create<yul::MStoreOp>(loc, iDstAddr, srcVal);
920+
879921
Value stride = bExt.genI256Const(getCallDataHeadSize(eltTy));
880922
Value nextDstAddr = b.create<arith::AddIOp>(loc, iDstAddr, stride);
881923
Value nextSrcAddr = b.create<arith::AddIOp>(loc, iSrcAddr, stride);

0 commit comments

Comments
 (0)