Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Add support for static methods. #21

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cli/src/main/scala/TestSuites.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ object TestSuites {
TestSuite("testsuite.core.HijackedClassesDispatchTest"),
TestSuite("testsuite.core.HijackedClassesMonoTest"),
TestSuite("testsuite.core.HijackedClassesUpcastTest"),
TestSuite("testsuite.core.StaticMethodTest"),
TestSuite("testsuite.core.ToStringTest")
)
}
2 changes: 1 addition & 1 deletion run.mjs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { load } from "./loader.mjs";

const { test } = await load("./target/output.wasm");
const o = test();
const o = test(7);
console.log(o);
3 changes: 1 addition & 2 deletions sample/src/main/scala/Sample.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import scala.scalajs.js.annotation._
//
object Main {
@JSExportTopLevel("test")
def test() = {
val i = 4
def test(i: Int): Boolean = {
val loopFib = fib(new LoopFib {}, i)
val recFib = fib(new RecFib {}, i)
val tailrecFib = fib(new TailRecFib {}, i)
Expand Down
10 changes: 10 additions & 0 deletions test-suite/src/main/scala/testsuite/core/StaticMethodTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package testsuite.core

import testsuite.Assert.ok

object StaticMethodTest {
def main(): Unit = {
ok(java.lang.Integer.sum(5, 65) == 70)
ok(java.lang.Integer.reverseBytes(0x01020304) == 0x04030201)
}
}
6 changes: 5 additions & 1 deletion wasm/src/main/scala/ir2wasm/HelperFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,11 @@ object HelperFunctions {
// TODO: "isInstance", "isAssignableFrom", "checkCast", "newArrayOfThisClass"

// Call java.lang.Class::<init>(dataObject)
instrs += CALL(FuncIdx(WasmFunctionName(IRNames.ClassClass, SpecialNames.ClassCtor)))
instrs += CALL(FuncIdx(WasmFunctionName(
IRTrees.MemberNamespace.Constructor,
IRNames.ClassClass,
SpecialNames.ClassCtor
)))

// typeData.classOf := classInstance
instrs += LOCAL_GET(typeDataParam)
Expand Down
8 changes: 4 additions & 4 deletions wasm/src/main/scala/ir2wasm/Preprocessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ object Preprocessor {
}

private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
val infos = clazz.methods.filterNot(_.flags.namespace.isConstructor).map { method =>
makeWasmFunctionInfo(clazz, method)
}
val infos = clazz.methods
.filter(_.flags.namespace == IRTrees.MemberNamespace.Public)
.map(method => makeWasmFunctionInfo(clazz, method))

ctx.putClassInfo(
clazz.name.name,
Expand All @@ -47,7 +47,7 @@ object Preprocessor {
method: IRTrees.MethodDef
): WasmFunctionInfo = {
WasmFunctionInfo(
Names.WasmFunctionName(clazz.name.name, method.name.name),
Names.WasmFunctionName(method.flags.namespace, clazz.name.name, method.name.name),
method.args.map(_.ptpe),
method.resultType,
isAbstract = method.body.isEmpty
Expand Down
93 changes: 30 additions & 63 deletions wasm/src/main/scala/ir2wasm/WasmBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,13 @@ class WasmBuilder {
.getOrElse(throw new Error(s"Module class should have a constructor, ${clazz.name}"))
val typeName = WasmTypeName.WasmStructTypeName(clazz.name.name)
val globalInstanceName = WasmGlobalName.WasmModuleInstanceName.fromIR(clazz.name.name)
val ctorName = WasmFunctionName(clazz.name.name, ctor.name.name)

val ctorName = WasmFunctionName(
ctor.flags.namespace,
clazz.name.name,
ctor.name.name
)

val body = List(
// global.get $module_name
// ref.if_null
Expand Down Expand Up @@ -355,7 +361,7 @@ class WasmBuilder {
Names.WasmTypeName.WasmITableTypeName(className),
classInfo.methods.map { m =>
WasmStructField(
Names.WasmFieldName(m.name.methodName),
Names.WasmFieldName(m.name.simpleName),
WasmRefNullType(WasmHeapType.Func(m.toWasmFunctionType().name)),
isMutable = false
)
Expand Down Expand Up @@ -397,97 +403,58 @@ class WasmBuilder {
exportDef: IRTrees.TopLevelMethodExportDef
)(implicit ctx: WasmContext): Unit = {
val method = exportDef.methodDef
val methodName = method.name match {
case lit: IRTrees.StringLiteral => lit
case _ => ???
}

// hack
// export top[moduleID="main"] static def "foo"(arg: any): any = {
// val prep0: int = arg.asInstanceOf[int];
// mod:sample.Main$.foo;I;I(prep0)
// }
// ->
// export top[moduleID="main"] static def "foo"(arg: int): int = {
// val prep0: int = arg;
// mod:sample.Main$.foo;I;I(arg)
// }
val paramTypeMap = mutable.Map[IRTrees.LocalIdent, IRTypes.Type]()
val nameMap = mutable.Map[IRTrees.LocalIdent, IRTrees.LocalIdent]()
val resultType: IRTypes.Type = method.body.tpe
def collectMapping(t: IRTrees.Tree): Unit = {
t match {
case IRTrees.Block(stats) => stats.foreach(collectMapping)
case IRTrees.VarDef(lhs, _, _, _, IRTrees.AsInstanceOf(IRTrees.VarRef(ident), tpe)) =>
paramTypeMap.update(ident, tpe) // arg -> int
nameMap.update(lhs, ident) // prep0 -> arg
case _ =>
}
}
def mutateTree(t: IRTrees.Tree): IRTrees.Tree = {
t match {
case b: IRTrees.Block => IRTrees.Block(b.stats.map(mutateTree))(b.pos)
case vdef @ IRTrees.VarDef(_, _, _, _, IRTrees.AsInstanceOf(vref, tpe)) =>
vdef.copy(rhs = vref)(vdef.pos)
case app: IRTrees.Apply =>
app.copy(args = app.args.map(a => mutateTree(a)))(app.tpe)(app.pos)
case vref: IRTrees.VarRef =>
val newName = nameMap.getOrElse(vref.ident, throw new Error("Invalid name"))
vref.copy(ident = newName)(vref.tpe)(vref.pos)
case t => t
}
}
val exportedName = exportDef.topLevelExportName

collectMapping(method.body)
val newBody = mutateTree(method.body)
val newParams = method.args.map { arg =>
paramTypeMap.get(arg.name) match {
case None => arg
case Some(newTpe) => arg.copy(ptpe = newTpe)(arg.pos)
}
if (method.restParam.isDefined) {
throw new UnsupportedOperationException(
s"Top-level export with ...rest param is unsupported at ${method.pos}: $method"
)
}

implicit val fctx = WasmFunctionContext(
enclosingClassName = None,
Names.WasmFunctionName(methodName),
Names.WasmFunctionName.forExport(exportedName),
receiverTyp = None,
newParams,
resultType
method.args,
IRTypes.AnyType
)

WasmExpressionBuilder.generateIRBody(newBody, resultType)
WasmExpressionBuilder.generateIRBody(method.body, IRTypes.AnyType)

val func = fctx.buildAndAddToContext()

val exprt = new WasmExport.Function(
methodName.value,
func
)
val exprt = new WasmExport.Function(exportedName, func)
ctx.addExport(exprt)
}

private def genFunction(
clazz: LinkedClass,
method: IRTrees.MethodDef
)(implicit ctx: WasmContext): WasmFunction = {
val functionName = Names.WasmFunctionName(clazz.name.name, method.name.name)
val functionName = Names.WasmFunctionName(
method.flags.namespace,
clazz.name.name,
method.name.name
)

// Receiver type for non-constructor methods needs to be `(ref any)` because params are invariant
// Otherwise, vtable can't be a subtype of the supertype's subtype
// Constructor can use the exact type because it won't be registered to vtables.
val receiverTyp =
if (clazz.kind == ClassKind.HijackedClass)
transformType(IRTypes.BoxedClassToPrimType(clazz.name.name))
if (method.flags.namespace.isStatic)
None
else if (clazz.kind == ClassKind.HijackedClass)
Some(transformType(IRTypes.BoxedClassToPrimType(clazz.name.name)))
else if (method.flags.namespace.isConstructor)
WasmRefNullType(WasmHeapType.Type(WasmTypeName.WasmStructTypeName(clazz.name.name)))
Some(WasmRefNullType(WasmHeapType.Type(WasmTypeName.WasmStructTypeName(clazz.name.name))))
else
WasmRefType.any
Some(WasmRefType.any)

// Prepare for function context, set receiver and parameters
implicit val fctx = WasmFunctionContext(
Some(clazz.className),
functionName,
Some(receiverTyp),
receiverTyp,
method.args,
method.resultType
)
Expand Down
32 changes: 28 additions & 4 deletions wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ private class WasmExpressionBuilder private (
case t: IRTrees.This => genThis(t)
case t: IRTrees.ApplyStatically => genApplyStatically(t)
case t: IRTrees.Apply => genApply(t)
case t: IRTrees.ApplyStatic => genApplyStatic(t)
case t: IRTrees.IsInstanceOf => genIsInstanceOf(t)
case t: IRTrees.AsInstanceOf => genAsInstanceOf(t)
case t: IRTrees.GetClass => genGetClass(t)
Expand Down Expand Up @@ -337,7 +338,11 @@ private class WasmExpressionBuilder private (
* After this code gen, the stack contains the result.
*/
def genHijackedClassCall(hijackedClass: IRNames.ClassName): Unit = {
val funcName = Names.WasmFunctionName(hijackedClass, t.method.name)
val funcName = Names.WasmFunctionName(
IRTrees.MemberNamespace.Public,
hijackedClass,
t.method.name
)
instrs += CALL(FuncIdx(funcName))
}

Expand Down Expand Up @@ -522,7 +527,11 @@ private class WasmExpressionBuilder private (

val (methodIdx, info) = ctx
.calculateVtableType(receiverClassName)
.resolveWithIdx(WasmFunctionName(receiverClassName, methodName))
.resolveWithIdx(WasmFunctionName(
IRTrees.MemberNamespace.Public,
receiverClassName,
methodName
))

// // push args to the stacks
// local.get $this ;; for accessing funcref
Expand Down Expand Up @@ -582,14 +591,25 @@ private class WasmExpressionBuilder private (
}

genArgs(t.args, t.method.name)
val funcName = Names.WasmFunctionName(t.className, t.method.name)
val namespace = IRTrees.MemberNamespace.forNonStaticCall(t.flags)
val funcName = Names.WasmFunctionName(namespace, t.className, t.method.name)
instrs += CALL(FuncIdx(funcName))
if (t.tpe == IRTypes.NothingType)
instrs += UNREACHABLE
t.tpe
}
}

private def genApplyStatic(tree: IRTrees.ApplyStatic): IRTypes.Type = {
genArgs(tree.args, tree.method.name)
val namespace = IRTrees.MemberNamespace.forStaticCall(tree.flags)
val funcName = Names.WasmFunctionName(namespace, tree.className, tree.method.name)
instrs += CALL(FuncIdx(funcName))
if (tree.tpe == IRTypes.NothingType)
instrs += UNREACHABLE
tree.tpe
}

private def genArgs(args: List[IRTrees.Tree], methodName: IRNames.MethodName): Unit = {
for ((arg, paramTypeRef) <- args.lazyZip(methodName.paramTypeRefs)) {
val paramType = ctx.inferTypeFromTypeRef(paramTypeRef)
Expand Down Expand Up @@ -1328,7 +1348,11 @@ private class WasmExpressionBuilder private (
instrs += CALL(FuncIdx(WasmFunctionName.newDefault(n.className)))
instrs += LOCAL_TEE(LocalIdx(localInstance.name))
genArgs(n.args, n.ctor.name)
instrs += CALL(FuncIdx(WasmFunctionName(n.className, n.ctor.name)))
instrs += CALL(FuncIdx(WasmFunctionName(
IRTrees.MemberNamespace.Constructor,
n.className,
n.ctor.name
)))
instrs += LOCAL_GET(LocalIdx(localInstance.name))
n.tpe
}
Expand Down
41 changes: 33 additions & 8 deletions wasm/src/main/scala/wasm4s/Names.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,46 @@ object Names {
// }

case class WasmFunctionName private (
val className: String,
val methodName: String
) extends WasmName(s"$className#$methodName")
val namespace: String,
val simpleName: String
) extends WasmName(namespace + "#" + simpleName)

object WasmFunctionName {
def apply(clazz: IRNames.ClassName, method: IRNames.MethodName): WasmFunctionName =
new WasmFunctionName(clazz.nameString, method.nameString)
def apply(lit: IRTrees.StringLiteral): WasmFunctionName = new WasmFunctionName(lit.value, "")
def apply(
namespace: IRTrees.MemberNamespace,
clazz: IRNames.ClassName,
method: IRNames.MethodName
): WasmFunctionName = {
new WasmFunctionName(
namespaceString(namespace) + "#" + clazz.nameString,
method.nameString
)
}

private def namespaceString(namespace: IRTrees.MemberNamespace): String = {
import IRTrees.MemberNamespace._

// These strings are the same ones that the JS back-end uses
namespace match {
case Public => "f"
case Private => "p"
case PublicStatic => "s"
case PrivateStatic => "ps"
case Constructor => "ct"
case StaticConstructor => "sct"
}
}

def forExport(exportedName: String): WasmFunctionName =
new WasmFunctionName("export", exportedName)

// Adding prefix __ to avoid name clashes with user code.
// It should be safe not to add prefix to the method name
// since loadModule is a static method and it's not registered in the vtable.
def loadModule(clazz: IRNames.ClassName): WasmFunctionName =
new WasmFunctionName(s"__${clazz.nameString}", "loadModule")
new WasmFunctionName("loadModule", clazz.nameString)
def newDefault(clazz: IRNames.ClassName): WasmFunctionName =
new WasmFunctionName(s"__${clazz.nameString}", "newDefault")
new WasmFunctionName("new", clazz.nameString)

val start = new WasmFunctionName("start", "start")

Expand Down
Loading