diff --git a/wasm/src/main/scala/ir2wasm/TypeTransformer.scala b/wasm/src/main/scala/ir2wasm/TypeTransformer.scala index b2d448f7..8976c6dd 100644 --- a/wasm/src/main/scala/ir2wasm/TypeTransformer.scala +++ b/wasm/src/main/scala/ir2wasm/TypeTransformer.scala @@ -35,9 +35,8 @@ object TypeTransformer { t: IRTypes.Type )(implicit ctx: ReadOnlyWasmContext): List[Types.WasmType] = t match { - case IRTypes.NoType => Nil - case IRTypes.ClassType(className) if className == IRNames.BoxedUnitClass => Nil - case _ => List(transformType(t)) + case IRTypes.NoType => Nil + case _ => List(transformType(t)) } def transformType(t: IRTypes.Type)(implicit ctx: ReadOnlyWasmContext): Types.WasmType = t match { @@ -64,6 +63,10 @@ object TypeTransformer { Types.WasmRefNullType( Types.WasmHeapType.Type(Names.WasmTypeName.WasmStructTypeName.string) ) + case IRNames.BoxedUnitClass => + Types.WasmRefNullType( + Types.WasmHeapType.Type(Names.WasmTypeName.WasmStructTypeName.undef) + ) case _ => if (ctx.getClassInfo(clazz.className).isInterface) Types.WasmRefNullType(Types.WasmHeapType.ObjectType) diff --git a/wasm/src/main/scala/ir2wasm/WasmBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmBuilder.scala index ec8b46f8..7a269766 100644 --- a/wasm/src/main/scala/ir2wasm/WasmBuilder.scala +++ b/wasm/src/main/scala/ir2wasm/WasmBuilder.scala @@ -34,7 +34,6 @@ class WasmBuilder { def transformTopLevelExport(export: LinkedTopLevelExport)(implicit ctx: WasmContext): Unit = { implicit val fctx = WasmFunctionContext() - val expressionBuilder = new WasmExpressionBuilder(ctx, fctx) export.tree match { case d: IRTrees.TopLevelFieldExportDef => ??? case d: IRTrees.TopLevelJSClassExportDef => ??? @@ -295,7 +294,6 @@ class WasmBuilder { private def transformToplevelMethodExportDef( exportDef: IRTrees.TopLevelMethodExportDef )(implicit ctx: WasmContext, fctx: WasmFunctionContext) = { - val builder = new WasmExpressionBuilder(ctx, fctx) val method = exportDef.methodDef val methodName = method.name match { case lit: IRTrees.StringLiteral => lit @@ -363,15 +361,12 @@ class WasmBuilder { } params.foreach(fctx.locals.define) - val instrs = newBody match { - case t: IRTrees.Block => t.stats.flatMap(builder.transformTree) - case _ => builder.transformTree(newBody) - } + val expr = WasmExpressionBuilder.transformBody(newBody, resultType) val func = WasmFunction( Names.WasmFunctionName(methodName), functionType, fctx.locals.all, - WasmExpr(instrs) + expr ) ctx.addFunction(func) @@ -417,18 +412,10 @@ class WasmBuilder { }).foreach(fctx.locals.define) // build function body - val builder = new WasmExpressionBuilder(ctx, fctx) val body = method.body.getOrElse(throw new Exception("abstract method cannot be transformed")) // val prefix = // if (method.flags.namespace.isConstructor) builder.objectCreationPrefix(clazz, method) else Nil - val instrs = body match { - case t: IRTrees.Block => t.stats.flatMap(builder.transformTree) - case _ => builder.transformTree(body) - } - val expr = method.resultType match { - case IRTypes.NoType => WasmExpr(instrs) - case _ => WasmExpr(instrs :+ RETURN) - } + val expr = WasmExpressionBuilder.transformBody(body, method.resultType) val func = WasmFunction( Names.WasmFunctionName(clazz.name.name, method.name.name), diff --git a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala index 1a5d85ba..9764ff44 100644 --- a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala +++ b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala @@ -3,6 +3,8 @@ package ir2wasm import scala.annotation.switch +import scala.collection.mutable + import org.scalajs.ir.{Trees => IRTrees} import org.scalajs.ir.{Types => IRTypes} import org.scalajs.ir.{Names => IRNames} @@ -15,7 +17,22 @@ import wasm4s.WasmImmediate._ import org.scalajs.ir.Types.ClassType import org.scalajs.ir.ClassKind -class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFunctionContext) { +object WasmExpressionBuilder { + def transformBody(tree: IRTrees.Tree, resultType: IRTypes.Type)( + implicit ctx: FunctionTypeWriterWasmContext, fctx: WasmFunctionContext + ): WasmExpr = { + val builder = new WasmExpressionBuilder(ctx, fctx) + WasmExpr(builder.genBody(tree, resultType)) + } +} + +private class WasmExpressionBuilder private (ctx: FunctionTypeWriterWasmContext, fctx: WasmFunctionContext) { + private val instrs = mutable.ListBuffer.empty[WasmInstr] + + def genBody(tree: IRTrees.Tree, expectedType: IRTypes.Type): List[WasmInstr] = { + genTree(tree, expectedType) + instrs.toList + } /** object creation prefix * ``` @@ -47,33 +64,34 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti // // ) // } - def transformTree(tree: IRTrees.Tree): List[WasmInstr] = { - tree match { - case t: IRTrees.Literal => transformLiteral(t) - case t: IRTrees.UnaryOp => transformUnaryOp(t) - case t: IRTrees.BinaryOp => transformBinaryOp(t) - case t: IRTrees.VarRef => List(transformVarRef(t)) - case t: IRTrees.LoadModule => transformLoadModule(t) - case t: IRTrees.StoreModule => - transformStoreModule(t) - case t: IRTrees.This => // push receiver to the stack - List(LOCAL_GET(LocalIdx(fctx.receiver.name))) - case t: IRTrees.ApplyStatic => ??? - case t: IRTrees.ApplyStatically => - transformApplyStatically(t) - case t: IRTrees.Apply => transformApply(t) - case t: IRTrees.ApplyDynamicImport => ??? - case t: IRTrees.AsInstanceOf => transformAsInstanceOf(t) - case t: IRTrees.Block => transformBlock(t) - case t: IRTrees.Labeled => transformLabeled(t) - case t: IRTrees.Return => transformReturn(t) - case t: IRTrees.Select => transformSelect(t) - case t: IRTrees.Assign => transformAssign(t) - case t: IRTrees.VarDef => transformVarDef(t) - case t: IRTrees.New => transformNew(t) - case t: IRTrees.If => transformIf(t) - case t: IRTrees.While => transformWhile(t) - case t: IRTrees.Skip => Nil + def genTrees(trees: List[IRTrees.Tree], expectedTypes: List[IRTypes.Type]): Unit = + trees.lazyZip(expectedTypes).foreach(genTree(_, _)) + + def genTreeAuto(tree: IRTrees.Tree): Unit = + genTree(tree, tree.tpe) + + def genTree(tree: IRTrees.Tree, expectedType: IRTypes.Type): Unit = { + val generatedType: IRTypes.Type = tree match { + case t: IRTrees.Literal => genLiteral(t) + case t: IRTrees.UnaryOp => genUnaryOp(t) + case t: IRTrees.BinaryOp => genBinaryOp(t) + case t: IRTrees.VarRef => genVarRef(t) + case t: IRTrees.LoadModule => genLoadModule(t) + case t: IRTrees.StoreModule => genStoreModule(t) + case t: IRTrees.This => genThis(t) + case t: IRTrees.ApplyStatically => genApplyStatically(t) + case t: IRTrees.Apply => genApply(t) + case t: IRTrees.AsInstanceOf => genAsInstanceOf(t) + case t: IRTrees.Block => genBlock(t, expectedType) + case t: IRTrees.Labeled => genLabeled(t, expectedType) + case t: IRTrees.Return => genReturn(t) + case t: IRTrees.Select => genSelect(t) + case t: IRTrees.Assign => genAssign(t) + case t: IRTrees.VarDef => genVarDef(t) + case t: IRTrees.New => genNew(t) + case t: IRTrees.If => genIf(t, expectedType) + case t: IRTrees.While => genWhile(t) + case t: IRTrees.Skip => IRTypes.NoType case _ => println(tree) ??? @@ -131,55 +149,61 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti // case IRTrees.IdentityHashCode(pos) => } + genAdapt(generatedType, expectedType) } - private def transformAssign(t: IRTrees.Assign): List[WasmInstr] = { - val wasmRHS = transformTree(t.rhs) + private def genAdapt(generatedType: IRTypes.Type, expectedType: IRTypes.Type): Unit = { + (generatedType, expectedType) match { + case _ if generatedType == expectedType => + () + case (IRTypes.NothingType, _) => + () + case (_, IRTypes.NoType) => + instrs += DROP + case _ => + () + } + } + + private def genAssign(t: IRTrees.Assign): IRTypes.Type = { t.lhs match { case sel: IRTrees.Select => val className = WasmStructTypeName(sel.className) val fieldName = WasmFieldName(sel.field.name) val idx = ctx.getClassInfo(sel.className).getFieldIdx(fieldName) - val castIfRequired = sel.qualifier match { - case _: IRTrees.This => - List( - // requires cast if the qualifier is `this` - // because receiver type is Object in wasm - REF_CAST( - HeapType(Types.WasmHeapType.Type(WasmStructTypeName(sel.className))) - ) - ) - case _ => Nil - } - transformTree(sel.qualifier) ++ castIfRequired ++ wasmRHS :+ - STRUCT_SET(TypeIdx(className), idx) + + // For Select, the receiver can never be a hijacked class, so we can use genTreeAuto + genTreeAuto(sel.qualifier) + + genTree(t.rhs, t.lhs.tpe) + instrs += STRUCT_SET(TypeIdx(className), idx) case sel: IRTrees.SelectStatic => // OK? val className = WasmStructTypeName(sel.className) val fieldName = WasmFieldName(sel.field.name) val idx = ctx.getClassInfo(sel.className).getFieldIdx(fieldName) - List( - GLOBAL_GET( - GlobalIdx(Names.WasmGlobalName.WasmModuleInstanceName.fromIR(sel.className)) - ) - ) ++ wasmRHS ++ List( - STRUCT_SET(TypeIdx(className), idx) + instrs += GLOBAL_GET( + GlobalIdx(Names.WasmGlobalName.WasmModuleInstanceName.fromIR(sel.className)) ) + genTree(t.rhs, t.lhs.tpe) + instrs += STRUCT_SET(TypeIdx(className), idx) + case assign: IRTrees.ArraySelect => ??? // array.set case assign: IRTrees.RecordSelect => ??? // struct.set case assign: IRTrees.JSPrivateSelect => ??? case assign: IRTrees.JSSelect => ??? case assign: IRTrees.JSSuperSelect => ??? case assign: IRTrees.JSGlobalRef => ??? + case ref: IRTrees.VarRef => - wasmRHS :+ LOCAL_SET(LocalIdx(Names.WasmLocalName.fromIR(ref.ident.name))) + genTree(t.rhs, t.lhs.tpe) + instrs += LOCAL_SET(LocalIdx(Names.WasmLocalName.fromIR(ref.ident.name))) } - } - private def transformApply(t: IRTrees.Apply): List[WasmInstr] = { - val pushReceiver = transformTree(t.receiver) - val wasmArgs = t.args.flatMap(transformTree) + IRTypes.NoType + } + private def genApply(t: IRTrees.Apply): IRTypes.Type = { val receiverClassName = t.receiver.tpe match { case ClassType(className) => className case prim: IRTypes.PrimType => IRTypes.PrimTypeToBoxedClass(prim) @@ -187,6 +211,13 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti } val receiverClassInfo = ctx.getClassInfo(receiverClassName) + def genReceiverArgsReceiver(): Unit = { + assert(!IRNames.HijackedClasses.contains(receiverClassName)) + genTreeAuto(t.receiver) + genArgs(t.args, t.method.name) + genTreeAuto(t.receiver) // TODO Reuse the receiver computed above + } + if (receiverClassInfo.isInterface) { // interface dispatch val itables = ctx.calculateClassItables(clazz = receiverClassName) @@ -207,41 +238,34 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti // case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}") // } - pushReceiver ++ wasmArgs ++ pushReceiver ++ - List( - STRUCT_GET( - // receiver type should be upcasted into `Object` if it's interface - // by TypeTransformer#transformType - TypeIdx(WasmStructTypeName(IRNames.ObjectClass)), - StructFieldIdx(1) - ), - I32_CONST(I32(itableIdx)), - ARRAY_GET( - TypeIdx(WasmArrayType.itables.name) - ), - REF_CAST( - HeapType(Types.WasmHeapType.Type(WasmITableTypeName(targetClass.name))) - ), - STRUCT_GET( - TypeIdx(WasmITableTypeName(targetClass.name)), - StructFieldIdx(methodIdx) - ), - CALL_REF( - TypeIdx(method.toWasmFunctionType()(ctx).name) - ) - ) + genReceiverArgsReceiver() + instrs += STRUCT_GET( + // receiver type should be upcasted into `Object` if it's interface + // by TypeTransformer#transformType + TypeIdx(WasmStructTypeName(IRNames.ObjectClass)), + StructFieldIdx(1) + ) + instrs += I32_CONST(I32(itableIdx)) + instrs += ARRAY_GET( + TypeIdx(WasmArrayType.itables.name) + ) + instrs += REF_CAST( + HeapType(Types.WasmHeapType.Type(WasmITableTypeName(targetClass.name))) + ) + instrs += STRUCT_GET( + TypeIdx(WasmITableTypeName(targetClass.name)), + StructFieldIdx(methodIdx) + ) + instrs += CALL_REF( + TypeIdx(method.toWasmFunctionType()(ctx).name) + ) + if (t.tpe == IRTypes.NothingType) + instrs += UNREACHABLE } else if (receiverClassInfo.kind == ClassKind.HijackedClass) { // statically resolved call - val info = receiverClassInfo.getMethodInfo(t.method.name) - val castIfNeeded = - if (receiverClassName == IRNames.BoxedStringClass && t.receiver.tpe == ClassType(IRNames.BoxedStringClass)) - List(REF_CAST(HeapType(Types.WasmHeapType.Type(WasmStructTypeName.string)))) - else - Nil - pushReceiver ++ castIfNeeded ++ wasmArgs ++ - List( - CALL(FuncIdx(info.name)) - ) + genApplyStatically( + IRTrees.ApplyStatically(t.flags, t.receiver, receiverClassName, t.method, t.args)(t.tpe)(t.pos) + ) } else { // virtual dispatch val (methodIdx, info) = ctx .calculateVtable(receiverClassName) @@ -253,191 +277,219 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti // struct.get $classType 0 ;; get vtable // struct.get $vtableType $methodIdx ;; get funcref // call.ref (type $funcType) ;; call funcref - pushReceiver ++ wasmArgs ++ pushReceiver ++ - List( - REF_CAST( - HeapType(Types.WasmHeapType.Type(WasmStructTypeName(receiverClassName))) - ), - STRUCT_GET( - TypeIdx(WasmStructTypeName(receiverClassName)), - StructFieldIdx(0) - ), - STRUCT_GET( - TypeIdx(WasmVTableTypeName.fromIR(receiverClassName)), - StructFieldIdx(methodIdx) - ), - CALL_REF( - TypeIdx(info.toWasmFunctionType()(ctx).name) - ) - ) + genReceiverArgsReceiver() + instrs += REF_CAST( + HeapType(Types.WasmHeapType.Type(WasmStructTypeName(receiverClassName))) + ) + instrs += STRUCT_GET( + TypeIdx(WasmStructTypeName(receiverClassName)), + StructFieldIdx(0) + ) + instrs += STRUCT_GET( + TypeIdx(WasmVTableTypeName.fromIR(receiverClassName)), + StructFieldIdx(methodIdx) + ) + instrs += CALL_REF( + TypeIdx(info.toWasmFunctionType()(ctx).name) + ) + if (t.tpe == IRTypes.NothingType) + instrs += UNREACHABLE } + + t.tpe } - private def transformApplyStatically(t: IRTrees.ApplyStatically): List[WasmInstr] = { - val wasmArgs = transformTree(t.receiver) ++ t.args.flatMap(transformTree) + private def genApplyStatically(t: IRTrees.ApplyStatically): IRTypes.Type = { + IRTypes.BoxedClassToPrimType.get(t.className) match { + case None => + genTree(t.receiver, IRTypes.ClassType(t.className)) + + case Some(primReceiverType) => + genTreeAuto(t.receiver) + genUnbox(t.receiver.tpe, primReceiverType) + } + + genArgs(t.args, t.method.name) val funcName = Names.WasmFunctionName(t.className, t.method.name) - wasmArgs :+ CALL(FuncIdx(funcName)) + instrs += CALL(FuncIdx(funcName)) + if (t.tpe == IRTypes.NothingType) + instrs += UNREACHABLE + t.tpe } - private def transformLiteral(l: IRTrees.Literal): List[WasmInstr] = l match { - case IRTrees.BooleanLiteral(v) => WasmInstr.I32_CONST(if (v) I32(1) else I32(0)) :: Nil - case IRTrees.ByteLiteral(v) => WasmInstr.I32_CONST(I32(v)) :: Nil - case IRTrees.ShortLiteral(v) => WasmInstr.I32_CONST(I32(v)) :: Nil - case IRTrees.IntLiteral(v) => WasmInstr.I32_CONST(I32(v)) :: Nil - case IRTrees.CharLiteral(v) => WasmInstr.I32_CONST(I32(v)) :: Nil - case IRTrees.LongLiteral(v) => WasmInstr.I64_CONST(I64(v)) :: Nil - case IRTrees.FloatLiteral(v) => WasmInstr.F32_CONST(F32(v)) :: Nil - case IRTrees.DoubleLiteral(v) => WasmInstr.F64_CONST(F64(v)) :: Nil - - case v: IRTrees.Undefined => - WasmInstr.GLOBAL_GET(GlobalIdx(WasmGlobalName.WasmUndefName)) :: Nil - case v: IRTrees.Null => ??? - - case v: IRTrees.StringLiteral => - // TODO We should allocate literal strings once and for all as globals - val str = v.value - str.toList.map(c => WasmInstr.I32_CONST(I32(c.toInt))) ::: - List( - WasmInstr - .ARRAY_NEW_FIXED(TypeIdx(WasmTypeName.WasmArrayTypeName.stringData), I32(str.length())), - WasmInstr.STRUCT_NEW(TypeIdx(WasmTypeName.WasmStructTypeName.string)) - ) + private def genUnbox(fromType: IRTypes.Type, primType: IRTypes.PrimType): Unit = { + if (fromType != primType && fromType != IRTypes.NothingType) { + primType match { + case IRTypes.StringType => + instrs += REF_CAST(HeapType(Types.WasmHeapType.Type(WasmStructTypeName.string))) + case _ => + println(s"unbox($fromType, $primType)") + ??? + } + } + } - case v: IRTrees.ClassOf => ??? + private def genArgs(args: List[IRTrees.Tree], methodName: IRNames.MethodName): Unit = { + for ((arg, paramTypeRef) <- args.lazyZip(methodName.paramTypeRefs)) { + val paramType = ctx.inferTypeFromTypeRef(paramTypeRef) + genTree(arg, paramType) + } } - private def transformSelect(sel: IRTrees.Select): List[WasmInstr] = { + private def genLiteral(l: IRTrees.Literal): IRTypes.Type = { + l match { + case IRTrees.BooleanLiteral(v) => instrs += WasmInstr.I32_CONST(if (v) I32(1) else I32(0)) + case IRTrees.ByteLiteral(v) => instrs += WasmInstr.I32_CONST(I32(v)) + case IRTrees.ShortLiteral(v) => instrs += WasmInstr.I32_CONST(I32(v)) + case IRTrees.IntLiteral(v) => instrs += WasmInstr.I32_CONST(I32(v)) + case IRTrees.CharLiteral(v) => instrs += WasmInstr.I32_CONST(I32(v)) + case IRTrees.LongLiteral(v) => instrs += WasmInstr.I64_CONST(I64(v)) + case IRTrees.FloatLiteral(v) => instrs += WasmInstr.F32_CONST(F32(v)) + case IRTrees.DoubleLiteral(v) => instrs += WasmInstr.F64_CONST(F64(v)) + + case v: IRTrees.Undefined => + instrs += WasmInstr.GLOBAL_GET(GlobalIdx(WasmGlobalName.WasmUndefName)) + case v: IRTrees.Null => ??? + + case v: IRTrees.StringLiteral => + // TODO We should allocate literal strings once and for all as globals + val str = v.value + str.foreach(c => instrs += WasmInstr.I32_CONST(I32(c.toInt))) + instrs += WasmInstr.ARRAY_NEW_FIXED(TypeIdx(WasmTypeName.WasmArrayTypeName.stringData), I32(str.length())) + instrs += WasmInstr.STRUCT_NEW(TypeIdx(WasmTypeName.WasmStructTypeName.string)) + + case v: IRTrees.ClassOf => ??? + } + + l.tpe + } + + private def genSelect(sel: IRTrees.Select): IRTypes.Type = { val className = WasmStructTypeName(sel.className) val fieldName = WasmFieldName(sel.field.name) val idx = ctx.getClassInfo(sel.className).getFieldIdx(fieldName) - val select = sel.qualifier match { - case _: IRTrees.This => - List( - // requires cast if the qualifier is `this` - // because receiver type is Object in wasm - REF_CAST( - HeapType(Types.WasmHeapType.Type(WasmStructTypeName(sel.className))) - ), - STRUCT_GET(TypeIdx(className), idx) - ) - case _ => - List(STRUCT_GET(TypeIdx(className), idx)) - } - transformTree(sel.qualifier) ++ select + // For Select, the receiver can never be a hijacked class, so we can use genTreeAuto + genTreeAuto(sel.qualifier) + + instrs += STRUCT_GET(TypeIdx(className), idx) + sel.tpe } - private def transformStoreModule(t: IRTrees.StoreModule): List[WasmInstr] = { + private def genStoreModule(t: IRTrees.StoreModule): IRTypes.Type = { val name = WasmGlobalName.WasmModuleInstanceName.fromIR(t.className) - transformTree(t.value) :+ GLOBAL_SET(GlobalIdx(name)) + genTree(t.value, IRTypes.ClassType(t.className)) + instrs += GLOBAL_SET(GlobalIdx(name)) + IRTypes.NoType } /** Push module class instance to the stack. * * see: WasmBuilder.genLoadModuleFunc */ - private def transformLoadModule(t: IRTrees.LoadModule): List[WasmInstr] = - List(CALL(FuncIdx(Names.WasmFunctionName.loadModule(t.className)))) + private def genLoadModule(t: IRTrees.LoadModule): IRTypes.Type = { + instrs += CALL(FuncIdx(Names.WasmFunctionName.loadModule(t.className))) + t.tpe + } - private def transformUnaryOp(unary: IRTrees.UnaryOp): List[WasmInstr] = { + private def genUnaryOp(unary: IRTrees.UnaryOp): IRTypes.Type = { import IRTrees.UnaryOp._ - val lhsInstrs = transformTree(unary.lhs) + genTreeAuto(unary.lhs) (unary.op: @switch) match { case Boolean_! => - lhsInstrs ++ - List( - I32_CONST(I32(1)), - I32_XOR - ) + instrs += I32_CONST(I32(1)) + instrs += I32_XOR // Widening conversions case CharToInt | ByteToInt | ShortToInt => - lhsInstrs // these are no-ops because they are all represented as i32's with the right mathematical value + () // these are no-ops because they are all represented as i32's with the right mathematical value case IntToLong => - lhsInstrs :+ I64_EXTEND32_S + instrs += I64_EXTEND32_S case IntToDouble => - lhsInstrs :+ F64_CONVERT_I32_S + instrs += F64_CONVERT_I32_S case FloatToDouble => - lhsInstrs :+ F64_PROMOTE_F32 + instrs += F64_PROMOTE_F32 // Narrowing conversions case IntToChar => - lhsInstrs ++ List(I32_CONST(I32(0xffff)), I32_AND) + instrs += I32_CONST(I32(0xffff)) + instrs += I32_AND case IntToByte => - lhsInstrs :+ I32_EXTEND8_S + instrs += I32_EXTEND8_S case IntToShort => - lhsInstrs :+ I32_EXTEND16_S + instrs += I32_EXTEND16_S case LongToInt => - lhsInstrs :+ I32_WRAP_I64 + instrs += I32_WRAP_I64 case DoubleToInt => - lhsInstrs :+ I32_TRUNC_SAT_F64_S + instrs += I32_TRUNC_SAT_F64_S case DoubleToFloat => - lhsInstrs :+ F32_DEMOTE_F64 + instrs += F32_DEMOTE_F64 // Long <-> Double (neither widening nor narrowing) case LongToDouble => - lhsInstrs :+ F64_CONVERT_I64_S + instrs += F64_CONVERT_I64_S case DoubleToLong => - lhsInstrs :+ I64_TRUNC_SAT_F64_S + instrs += I64_TRUNC_SAT_F64_S // Long -> Float (neither widening nor narrowing), introduced in 1.6 case LongToFloat => - lhsInstrs :+ F32_CONVERT_I64_S + instrs += F32_CONVERT_I64_S // String.length, introduced in 1.11 case String_length => - lhsInstrs ++ - List( - STRUCT_GET(TypeIdx(WasmStructTypeName.string), StructFieldIdx(0)), // get the array - ARRAY_LEN - ) + instrs += STRUCT_GET(TypeIdx(WasmStructTypeName.string), StructFieldIdx(0)) // get the array + instrs += ARRAY_LEN } + + unary.tpe } - private def transformBinaryOp(binary: IRTrees.BinaryOp): List[WasmInstr] = { + private def genBinaryOp(binary: IRTrees.BinaryOp): IRTypes.Type = { import IRTrees.BinaryOp - def longShiftOp(shiftInstr: WasmInstr): List[WasmInstr] = { - transformTree(binary.lhs) ++ - transformTree(binary.rhs) ++ - List( - I64_EXTEND_I32_S, - shiftInstr - ) + def genLongShiftOp(shiftInstr: WasmInstr): IRTypes.Type = { + genTree(binary.lhs, IRTypes.LongType) + genTree(binary.rhs, IRTypes.IntType) + instrs += I64_EXTEND_I32_S + instrs += shiftInstr + IRTypes.LongType } binary.op match { - case BinaryOp.String_+ => transformStringConcat(binary.lhs, binary.rhs) + case BinaryOp.=== | BinaryOp.!== => genEq(binary) + + case BinaryOp.String_+ => genStringConcat(binary.lhs, binary.rhs) - case BinaryOp.Long_<< => longShiftOp(I64_SHL) - case BinaryOp.Long_>>> => longShiftOp(I64_SHR_U) - case BinaryOp.Long_>> => longShiftOp(I64_SHR_S) + case BinaryOp.Long_<< => genLongShiftOp(I64_SHL) + case BinaryOp.Long_>>> => genLongShiftOp(I64_SHR_U) + case BinaryOp.Long_>> => genLongShiftOp(I64_SHR_S) // New in 1.11 case BinaryOp.String_charAt => - transformTree(binary.lhs) ++ // push the string - List( - STRUCT_GET(TypeIdx(WasmStructTypeName.string), StructFieldIdx(0)), // get the array - ) ++ - transformTree(binary.rhs) ++ // push the index - List( - ARRAY_GET_U(TypeIdx(WasmArrayTypeName.stringData)) // access the element of the array - ) + genTree(binary.lhs, IRTypes.StringType) // push the string + instrs += STRUCT_GET(TypeIdx(WasmStructTypeName.string), StructFieldIdx(0)) // get the array + genTree(binary.rhs, IRTypes.IntType) // push the index + instrs += ARRAY_GET_U(TypeIdx(WasmArrayTypeName.stringData)) // access the element of the array + IRTypes.CharType - case _ => transformElementaryBinaryOp(binary) + case _ => genElementaryBinaryOp(binary) } } - private def transformElementaryBinaryOp(binary: IRTrees.BinaryOp): List[WasmInstr] = { + private def genEq(binary: IRTrees.BinaryOp): IRTypes.Type = { + println(binary) + ??? + + IRTypes.BooleanType + } + + private def genElementaryBinaryOp(binary: IRTrees.BinaryOp): IRTypes.Type = { import IRTrees.BinaryOp - val lhsInstrs = transformTree(binary.lhs) - val rhsInstrs = transformTree(binary.rhs) + genTreeAuto(binary.lhs) + genTreeAuto(binary.rhs) val operation = binary.op match { - case BinaryOp.=== => ??? - case BinaryOp.!== => ??? - case BinaryOp.Boolean_== => I32_EQ case BinaryOp.Boolean_!= => I32_NE case BinaryOp.Boolean_| => I32_OR @@ -505,44 +557,44 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti case BinaryOp.Double_> => F64_GT case BinaryOp.Double_>= => F64_GE } - lhsInstrs ++ rhsInstrs :+ operation + instrs += operation + binary.tpe } - private def transformStringConcat(lhs: IRTrees.Tree, rhs: IRTrees.Tree): List[WasmInstr] = { + private def genStringConcat(lhs: IRTrees.Tree, rhs: IRTrees.Tree): IRTypes.Type = { val wasmStringType = Types.WasmRefType(Types.WasmHeapType.Type(WasmStructTypeName.string)) - def transformToString(tree: IRTrees.Tree): List[WasmInstr] = { - val valueInstrs = transformTree(tree) + def genToString(tree: IRTrees.Tree): Unit = { + genTreeAuto(tree) tree.tpe match { case IRTypes.StringType => - valueInstrs + () // no-op case IRTypes.BooleanType => - valueInstrs ++ - List(IF(BlockType.ValueType(wasmStringType))) ++ - transformLiteral(IRTrees.StringLiteral("true")(tree.pos)) ++ - List(ELSE) ++ - transformLiteral(IRTrees.StringLiteral("false")(tree.pos)) ++ - List(END) + instrs += IF(BlockType.ValueType(wasmStringType)) + genLiteral(IRTrees.StringLiteral("true")(tree.pos)) + instrs += ELSE + genLiteral(IRTrees.StringLiteral("false")(tree.pos)) + instrs += END case IRTypes.CharType => - valueInstrs ++ - List( - WasmInstr.ARRAY_NEW_FIXED(TypeIdx(WasmTypeName.WasmArrayTypeName.stringData), I32(1)), - WasmInstr.STRUCT_NEW(TypeIdx(WasmTypeName.WasmStructTypeName.string)) - ) + instrs += WasmInstr.ARRAY_NEW_FIXED(TypeIdx(WasmTypeName.WasmArrayTypeName.stringData), I32(1)) + instrs += WasmInstr.STRUCT_NEW(TypeIdx(WasmTypeName.WasmStructTypeName.string)) case IRTypes.ByteType | IRTypes.ShortType | IRTypes.IntType => // TODO Write a correct implementation - valueInstrs ++ (DROP +: transformLiteral(IRTrees.StringLiteral("0")(tree.pos))) + instrs += DROP + genLiteral(IRTrees.StringLiteral("0")(tree.pos)) case IRTypes.LongType => // TODO Write a correct implementation - valueInstrs ++ (DROP +: transformLiteral(IRTrees.StringLiteral("0")(tree.pos))) + instrs += DROP + genLiteral(IRTrees.StringLiteral("0")(tree.pos)) case IRTypes.FloatType | IRTypes.DoubleType => // TODO Write a correct implementation - valueInstrs ++ (DROP +: transformLiteral(IRTrees.StringLiteral("0.0")(tree.pos))) + instrs += DROP + genLiteral(IRTrees.StringLiteral("0.0")(tree.pos)) case _ => // TODO @@ -553,23 +605,25 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti lhs match { case IRTrees.StringLiteral("") => // Common case where we don't actually need a concatenation - transformToString(rhs) + genToString(rhs) case _ => - // TODO: transformToString(lhs) ::: transformToString(rhs) :: callHelperConcat() :: Nil + // TODO: genToString(lhs) ::: genToString(rhs) :: callHelperConcat() :: Nil ??? } + + IRTypes.StringType } - private def transformAsInstanceOf(tree: IRTrees.AsInstanceOf): List[WasmInstr] = { - val exprInstrs = transformTree(tree.expr) + private def genAsInstanceOf(tree: IRTrees.AsInstanceOf): IRTypes.Type = { + genTreeAuto(tree.expr) val sourceTpe = tree.expr.tpe val targetTpe = tree.tpe if (IRTypes.isSubtype(sourceTpe, targetTpe)(isSubclass(_, _))) { // Common case where no cast is necessary - exprInstrs + sourceTpe } else { println(tree) ??? @@ -579,39 +633,59 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti private def isSubclass(subClass: IRNames.ClassName, superClass: IRNames.ClassName): Boolean = ctx.getClassInfo(subClass).ancestors.contains(superClass) - private def transformVarRef(r: IRTrees.VarRef): LOCAL_GET = { + private def genVarRef(r: IRTrees.VarRef): IRTypes.Type = { val name = WasmLocalName.fromIR(r.ident.name) - LOCAL_GET(LocalIdx(name)) + instrs += LOCAL_GET(LocalIdx(name)) + r.tpe } - private def transformVarDef(r: IRTrees.VarDef): List[WasmInstr] = { - r.vtpe match { - // val _: Unit = rhs - case ClassType(className) if className == IRNames.BoxedUnitClass => - transformTree(r.rhs) :+ DROP + private def genThis(t: IRTrees.This): IRTypes.Type = { + instrs += LOCAL_GET(LocalIdx(fctx.receiver.name)) + + /* If the receiver is a Class/ModuleClass, its wasm type will be declared + * as j.l.Object, and therefore we must cast it down. + */ + t.tpe match { + case IRTypes.ClassType(className) => + val info = ctx.getClassInfo(className) + if (info.kind.isClass) { + instrs += REF_CAST( + HeapType(Types.WasmHeapType.Type(WasmStructTypeName(className))) + ) + } case _ => - val local = WasmLocal( - WasmLocalName.fromIR(r.name.name), - TypeTransformer.transformType(r.vtpe)(ctx), - isParameter = false - ) - fctx.locals.define(local) - - transformTree(r.rhs) :+ LOCAL_SET(LocalIdx(local.name)) + () } + + t.tpe } - private def transformIf(t: IRTrees.If): List[WasmInstr] = { - val ty = TypeTransformer.transformType(t.tpe)(ctx) - transformTree(t.cond) ++ - List(IF(BlockType.ValueType(ty))) ++ - transformTree(t.thenp) ++ - List(ELSE) ++ - transformTree(t.elsep) ++ - List(END) + private def genVarDef(r: IRTrees.VarDef): IRTypes.Type = { + val local = WasmLocal( + WasmLocalName.fromIR(r.name.name), + TypeTransformer.transformType(r.vtpe)(ctx), + isParameter = false + ) + fctx.locals.define(local) + + genTree(r.rhs, r.vtpe) + instrs += LOCAL_SET(LocalIdx(local.name)) + + IRTypes.NoType } - private def transformWhile(t: IRTrees.While): List[WasmInstr] = { + private def genIf(t: IRTrees.If, expectedType: IRTypes.Type): IRTypes.Type = { + val ty = TypeTransformer.transformType(expectedType)(ctx) + genTree(t.cond, IRTypes.BooleanType) + instrs += IF(BlockType.ValueType(ty)) + genTree(t.thenp, expectedType) + instrs += ELSE + genTree(t.elsep, expectedType) + instrs += END + expectedType + } + + private def genWhile(t: IRTrees.While): IRTypes.Type = { val label = fctx.genLabel() val noResultType = BlockType.ValueType(Types.WasmNoType) @@ -623,14 +697,12 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti // br $label // end // unreachable - List( - LOOP(noResultType, Some(label)) - ) ++ transformTree(t.body) ++ - List( - BR(label), - END, - UNREACHABLE - ) + instrs += LOOP(noResultType, Some(label)) + genTree(t.body, IRTypes.NoType) + instrs += BR(label) + instrs += END + instrs += UNREACHABLE + IRTypes.NothingType case _ => // loop $label @@ -641,39 +713,43 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti // end // end - List( - LOOP(noResultType, Some(label)) - ) ++ transformTree(t.cond) ++ - List( - IF(noResultType) - ) ++ - transformTree(t.body) ++ - List( - BR(label), - END, // IF - END // LOOP - ) + instrs += LOOP(noResultType, Some(label)) + genTree(t.cond, IRTypes.BooleanType) + instrs += IF(noResultType) + genTree(t.body, IRTypes.NoType) + instrs += BR(label) + instrs += END // IF + instrs += END // LOOP + IRTypes.NoType } } - private def transformBlock(t: IRTrees.Block): List[WasmInstr] = - t.stats.flatMap(transformTree) + private def genBlock(t: IRTrees.Block, expectedType: IRTypes.Type): IRTypes.Type = { + for (stat <- t.stats.init) + genTree(stat, IRTypes.NoType) + genTree(t.stats.last, expectedType) + expectedType + } - private def transformLabeled(t: IRTrees.Labeled): List[WasmInstr] = { - val label = fctx.getLabelFor(t.label.name) + private def genLabeled(t: IRTrees.Labeled, expectedType: IRTypes.Type): IRTypes.Type = { + val label = fctx.registerLabel(t.label.name, expectedType) val ty = TypeTransformer.transformType(t.tpe)(ctx) - BLOCK(BlockType.ValueType(ty), Some(label)) +: - transformTree(t.body) :+ - END + + instrs += BLOCK(BlockType.ValueType(ty), Some(label)) + genTree(t.body, expectedType) + instrs += END + expectedType } - private def transformReturn(t: IRTrees.Return): List[WasmInstr] = { - val label = fctx.getLabelFor(t.label.name) - transformTree(t.expr) :+ - BR(label) + private def genReturn(t: IRTrees.Return): IRTypes.Type = { + val (label, expectedType) = fctx.getLabelFor(t.label.name) + + genTree(t.expr, expectedType) + instrs += BR(label) + IRTypes.NothingType } - private def transformNew(n: IRTrees.New): List[WasmInstr] = { + private def genNew(n: IRTrees.New): IRTypes.Type = { val localInstance = WasmLocal( fctx.genSyntheticLocalName(), TypeTransformer.transformType(n.tpe)(ctx), @@ -681,15 +757,13 @@ class WasmExpressionBuilder(ctx: FunctionTypeWriterWasmContext, fctx: WasmFuncti ) fctx.locals.define(localInstance) - List( - // REF_NULL(HeapType(Types.WasmHeapType.Type(WasmTypeName.WasmStructTypeName(n.className)))), - // LOCAL_TEE(LocalIdx(localInstance.name)) - CALL(FuncIdx(WasmFunctionName.newDefault(n.className))), - LOCAL_TEE(LocalIdx(localInstance.name)) - ) ++ n.args.flatMap(transformTree) ++ - List( - CALL(FuncIdx(WasmFunctionName(n.className, n.ctor.name))), - LOCAL_GET(LocalIdx(localInstance.name)) - ) + // REF_NULL(HeapType(Types.WasmHeapType.Type(WasmTypeName.WasmStructTypeName(n.className)))), + // LOCAL_TEE(LocalIdx(localInstance.name)) + instrs += CALL(FuncIdx(WasmFunctionName.newDefault(n.className))) + instrs += LOCAL_TEE(LocalIdx(localInstance.name)) + genArgs(n.args, n.ctor.name) + instrs += CALL(FuncIdx(WasmFunctionName(n.className, n.ctor.name))) + instrs += LOCAL_GET(LocalIdx(localInstance.name)) + n.tpe } } diff --git a/wasm/src/main/scala/wasm4s/WasmFunctionContext.scala b/wasm/src/main/scala/wasm4s/WasmFunctionContext.scala index 1158c885..5ec0b36e 100644 --- a/wasm/src/main/scala/wasm4s/WasmFunctionContext.scala +++ b/wasm/src/main/scala/wasm4s/WasmFunctionContext.scala @@ -3,6 +3,7 @@ package wasm.wasm4s import scala.collection.mutable import org.scalajs.ir.{Names => IRNames} +import org.scalajs.ir.{Types => IRTypes} import Names.WasmLocalName @@ -13,7 +14,8 @@ class WasmFunctionContext private (private val _receiver: Option[WasmLocal]) { val locals = new WasmSymbolTable[WasmLocalName, WasmLocal]() def receiver = _receiver.getOrElse(throw new Error("Can access to the receiver in this context.")) - private val registeredLabels = mutable.AnyRefMap.empty[IRNames.LabelName, WasmImmediate.LabelIdx] + private val registeredLabels = + mutable.AnyRefMap.empty[IRNames.LabelName, (WasmImmediate.LabelIdx, IRTypes.Type)] def genLabel(): WasmImmediate.LabelIdx = { val label = WasmImmediate.LabelIdx(labelIdx) @@ -21,8 +23,17 @@ class WasmFunctionContext private (private val _receiver: Option[WasmLocal]) { label } - def getLabelFor(irLabelName: IRNames.LabelName): WasmImmediate.LabelIdx = - registeredLabels.getOrElseUpdate(irLabelName, genLabel()) + def registerLabel(irLabelName: IRNames.LabelName, expectedType: IRTypes.Type): WasmImmediate.LabelIdx = { + val label = genLabel() + registeredLabels(irLabelName) = (label, expectedType) + label + } + + def getLabelFor(irLabelName: IRNames.LabelName): (WasmImmediate.LabelIdx, IRTypes.Type) = { + registeredLabels.getOrElse(irLabelName, { + throw new IllegalArgumentException(s"Unknown label ${irLabelName.nameString}") + }) + } def genSyntheticLocalName(): WasmLocalName = { val name = WasmLocalName.synthetic(cnt)