Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit 1be5a2b

Browse files
authored
Merge pull request #114 from sjrd/simpler-int-div-mod
Generate simpler code for Int and Long division and remainder.
2 parents 145f63c + db34150 commit 1be5a2b

File tree

1 file changed

+136
-74
lines changed

1 file changed

+136
-74
lines changed

Diff for: wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala

+136-74
Original file line numberDiff line numberDiff line change
@@ -977,11 +977,147 @@ private class WasmExpressionBuilder private (
977977
IRTypes.LongType
978978
}
979979

980+
def genThrowArithmeticException(): Unit = {
981+
implicit val pos = binary.pos
982+
val divisionByZeroEx = IRTrees.Throw(
983+
IRTrees.New(
984+
IRNames.ArithmeticExceptionClass,
985+
IRTrees.MethodIdent(
986+
IRNames.MethodName.constructor(List(IRTypes.ClassRef(IRNames.BoxedStringClass)))
987+
),
988+
List(IRTrees.StringLiteral("/ by zero "))
989+
)
990+
)
991+
genThrow(divisionByZeroEx)
992+
}
993+
994+
def genDivModByConstant[T](
995+
isDiv: Boolean,
996+
rhsValue: T,
997+
const: T => WasmInstr,
998+
sub: WasmInstr,
999+
mainOp: WasmInstr
1000+
)(implicit num: Numeric[T]): IRTypes.Type = {
1001+
/* When we statically know the value of the rhs, we can avoid the
1002+
* dynamic tests for division by zero and overflow. This is quite
1003+
* common in practice.
1004+
*/
1005+
1006+
val tpe = binary.tpe
1007+
1008+
if (rhsValue == num.zero) {
1009+
genTree(binary.lhs, tpe)
1010+
fctx.markPosition(binary)
1011+
genThrowArithmeticException()
1012+
IRTypes.NothingType
1013+
} else if (isDiv && rhsValue == num.fromInt(-1)) {
1014+
/* MinValue / -1 overflows; it traps in Wasm but we need to wrap.
1015+
* We rewrite as `0 - lhs` so that we do not need any test.
1016+
*/
1017+
fctx.markPosition(binary)
1018+
instrs += const(num.zero)
1019+
genTree(binary.lhs, tpe)
1020+
fctx.markPosition(binary)
1021+
instrs += sub
1022+
tpe
1023+
} else {
1024+
genTree(binary.lhs, tpe)
1025+
fctx.markPosition(binary.rhs)
1026+
instrs += const(rhsValue)
1027+
fctx.markPosition(binary)
1028+
instrs += mainOp
1029+
tpe
1030+
}
1031+
}
1032+
1033+
def genDivMod[T](
1034+
isDiv: Boolean,
1035+
const: T => WasmInstr,
1036+
eqz: WasmInstr,
1037+
eq: WasmInstr,
1038+
sub: WasmInstr,
1039+
mainOp: WasmInstr
1040+
)(implicit num: Numeric[T]): IRTypes.Type = {
1041+
/* Here we perform the same steps as in the static case, but using
1042+
* value tests at run-time.
1043+
*/
1044+
1045+
val tpe = binary.tpe
1046+
val wasmTyp = TypeTransformer.transformType(tpe)(ctx)
1047+
1048+
val lhsLocal = fctx.addSyntheticLocal(wasmTyp)
1049+
val rhsLocal = fctx.addSyntheticLocal(wasmTyp)
1050+
genTree(binary.lhs, tpe)
1051+
instrs += LOCAL_SET(lhsLocal)
1052+
genTree(binary.rhs, tpe)
1053+
instrs += LOCAL_TEE(rhsLocal)
1054+
1055+
fctx.markPosition(binary)
1056+
1057+
instrs += eqz
1058+
fctx.ifThen() {
1059+
genThrowArithmeticException()
1060+
}
1061+
if (isDiv) {
1062+
// Handle the MinValue / -1 corner case
1063+
instrs += LOCAL_GET(rhsLocal)
1064+
instrs += const(num.fromInt(-1))
1065+
instrs += eq
1066+
fctx.ifThenElse(wasmTyp) {
1067+
// 0 - lhs
1068+
instrs += const(num.zero)
1069+
instrs += LOCAL_GET(lhsLocal)
1070+
instrs += sub
1071+
} {
1072+
// lhs / rhs
1073+
instrs += LOCAL_GET(lhsLocal)
1074+
instrs += LOCAL_GET(rhsLocal)
1075+
instrs += mainOp
1076+
}
1077+
} else {
1078+
// lhs % rhs
1079+
instrs += LOCAL_GET(lhsLocal)
1080+
instrs += LOCAL_GET(rhsLocal)
1081+
instrs += mainOp
1082+
}
1083+
1084+
tpe
1085+
}
1086+
9801087
binary.op match {
9811088
case BinaryOp.=== | BinaryOp.!== => genEq(binary)
9821089

9831090
case BinaryOp.String_+ => genStringConcat(binary)
9841091

1092+
case BinaryOp.Int_/ =>
1093+
binary.rhs match {
1094+
case IRTrees.IntLiteral(rhsValue) =>
1095+
genDivModByConstant(isDiv = true, rhsValue, I32_CONST(_), I32_SUB, I32_DIV_S)
1096+
case _ =>
1097+
genDivMod(isDiv = true, I32_CONST(_), I32_EQZ, I32_EQ, I32_SUB, I32_DIV_S)
1098+
}
1099+
case BinaryOp.Int_% =>
1100+
binary.rhs match {
1101+
case IRTrees.IntLiteral(rhsValue) =>
1102+
genDivModByConstant(isDiv = false, rhsValue, I32_CONST(_), I32_SUB, I32_REM_S)
1103+
case _ =>
1104+
genDivMod(isDiv = false, I32_CONST(_), I32_EQZ, I32_EQ, I32_SUB, I32_REM_S)
1105+
}
1106+
case BinaryOp.Long_/ =>
1107+
binary.rhs match {
1108+
case IRTrees.LongLiteral(rhsValue) =>
1109+
genDivModByConstant(isDiv = true, rhsValue, I64_CONST(_), I64_SUB, I64_DIV_S)
1110+
case _ =>
1111+
genDivMod(isDiv = true, I64_CONST(_), I64_EQZ, I64_EQ, I64_SUB, I64_DIV_S)
1112+
}
1113+
case BinaryOp.Long_% =>
1114+
binary.rhs match {
1115+
case IRTrees.LongLiteral(rhsValue) =>
1116+
genDivModByConstant(isDiv = false, rhsValue, I64_CONST(_), I64_SUB, I64_REM_S)
1117+
case _ =>
1118+
genDivMod(isDiv = false, I64_CONST(_), I64_EQZ, I64_EQ, I64_SUB, I64_REM_S)
1119+
}
1120+
9851121
case BinaryOp.Long_<< => genLongShiftOp(I64_SHL)
9861122
case BinaryOp.Long_>>> => genLongShiftOp(I64_SHR_U)
9871123
case BinaryOp.Long_>> => genLongShiftOp(I64_SHR_S)
@@ -1019,80 +1155,6 @@ private class WasmExpressionBuilder private (
10191155
instrs += CALL(WasmFunctionName.stringCharAt)
10201156
IRTypes.CharType
10211157

1022-
// Check division by zero
1023-
// (Int|Long).MinValue / -1 = (Int|Long).MinValue because of overflow
1024-
case BinaryOp.Int_/ | BinaryOp.Long_/ | BinaryOp.Int_% | BinaryOp.Long_% =>
1025-
implicit val noPos = Position.NoPosition
1026-
val divisionByZeroEx = IRTrees.Throw(
1027-
IRTrees.New(
1028-
IRNames.ArithmeticExceptionClass,
1029-
IRTrees.MethodIdent(
1030-
IRNames.MethodName.constructor(List(IRTypes.ClassRef(IRNames.BoxedStringClass)))
1031-
),
1032-
List(IRTrees.StringLiteral("/ by zero "))
1033-
)
1034-
)
1035-
val resType = TypeTransformer.transformType(binary.tpe)(ctx)
1036-
1037-
val lhs = fctx.addSyntheticLocal(TypeTransformer.transformType(binary.lhs.tpe)(ctx))
1038-
val rhs = fctx.addSyntheticLocal(TypeTransformer.transformType(binary.rhs.tpe)(ctx))
1039-
genTreeAuto(binary.lhs)
1040-
instrs += LOCAL_SET(lhs)
1041-
genTreeAuto(binary.rhs)
1042-
instrs += LOCAL_SET(rhs)
1043-
1044-
fctx.markPosition(binary)
1045-
1046-
fctx.block(resType) { done =>
1047-
fctx.block() { default =>
1048-
fctx.block() { divisionByZero =>
1049-
instrs += LOCAL_GET(rhs)
1050-
binary.op match {
1051-
case BinaryOp.Int_/ | BinaryOp.Int_% => instrs += I32_EQZ
1052-
case BinaryOp.Long_/ | BinaryOp.Long_% => instrs += I64_EQZ
1053-
}
1054-
instrs += BR_IF(divisionByZero)
1055-
1056-
// Check overflow for division
1057-
if (binary.op == BinaryOp.Int_/ || binary.op == BinaryOp.Long_/) {
1058-
fctx.block() { overflow =>
1059-
instrs += LOCAL_GET(rhs)
1060-
if (binary.op == BinaryOp.Int_/) instrs ++= List(I32_CONST(-1), I32_EQ)
1061-
else instrs ++= List(I64_CONST(-1), I64_EQ)
1062-
fctx.ifThen() { // if (rhs == -1)
1063-
instrs += LOCAL_GET(lhs)
1064-
if (binary.op == BinaryOp.Int_/)
1065-
instrs ++= List(I32_CONST(Int.MinValue), I32_EQ)
1066-
else instrs ++= List(I64_CONST(Long.MinValue), I64_EQ)
1067-
instrs += BR_IF(overflow)
1068-
}
1069-
instrs += BR(default)
1070-
}
1071-
// overflow
1072-
if (binary.op == BinaryOp.Int_/) instrs += I32_CONST(Int.MinValue)
1073-
else instrs += I64_CONST(Long.MinValue)
1074-
instrs += BR(done)
1075-
}
1076-
1077-
// remainder
1078-
instrs += BR(default)
1079-
}
1080-
// division by zero
1081-
genThrow(divisionByZeroEx)
1082-
}
1083-
// default
1084-
instrs += LOCAL_GET(lhs)
1085-
instrs += LOCAL_GET(rhs)
1086-
instrs +=
1087-
(binary.op match {
1088-
case BinaryOp.Int_/ => I32_DIV_S
1089-
case BinaryOp.Int_% => I32_REM_S
1090-
case BinaryOp.Long_/ => I64_DIV_S
1091-
case BinaryOp.Long_% => I64_REM_S
1092-
})
1093-
binary.tpe
1094-
}
1095-
10961158
case _ => genElementaryBinaryOp(binary)
10971159
}
10981160
}

0 commit comments

Comments
 (0)