Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit f055d71

Browse files
committed
Implement support for Closures.
1 parent 26a1425 commit f055d71

File tree

7 files changed

+178
-5
lines changed

7 files changed

+178
-5
lines changed

Diff for: cli/src/main/scala/TestSuites.scala

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ object TestSuites {
99
TestSuite("testsuite.core.InterfaceCall"),
1010
TestSuite("testsuite.core.AsInstanceOfTest"),
1111
TestSuite("testsuite.core.ClassOfTest"),
12+
TestSuite("testsuite.core.ClosureTest"),
1213
TestSuite("testsuite.core.FieldsTest"),
1314
TestSuite("testsuite.core.GetClassTest"),
1415
TestSuite("testsuite.core.JSInteropTest"),

Diff for: loader.mjs

+3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ const scalaJSHelpers = {
6363

6464
// Closure
6565
closure: (f, data) => f.bind(void 0, data),
66+
closureThis: (f, data) => function(...args) { return f(this, data, ...args); },
67+
closureRest: (f, data, n) => ((...args) => f(data, ...args.slice(0, n), args.slice(n))),
68+
closureThisRest: (f, data, n) => function(...args) { return f(this, data, ...args.slice(0, n), args.slice(n)); },
6669

6770
// Strings
6871
emptyString: () => "",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package testsuite.core
2+
3+
import scala.scalajs.js
4+
5+
import testsuite.Assert.ok
6+
7+
object ClosureTest {
8+
def main(): Unit = {
9+
testClosure()
10+
testClosureThis()
11+
12+
// TODO We cannot test closures with ...rest params yet because they need Seq's
13+
14+
testGiveToActualJSCode()
15+
}
16+
17+
def testClosure(): Unit = {
18+
def makeClosure(x: Int, y: String): js.Function2[Boolean, String, String] =
19+
(z, w) => s"$x $y $z $w"
20+
21+
val f = makeClosure(5, "foo")
22+
ok(f(true, "bar") == "5 foo true bar")
23+
}
24+
25+
def testClosureThis(): Unit = {
26+
def makeClosure(x: Int, y: String): js.ThisFunction2[Any, Boolean, String, String] =
27+
(ths, z, w) => s"$ths $x $y $z $w"
28+
29+
val f = makeClosure(5, "foo")
30+
ok(f(new Obj, true, "bar") == "Obj 5 foo true bar")
31+
}
32+
33+
def testGiveToActualJSCode(): Unit = {
34+
val arr = js.Array(2, 3, 5, 7, 11)
35+
val f: js.Function1[Int, Int] = x => x * 2
36+
val result = arr.asInstanceOf[js.Dynamic].map(f).asInstanceOf[js.Array[Int]]
37+
ok(result.length == 5)
38+
ok(result(0) == 4)
39+
ok(result(1) == 6)
40+
ok(result(2) == 10)
41+
ok(result(3) == 14)
42+
ok(result(4) == 22)
43+
}
44+
45+
class Obj {
46+
override def toString(): String = "Obj"
47+
}
48+
}

Diff for: wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala

+73-1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ private class WasmExpressionBuilder private (
141141
case t: IRTrees.JSGlobalRef => genJSGlobalRef(t)
142142
case t: IRTrees.JSTypeOfGlobalRef => genJSTypeOfGlobalRef(t)
143143
case t: IRTrees.JSLinkingInfo => genJSLinkingInfo(t)
144+
case t: IRTrees.Closure => genClosure(t)
144145

145146
case _ =>
146147
println(tree)
@@ -156,7 +157,6 @@ private class WasmExpressionBuilder private (
156157
// case IRTrees.NewArray(pos) =>
157158
// case IRTrees.Match(tpe) =>
158159
// case IRTrees.Throw(pos) =>
159-
// case IRTrees.Closure(pos) =>
160160
// case IRTrees.RecordSelect(tpe) =>
161161
// case IRTrees.TryFinally(pos) =>
162162
// case IRTrees.JSImportMeta(pos) =>
@@ -1578,4 +1578,76 @@ private class WasmExpressionBuilder private (
15781578
instrs += CALL(FuncIdx(WasmFunctionName.jsLinkingInfo))
15791579
IRTypes.AnyType
15801580
}
1581+
1582+
private def genClosure(tree: IRTrees.Closure): IRTypes.Type = {
1583+
implicit val ctx = this.ctx
1584+
1585+
val hasThis = !tree.arrow
1586+
val dataStructType = ctx.getClosureDataStructType(tree.captureParams.map(_.ptpe))
1587+
1588+
// Define the function where captures are reified as a `__captureData` argument.
1589+
val closureFuncName = fctx.genInnerFuncName()
1590+
locally {
1591+
val receiverParam =
1592+
if (!hasThis) None
1593+
else Some(WasmLocal(WasmLocalName.receiver, Types.WasmAnyRef, isParameter = true))
1594+
1595+
val captureDataParam = WasmLocal(
1596+
WasmLocalName("__captureData"),
1597+
Types.WasmRefType(Types.WasmHeapType.Type(dataStructType.name)),
1598+
isParameter = true
1599+
)
1600+
1601+
val paramLocals = (tree.params ::: tree.restParam.toList).map { param =>
1602+
val typ = TypeTransformer.transformType(param.ptpe)
1603+
WasmLocal(WasmLocalName.fromIR(param.name.name), typ, isParameter = true)
1604+
}
1605+
val resultTyps = TypeTransformer.transformResultType(IRTypes.AnyType)
1606+
1607+
implicit val fctx = WasmFunctionContext(
1608+
enclosingClassName = None,
1609+
closureFuncName,
1610+
receiverParam,
1611+
captureDataParam :: paramLocals,
1612+
resultTyps
1613+
)
1614+
1615+
val captureDataLocalIdx = fctx.paramIndices.head
1616+
1617+
// Extract the fields of captureData in individual locals
1618+
for ((captureParam, index) <- tree.captureParams.zipWithIndex) {
1619+
val local = fctx.addLocal(
1620+
captureParam.name.name,
1621+
TypeTransformer.transformType(captureParam.ptpe)
1622+
)
1623+
fctx.instrs += LOCAL_GET(captureDataLocalIdx)
1624+
fctx.instrs += STRUCT_GET(TypeIdx(dataStructType.name), StructFieldIdx(index))
1625+
fctx.instrs += LOCAL_SET(local)
1626+
}
1627+
1628+
// Now transform the body
1629+
WasmExpressionBuilder.generateIRBody(tree.body, IRTypes.AnyType)
1630+
1631+
fctx.buildAndAddToContext()
1632+
}
1633+
1634+
// Put a reference to the function on the stack
1635+
instrs += ctx.refFuncWithDeclaration(closureFuncName)
1636+
1637+
// Evaluate the capture values and instantiate the capture data struct
1638+
for ((param, value) <- tree.captureParams.zip(tree.captureValues))
1639+
genTree(value, param.ptpe)
1640+
instrs += STRUCT_NEW(TypeIdx(dataStructType.name))
1641+
1642+
// Call the appropriate helper
1643+
val helper = (hasThis, tree.restParam.isDefined) match {
1644+
case (false, false) => WasmFunctionName.closure
1645+
case (true, false) => WasmFunctionName.closureThis
1646+
case (false, true) => WasmFunctionName.closureRest
1647+
case (true, true) => WasmFunctionName.closureThisRest
1648+
}
1649+
instrs += CALL(FuncIdx(helper))
1650+
1651+
IRTypes.AnyType
1652+
}
15811653
}

Diff for: wasm/src/main/scala/wasm4s/Names.scala

+8
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ object Names {
156156
def typeTest(primRef: IRTypes.PrimRef): WasmFunctionName = helper("t" + primRef.charCode)
157157

158158
val closure = helper("closure")
159+
val closureThis = helper("closureThis")
160+
val closureRest = helper("closureRest")
161+
val closureThisRest = helper("closureThisRest")
159162

160163
val emptyString = helper("emptyString")
161164
val stringLength = helper("stringLength")
@@ -246,6 +249,9 @@ object Names {
246249
def apply(name: WasmTypeName.WasmITableTypeName) = new WasmFieldName(name.name)
247250
def apply(name: IRNames.MethodName) = new WasmFieldName(name.nameString)
248251
def apply(name: WasmFunctionName) = new WasmFieldName(name.name)
252+
253+
def captureParam(i: Int): WasmFieldName = new WasmFieldName("c" + i)
254+
249255
val vtable = new WasmFieldName("vtable")
250256
val itable = new WasmFieldName("itable")
251257
val itables = new WasmFieldName("itables")
@@ -329,6 +335,8 @@ object Names {
329335
object WasmStructTypeName {
330336
def apply(name: IRNames.ClassName) = new WasmStructTypeName(name.nameString)
331337

338+
def captureData(index: Int): WasmStructTypeName = new WasmStructTypeName("captureData__" + index)
339+
332340
val typeData = new WasmStructTypeName("typeData")
333341
}
334342

Diff for: wasm/src/main/scala/wasm4s/WasmContext.scala

+31
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,12 @@ trait ReadOnlyWasmContext {
111111
trait FunctionTypeWriterWasmContext extends ReadOnlyWasmContext { this: WasmContext =>
112112
protected val functionSignatures = LinkedHashMap.empty[WasmFunctionSignature, Int]
113113
protected val constantStringGlobals = LinkedHashMap.empty[String, WasmGlobalName]
114+
protected val closureDataTypes = LinkedHashMap.empty[List[IRTypes.Type], WasmStructType]
114115

115116
private var nextConstantStringIndex: Int = 1
117+
private var nextClosureDataTypeIndex: Int = 1
116118

119+
def addFunction(fun: WasmFunction): Unit
117120
protected def addGlobal(g: WasmGlobal): Unit
118121
protected def addFuncDeclaration(name: WasmFunctionName): Unit
119122

@@ -159,6 +162,19 @@ trait FunctionTypeWriterWasmContext extends ReadOnlyWasmContext { this: WasmCont
159162
WasmInstr.GLOBAL_GET(WasmImmediate.GlobalIdx(globalName))
160163
}
161164

165+
def getClosureDataStructType(captureParamTypes: List[IRTypes.Type]): WasmStructType = {
166+
closureDataTypes.getOrElse(captureParamTypes, {
167+
val fields: List[WasmStructField] =
168+
for ((tpe, i) <- captureParamTypes.zipWithIndex) yield
169+
WasmStructField(WasmFieldName.captureParam(i), TypeTransformer.transformType(tpe)(this), isMutable = false)
170+
val structTypeName = WasmStructTypeName.captureData(nextClosureDataTypeIndex)
171+
nextClosureDataTypeIndex += 1
172+
val structType = WasmStructType(structTypeName, fields, superType = None)
173+
addGCType(structType)
174+
structType
175+
})
176+
}
177+
162178
def refFuncWithDeclaration(name: WasmFunctionName): WasmInstr.REF_FUNC = {
163179
addFuncDeclaration(name)
164180
WasmInstr.REF_FUNC(WasmImmediate.FuncIdx(name))
@@ -227,6 +243,21 @@ class WasmContext(val module: WasmModule) extends FunctionTypeWriterWasmContext
227243
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
228244
List(WasmRefType.any)
229245
)
246+
addHelperImport(
247+
WasmFunctionName.closureThis,
248+
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
249+
List(WasmRefType.any)
250+
)
251+
addHelperImport(
252+
WasmFunctionName.closureRest,
253+
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
254+
List(WasmRefType.any)
255+
)
256+
addHelperImport(
257+
WasmFunctionName.closureThisRest,
258+
List(WasmRefType(WasmHeapType.Simple.Func), WasmAnyRef),
259+
List(WasmRefType.any)
260+
)
230261

231262
addHelperImport(WasmFunctionName.emptyString, List(), List(WasmRefType.any))
232263
addHelperImport(WasmFunctionName.stringLength, List(WasmRefType.any), List(WasmInt32))

Diff for: wasm/src/main/scala/wasm4s/WasmFunctionContext.scala

+14-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import wasm.wasm4s.WasmInstr._
1414
import wasm.ir2wasm.TypeTransformer
1515

1616
class WasmFunctionContext private (
17-
ctx: WasmContext,
17+
ctx: FunctionTypeWriterWasmContext,
1818
val enclosingClassName: Option[IRNames.ClassName],
1919
val functionName: WasmFunctionName,
2020
_receiver: Option[WasmLocal],
@@ -23,6 +23,7 @@ class WasmFunctionContext private (
2323
) {
2424
private var cnt = 0
2525
private var labelIdx = 0
26+
private var innerFuncIdx = 0
2627

2728
val locals = new WasmSymbolTable[WasmLocalName, WasmLocal]()
2829

@@ -80,6 +81,15 @@ class WasmFunctionContext private (
8081
def addSyntheticLocal(typ: WasmType): LocalIdx =
8182
addLocal(genSyntheticLocalName(), typ)
8283

84+
def genInnerFuncName(): WasmFunctionName = {
85+
val innerName = WasmFunctionName(
86+
functionName.namespace,
87+
functionName.simpleName + "__c" + innerFuncIdx
88+
)
89+
innerFuncIdx += 1
90+
innerName
91+
}
92+
8393
// Helpers to build structured control flow
8494

8595
def ifThenElse(blockType: BlockType)(thenp: => Unit)(elsep: => Unit): Unit = {
@@ -177,7 +187,7 @@ object WasmFunctionContext {
177187
receiver: Option[WasmLocal],
178188
params: List[WasmLocal],
179189
resultTypes: List[WasmType]
180-
)(implicit ctx: WasmContext): WasmFunctionContext = {
190+
)(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionContext = {
181191
new WasmFunctionContext(ctx, enclosingClassName, name, receiver, params, resultTypes)
182192
}
183193

@@ -187,7 +197,7 @@ object WasmFunctionContext {
187197
receiverTyp: Option[WasmType],
188198
paramDefs: List[IRTrees.ParamDef],
189199
resultType: IRTypes.Type
190-
)(implicit ctx: WasmContext): WasmFunctionContext = {
200+
)(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionContext = {
191201
val receiver = receiverTyp.map { typ =>
192202
WasmLocal(WasmLocalName.receiver, typ, isParameter = true)
193203
}
@@ -205,7 +215,7 @@ object WasmFunctionContext {
205215
name: WasmFunctionName,
206216
params: List[(String, WasmType)],
207217
resultTypes: List[WasmType]
208-
)(implicit ctx: WasmContext): WasmFunctionContext = {
218+
)(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionContext = {
209219
val paramLocals = params.map { param =>
210220
WasmLocal(WasmLocalName.fromStr(param._1), param._2, isParameter = true)
211221
}

0 commit comments

Comments
 (0)