Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit 4f6e195

Browse files
committed
Implement support for Closures.
1 parent 4643e4b commit 4f6e195

File tree

7 files changed

+185
-5
lines changed

7 files changed

+185
-5
lines changed

Diff for: cli/src/main/scala/TestSuites.scala

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ object TestSuites {
1010
TestSuite("testsuite.core.InterfaceCall"),
1111
TestSuite("testsuite.core.AsInstanceOfTest"),
1212
TestSuite("testsuite.core.ClassOfTest"),
13+
TestSuite("testsuite.core.ClosureTest"),
1314
TestSuite("testsuite.core.FieldsTest"),
1415
TestSuite("testsuite.core.GetClassTest"),
1516
TestSuite("testsuite.core.JSInteropTest"),

Diff for: loader.mjs

+3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ const scalaJSHelpers = {
6363

6464
// Closure
6565
closure: (f, data) => f.bind(void 0, data),
66+
closureThis: (f, data) => function(...args) { return f(this, data, ...args); },
67+
closureRest: (f, data, n) => ((...args) => f(data, ...args.slice(0, n), args.slice(n))),
68+
closureThisRest: (f, data, n) => function(...args) { return f(this, data, ...args.slice(0, n), args.slice(n)); },
6669

6770
// Strings
6871
emptyString: () => "",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package testsuite.core
2+
3+
import scala.scalajs.js
4+
5+
import testsuite.Assert.ok
6+
7+
object ClosureTest {
8+
def main(): Unit = {
9+
testClosure()
10+
testClosureThis()
11+
12+
// TODO We cannot test closures with ...rest params yet because they need Seq's
13+
14+
testGiveToActualJSCode()
15+
}
16+
17+
def testClosure(): Unit = {
18+
def makeClosure(x: Int, y: String): js.Function2[Boolean, String, String] =
19+
(z, w) => s"$x $y $z $w"
20+
21+
val f = makeClosure(5, "foo")
22+
ok(f(true, "bar") == "5 foo true bar")
23+
}
24+
25+
def testClosureThis(): Unit = {
26+
def makeClosure(x: Int, y: String): js.ThisFunction2[Any, Boolean, String, String] =
27+
(ths, z, w) => s"$ths $x $y $z $w"
28+
29+
val f = makeClosure(5, "foo")
30+
ok(f(new Obj, true, "bar") == "Obj 5 foo true bar")
31+
}
32+
33+
def testGiveToActualJSCode(): Unit = {
34+
val arr = js.Array(2, 3, 5, 7, 11)
35+
val f: js.Function1[Int, Int] = x => x * 2
36+
val result = arr.asInstanceOf[js.Dynamic].map(f).asInstanceOf[js.Array[Int]]
37+
ok(result.length == 5)
38+
ok(result(0) == 4)
39+
ok(result(1) == 6)
40+
ok(result(2) == 10)
41+
ok(result(3) == 14)
42+
ok(result(4) == 22)
43+
}
44+
45+
class Obj {
46+
override def toString(): String = "Obj"
47+
}
48+
}

Diff for: wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala

+80-1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ private class WasmExpressionBuilder private (
142142
case t: IRTrees.JSGlobalRef => genJSGlobalRef(t)
143143
case t: IRTrees.JSTypeOfGlobalRef => genJSTypeOfGlobalRef(t)
144144
case t: IRTrees.JSLinkingInfo => genJSLinkingInfo(t)
145+
case t: IRTrees.Closure => genClosure(t)
145146

146147
// array
147148
case t: IRTrees.ArrayLength => genArrayLength(t)
@@ -159,7 +160,6 @@ private class WasmExpressionBuilder private (
159160
// case IRTrees.SelectStatic(tpe) =>
160161
// case IRTrees.JSSuperMethodCall(pos) =>
161162
// case IRTrees.Match(tpe) =>
162-
// case IRTrees.Closure(pos) =>
163163
// case IRTrees.RecordSelect(tpe) =>
164164
// case IRTrees.TryFinally(pos) =>
165165
// case IRTrees.JSImportMeta(pos) =>
@@ -1652,4 +1652,83 @@ private class WasmExpressionBuilder private (
16521652
instrs += ARRAY_NEW_FIXED(TypeIdx(arrTy.name), I32(t.elems.size))
16531653
t.tpe
16541654
}
1655+
1656+
private def genClosure(tree: IRTrees.Closure): IRTypes.Type = {
1657+
implicit val ctx = this.ctx
1658+
1659+
val hasThis = !tree.arrow
1660+
val hasRestParam = tree.restParam.isDefined
1661+
val dataStructType = ctx.getClosureDataStructType(tree.captureParams.map(_.ptpe))
1662+
1663+
// Define the function where captures are reified as a `__captureData` argument.
1664+
val closureFuncName = fctx.genInnerFuncName()
1665+
locally {
1666+
val receiverParam =
1667+
if (!hasThis) None
1668+
else Some(WasmLocal(WasmLocalName.receiver, Types.WasmAnyRef, isParameter = true))
1669+
1670+
val captureDataParam = WasmLocal(
1671+
WasmLocalName("__captureData"),
1672+
Types.WasmRefType(Types.WasmHeapType.Type(dataStructType.name)),
1673+
isParameter = true
1674+
)
1675+
1676+
val paramLocals = (tree.params ::: tree.restParam.toList).map { param =>
1677+
val typ = TypeTransformer.transformType(param.ptpe)
1678+
WasmLocal(WasmLocalName.fromIR(param.name.name), typ, isParameter = true)
1679+
}
1680+
val resultTyps = TypeTransformer.transformResultType(IRTypes.AnyType)
1681+
1682+
implicit val fctx = WasmFunctionContext(
1683+
enclosingClassName = None,
1684+
closureFuncName,
1685+
receiverParam,
1686+
captureDataParam :: paramLocals,
1687+
resultTyps
1688+
)
1689+
1690+
val captureDataLocalIdx = fctx.paramIndices.head
1691+
1692+
// Extract the fields of captureData in individual locals
1693+
for ((captureParam, index) <- tree.captureParams.zipWithIndex) {
1694+
val local = fctx.addLocal(
1695+
captureParam.name.name,
1696+
TypeTransformer.transformType(captureParam.ptpe)
1697+
)
1698+
fctx.instrs += LOCAL_GET(captureDataLocalIdx)
1699+
fctx.instrs += STRUCT_GET(TypeIdx(dataStructType.name), StructFieldIdx(index))
1700+
fctx.instrs += LOCAL_SET(local)
1701+
}
1702+
1703+
// Now transform the body - use AnyType as result type to box potential primitives
1704+
WasmExpressionBuilder.generateIRBody(tree.body, IRTypes.AnyType)
1705+
1706+
fctx.buildAndAddToContext()
1707+
}
1708+
1709+
// Put a reference to the function on the stack
1710+
instrs += ctx.refFuncWithDeclaration(closureFuncName)
1711+
1712+
// Evaluate the capture values and instantiate the capture data struct
1713+
for ((param, value) <- tree.captureParams.zip(tree.captureValues))
1714+
genTree(value, param.ptpe)
1715+
instrs += STRUCT_NEW(TypeIdx(dataStructType.name))
1716+
1717+
/* If there is a ...rest param, the helper requires as third argument the
1718+
* number of regular arguments.
1719+
*/
1720+
if (hasRestParam)
1721+
instrs += I32_CONST(I32(tree.params.size))
1722+
1723+
// Call the appropriate helper
1724+
val helper = (hasThis, hasRestParam) match {
1725+
case (false, false) => WasmFunctionName.closure
1726+
case (true, false) => WasmFunctionName.closureThis
1727+
case (false, true) => WasmFunctionName.closureRest
1728+
case (true, true) => WasmFunctionName.closureThisRest
1729+
}
1730+
instrs += CALL(FuncIdx(helper))
1731+
1732+
IRTypes.AnyType
1733+
}
16551734
}

Diff for: wasm/src/main/scala/wasm4s/Names.scala

+8
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ object Names {
156156
def typeTest(primRef: IRTypes.PrimRef): WasmFunctionName = helper("t" + primRef.charCode)
157157

158158
val closure = helper("closure")
159+
val closureThis = helper("closureThis")
160+
val closureRest = helper("closureRest")
161+
val closureThisRest = helper("closureThisRest")
159162

160163
val emptyString = helper("emptyString")
161164
val stringLength = helper("stringLength")
@@ -246,6 +249,9 @@ object Names {
246249
def apply(name: WasmTypeName.WasmITableTypeName) = new WasmFieldName(name.name)
247250
def apply(name: IRNames.MethodName) = new WasmFieldName(name.nameString)
248251
def apply(name: WasmFunctionName) = new WasmFieldName(name.name)
252+
253+
def captureParam(i: Int): WasmFieldName = new WasmFieldName("c" + i)
254+
249255
val vtable = new WasmFieldName("vtable")
250256
val itable = new WasmFieldName("itable")
251257
val itables = new WasmFieldName("itables")
@@ -330,6 +336,8 @@ object Names {
330336
object WasmStructTypeName {
331337
def apply(name: IRNames.ClassName) = new WasmStructTypeName(name.nameString)
332338

339+
def captureData(index: Int): WasmStructTypeName = new WasmStructTypeName("captureData__" + index)
340+
333341
val typeData = new WasmStructTypeName("typeData")
334342
}
335343

Diff for: wasm/src/main/scala/wasm4s/WasmContext.scala

+31
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,13 @@ trait ReadOnlyWasmContext {
111111
trait TypeDefinableWasmContext extends ReadOnlyWasmContext { this: WasmContext =>
112112
protected val functionSignatures = LinkedHashMap.empty[WasmFunctionSignature, Int]
113113
protected val constantStringGlobals = LinkedHashMap.empty[String, WasmGlobalName]
114+
protected val closureDataTypes = LinkedHashMap.empty[List[IRTypes.Type], WasmStructType]
114115

115116
private var nextConstantStringIndex: Int = 1
116117
private var nextArrayTypeIndex: Int = 1
118+
private var nextClosureDataTypeIndex: Int = 1
117119

120+
def addFunction(fun: WasmFunction): Unit
118121
protected def addGlobal(g: WasmGlobal): Unit
119122
protected def addFuncDeclaration(name: WasmFunctionName): Unit
120123

@@ -160,6 +163,19 @@ trait TypeDefinableWasmContext extends ReadOnlyWasmContext { this: WasmContext =
160163
WasmInstr.GLOBAL_GET(WasmImmediate.GlobalIdx(globalName))
161164
}
162165

166+
def getClosureDataStructType(captureParamTypes: List[IRTypes.Type]): WasmStructType = {
167+
closureDataTypes.getOrElse(captureParamTypes, {
168+
val fields: List[WasmStructField] =
169+
for ((tpe, i) <- captureParamTypes.zipWithIndex) yield
170+
WasmStructField(WasmFieldName.captureParam(i), TypeTransformer.transformType(tpe)(this), isMutable = false)
171+
val structTypeName = WasmStructTypeName.captureData(nextClosureDataTypeIndex)
172+
nextClosureDataTypeIndex += 1
173+
val structType = WasmStructType(structTypeName, fields, superType = None)
174+
addGCType(structType)
175+
structType
176+
})
177+
}
178+
163179
def refFuncWithDeclaration(name: WasmFunctionName): WasmInstr.REF_FUNC = {
164180
addFuncDeclaration(name)
165181
WasmInstr.REF_FUNC(WasmImmediate.FuncIdx(name))
@@ -244,6 +260,21 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
244260
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
245261
List(WasmRefType.any)
246262
)
263+
addHelperImport(
264+
WasmFunctionName.closureThis,
265+
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
266+
List(WasmRefType.any)
267+
)
268+
addHelperImport(
269+
WasmFunctionName.closureRest,
270+
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
271+
List(WasmRefType.any)
272+
)
273+
addHelperImport(
274+
WasmFunctionName.closureThisRest,
275+
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
276+
List(WasmRefType.any)
277+
)
247278

248279
addHelperImport(WasmFunctionName.emptyString, List(), List(WasmRefType.any))
249280
addHelperImport(WasmFunctionName.stringLength, List(WasmRefType.any), List(WasmInt32))

Diff for: wasm/src/main/scala/wasm4s/WasmFunctionContext.scala

+14-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import wasm.wasm4s.WasmInstr._
1414
import wasm.ir2wasm.TypeTransformer
1515

1616
class WasmFunctionContext private (
17-
ctx: WasmContext,
17+
ctx: TypeDefinableWasmContext,
1818
val enclosingClassName: Option[IRNames.ClassName],
1919
val functionName: WasmFunctionName,
2020
_receiver: Option[WasmLocal],
@@ -23,6 +23,7 @@ class WasmFunctionContext private (
2323
) {
2424
private var cnt = 0
2525
private var labelIdx = 0
26+
private var innerFuncIdx = 0
2627

2728
val locals = new WasmSymbolTable[WasmLocalName, WasmLocal]()
2829

@@ -80,6 +81,15 @@ class WasmFunctionContext private (
8081
def addSyntheticLocal(typ: WasmType): LocalIdx =
8182
addLocal(genSyntheticLocalName(), typ)
8283

84+
def genInnerFuncName(): WasmFunctionName = {
85+
val innerName = WasmFunctionName(
86+
functionName.namespace,
87+
functionName.simpleName + "__c" + innerFuncIdx
88+
)
89+
innerFuncIdx += 1
90+
innerName
91+
}
92+
8393
// Helpers to build structured control flow
8494

8595
def ifThenElse(blockType: BlockType)(thenp: => Unit)(elsep: => Unit): Unit = {
@@ -177,7 +187,7 @@ object WasmFunctionContext {
177187
receiver: Option[WasmLocal],
178188
params: List[WasmLocal],
179189
resultTypes: List[WasmType]
180-
)(implicit ctx: WasmContext): WasmFunctionContext = {
190+
)(implicit ctx: TypeDefinableWasmContext): WasmFunctionContext = {
181191
new WasmFunctionContext(ctx, enclosingClassName, name, receiver, params, resultTypes)
182192
}
183193

@@ -187,7 +197,7 @@ object WasmFunctionContext {
187197
receiverTyp: Option[WasmType],
188198
paramDefs: List[IRTrees.ParamDef],
189199
resultType: IRTypes.Type
190-
)(implicit ctx: WasmContext): WasmFunctionContext = {
200+
)(implicit ctx: TypeDefinableWasmContext): WasmFunctionContext = {
191201
val receiver = receiverTyp.map { typ =>
192202
WasmLocal(WasmLocalName.receiver, typ, isParameter = true)
193203
}
@@ -205,7 +215,7 @@ object WasmFunctionContext {
205215
name: WasmFunctionName,
206216
params: List[(String, WasmType)],
207217
resultTypes: List[WasmType]
208-
)(implicit ctx: WasmContext): WasmFunctionContext = {
218+
)(implicit ctx: TypeDefinableWasmContext): WasmFunctionContext = {
209219
val paramLocals = params.map { param =>
210220
WasmLocal(WasmLocalName.fromStr(param._1), param._2, isParameter = true)
211221
}

0 commit comments

Comments
 (0)