Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit 620ff4a

Browse files
committed
fix: Add tests for virtual dispatch + fix bug for abstract class
We had sevelral problems with the vtable implementations - We'd been generating a global vtable even for an abstract class, but it's not needed because we can't instantiate abstract class - For calculating the vtable (both it's type and global instance), we filtered out the abstract methods. Therefore, we couldn't resolve a abstract method call from vtable. For example, ```scala class A extends B: def a = 1 class B extends C: def b: Int = 1 override def c: Int = 1 abstract class C: def c: Int ``` The vtable type for C will be an empty table, because there's no concrete methods. Therefore, when we have `x: C`, `x.c` wond't resolve the implementation of `C` because vtable type doesn't have a slot for `c`. The root cause is that we generated both of the following from one (in-memory) vtable object that doesn't have abstract methods. - vtable type (for declaring the vtable type, and resolve methods by name at compile time), and - global vtable instance (for method lookup at runtime) The former should include abstract methods like `C.c`, and the former should not.
1 parent e198109 commit 620ff4a

File tree

8 files changed

+262
-130
lines changed

8 files changed

+262
-130
lines changed

Diff for: cli/src/main/scala/TestSuites.scala

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ object TestSuites {
55
val suites = List(
66
TestSuite("testsuite.core.simple.Simple", "simple"),
77
TestSuite("testsuite.core.add.Add", "add"),
8+
TestSuite("testsuite.core.add.Add", "add"),
9+
TestSuite("testsuite.core.virtualdispatch.VirtualDispatch", "virtualDispatch"),
810
TestSuite("testsuite.core.asinstanceof.AsInstanceOfTest", "asInstanceOf"),
911
TestSuite("testsuite.core.hijackedclassesmono.HijackedClassesMonoTest", "hijackedClassesMono")
1012
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package testsuite.core.virtualdispatch
2+
3+
import scala.scalajs.js.annotation._
4+
5+
object VirtualDispatch {
6+
def main(): Unit = { val _ = test() }
7+
8+
@JSExportTopLevel("virtualDispatch")
9+
def test(): Boolean = {
10+
val a = new A
11+
val b = new B
12+
13+
testA(a) &&
14+
testB(a, isInstanceOfA = true) &&
15+
testB(b, isInstanceOfA = false) &&
16+
testC(a, isInstanceOfA = true) &&
17+
testC(b, isInstanceOfA = false)
18+
}
19+
20+
def testA(a: A): Boolean = {
21+
a.a == 2 && a.impl == 2 && a.b == 1 && a.c == 1
22+
}
23+
24+
def testB(b: B, isInstanceOfA: Boolean): Boolean = {
25+
if (isInstanceOfA) {
26+
b.b == 1 && b.c == 1 && b.impl == 2
27+
} else {
28+
b.b == 1 && b.c == 1 && b.impl == 0
29+
}
30+
}
31+
32+
def testC(c: C, isInstanceOfA: Boolean): Boolean = {
33+
if (isInstanceOfA) {
34+
c.c == 1 && c.impl == 2
35+
} else {
36+
c.c == 1 && c.impl == 0
37+
}
38+
}
39+
40+
class A extends B {
41+
def a: Int = 2
42+
override def impl = 2
43+
}
44+
45+
class B extends C {
46+
def b: Int = 1
47+
override def c: Int = 1
48+
}
49+
50+
abstract class C {
51+
def c: Int
52+
def impl: Int = 0
53+
}
54+
}

Diff for: wasm/src/main/scala/ir2wasm/Preprocessor.scala

+22-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ object Preprocessor {
2323

2424
private def preprocess(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
2525
clazz.kind match {
26-
case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface | ClassKind.HijackedClass =>
26+
case ClassKind.ModuleClass | ClassKind.Class | ClassKind.Interface |
27+
ClassKind.HijackedClass =>
2728
collectMethods(clazz)
2829
case ClassKind.JSClass | ClassKind.JSModuleClass | ClassKind.NativeJSModuleClass |
2930
ClassKind.AbstractJSType | ClassKind.NativeJSClass =>
@@ -61,6 +62,26 @@ object Preprocessor {
6162
)
6263
}
6364

65+
/** Collect WasmFunctionInfo based on the abstract method call
66+
*
67+
* ```
68+
* class A extends B:
69+
* def a = 1
70+
*
71+
* class B extends C:
72+
* def b: Int = 1
73+
* override def c: Int = 1
74+
*
75+
* abstract class C:
76+
* def c: Int
77+
* ```
78+
*
79+
* why we need this? - The problem is that the frontend linker gets rid of abstract method
80+
* entirely.
81+
*
82+
* It keeps B.c because it's concrete and used. But because `C.c` isn't there at all anymore, if
83+
* we have val `x: C` and we call `x.c`, we don't find the method at all.
84+
*/
6485
private def collectAbstractMethodCalls(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
6586
object traverser extends Traversers.Traverser {
6687
import IRTrees._

Diff for: wasm/src/main/scala/ir2wasm/WasmBuilder.scala

+84-71
Original file line numberDiff line numberDiff line change
@@ -42,40 +42,97 @@ class WasmBuilder {
4242
}
4343
}
4444

45+
/** @return
46+
* Optionally returns the generated struct type for this class. If the given LinkedClass is an
47+
* abstract class, returns None
48+
*/
4549
private def transformClassCommon(
4650
clazz: LinkedClass
4751
)(implicit ctx: WasmContext): WasmStructType = {
48-
val (vtableType, vtableName) = genVTable(clazz)
52+
// gen functions
53+
clazz.methods.foreach { method =>
54+
genFunction(clazz, method)
55+
}
56+
val className = clazz.name.name
57+
58+
// generate vtable type, this should be done for both abstract and concrete classes
59+
val vtable = ctx.calculateVtableType(className)
60+
val vtableType = genVTableType(clazz, vtable.functions)
61+
ctx.addGCType(vtableType)
62+
63+
val isAbstractClass = {
64+
// If number of declared functions doesn't match number of defined functions, it must be a AbstractClass
65+
// TODO: better way to check if it's abstract class
66+
val definedFunctions = ctx.calculateGlobalVTable(className)
67+
val declaredFunctions = vtable.functions
68+
declaredFunctions.length != definedFunctions.length
69+
}
70+
71+
// we should't generate global vtable for abstract class because
72+
// - Can't generate Global vtable because we can't fill the slot for abstract methods
73+
// - We won't access vtable for abstract classes since we can't instantiate abstract classes, there's no point generating
74+
//
75+
// However, I couldn't find a way to test if the LinkedClass is abstract
76+
// "clazz.methods.exists(m => m.body.isEmpty)" doesn't work because abstract methods are removed at linker optimization
77+
// the WasmFunctionInfo of the abstract methods will be added specially in Preprocessor
78+
val (gVtable, gItable) = if (!isAbstractClass) {
79+
// Generate global vtable
80+
val functions = ctx.calculateGlobalVTable(className)
81+
val vtableInit = functions.map { method =>
82+
WasmInstr.REF_FUNC(method.name)
83+
} :+ WasmInstr.STRUCT_NEW(vtableType.name)
84+
val vtableName = Names.WasmGlobalName.WasmGlobalVTableName(clazz.name.name)
85+
val globalVTable =
86+
WasmGlobal(
87+
vtableName,
88+
WasmRefNullType(WasmHeapType.Type(vtableType.name)),
89+
WasmExpr(vtableInit),
90+
isMutable = false
91+
)
92+
ctx.addGlobal(globalVTable)
93+
94+
// Generate class itable
95+
val globalClassITable = calculateClassITable(clazz)
96+
globalClassITable.foreach(ctx.addGlobal)
97+
98+
(Some(globalVTable), globalClassITable)
99+
} else (None, None)
100+
101+
// Declare the strcut type for the class
102+
genStructNewDefault(clazz, gVtable, gItable)
49103
val vtableField = WasmStructField(
50104
Names.WasmFieldName.vtable,
51105
WasmRefNullType(WasmHeapType.Type(vtableType.name)),
52106
isMutable = false
53107
)
54-
calculateClassITable(clazz) match {
55-
case None =>
56-
genStructNewDefault(clazz, vtableName, None)
57-
case Some(globalITable) =>
58-
ctx.addGlobal(globalITable)
59-
genStructNewDefault(clazz, vtableName, Some(globalITable))
60-
}
61-
62-
// type definition
63108
val fields = clazz.fields.map(transformField)
64109
val structType = WasmStructType(
65110
Names.WasmTypeName.WasmStructTypeName(clazz.name.name),
66111
vtableField +: WasmStructField.itables +: fields,
67112
clazz.superClass.map(s => Names.WasmTypeName.WasmStructTypeName(s.name))
68113
)
69114
ctx.addGCType(structType)
70-
71-
// implementation of methods
72-
clazz.methods.foreach { method =>
73-
genFunction(clazz, method)
74-
}
75-
76115
structType
77116
}
78117

118+
private def genVTableType(clazz: LinkedClass, functions: List[WasmFunctionInfo])(implicit
119+
ctx: WasmContext
120+
): WasmStructType = {
121+
val vtableFields =
122+
functions.map { method =>
123+
WasmStructField(
124+
Names.WasmFieldName(method.name),
125+
WasmRefNullType(WasmHeapType.Func(method.toWasmFunctionType().name)),
126+
isMutable = false
127+
)
128+
}
129+
WasmStructType(
130+
Names.WasmTypeName.WasmVTableTypeName(clazz.name.name),
131+
vtableFields,
132+
clazz.superClass.map(s => Names.WasmTypeName.WasmVTableTypeName(s.name))
133+
)
134+
}
135+
79136
private def genLoadModuleFunc(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
80137
import WasmImmediate._
81138
assert(clazz.kind == ClassKind.ModuleClass)
@@ -118,10 +175,18 @@ class WasmBuilder {
118175

119176
private def genStructNewDefault(
120177
clazz: LinkedClass,
121-
vtable: WasmGlobalName.WasmGlobalVTableName,
178+
vtable: Option[WasmGlobal],
122179
itable: Option[WasmGlobal]
123180
)(implicit ctx: WasmContext): Unit = {
124-
val getVTable = GLOBAL_GET(WasmImmediate.GlobalIdx(vtable))
181+
val getVTable = vtable match {
182+
case None =>
183+
REF_NULL(
184+
WasmImmediate.HeapType(
185+
WasmHeapType.Type(WasmTypeName.WasmVTableTypeName(clazz.name.name))
186+
)
187+
)
188+
case Some(v) => GLOBAL_GET(WasmImmediate.GlobalIdx(v.name))
189+
}
125190
val getITable = itable match {
126191
case None => REF_NULL(WasmImmediate.HeapType(WasmHeapType.Type(WasmArrayType.itables.name)))
127192
case Some(i) => GLOBAL_GET(WasmImmediate.GlobalIdx(i.name))
@@ -156,21 +221,7 @@ class WasmBuilder {
156221
)(implicit ctx: ReadOnlyWasmContext): Option[WasmGlobal] = {
157222
val classItables = ctx.calculateClassItables(clazz.name.name)
158223
if (!classItables.isEmpty) {
159-
// val classITableTypeName = WasmTypeName.WasmITableTypeName(clazz.name.name)
160-
// val classITableType = WasmStructType(
161-
// classITableTypeName,
162-
// interfaceInfos.map { info =>
163-
// val itableTypeName = WasmTypeName.WasmITableTypeName(info.name)
164-
// WasmStructField(
165-
// Names.WasmFieldName(itableTypeName),
166-
// WasmRefType(WasmHeapType.Type(itableTypeName)),
167-
// isMutable = false
168-
// )
169-
// },
170-
// None
171-
// )
172-
173-
val vtable = ctx.calculateVtable(clazz.name.name)
224+
val vtable = ctx.calculateVtableType(clazz.name.name)
174225

175226
val itablesInit: List[WasmInstr] = classItables.itables.flatMap { iface =>
176227
iface.methods.map { method =>
@@ -194,44 +245,6 @@ class WasmBuilder {
194245
} else None
195246
}
196247

197-
private def genVTable(
198-
clazz: LinkedClass
199-
)(implicit ctx: WasmContext): (WasmStructType, WasmGlobalName.WasmGlobalVTableName) = {
200-
val className = clazz.name.name
201-
def genVTableType(vtable: WasmVTable): WasmStructType = {
202-
val vtableFields =
203-
vtable.functions.map { method =>
204-
WasmStructField(
205-
Names.WasmFieldName(method.name),
206-
WasmRefNullType(WasmHeapType.Func(method.toWasmFunctionType().name)),
207-
isMutable = false
208-
)
209-
}
210-
WasmStructType(
211-
Names.WasmTypeName.WasmVTableTypeName.fromIR(clazz.name.name),
212-
vtableFields,
213-
clazz.superClass.map(s => Names.WasmTypeName.WasmVTableTypeName.fromIR(s.name))
214-
)
215-
}
216-
217-
val vtableName = Names.WasmGlobalName.WasmGlobalVTableName(clazz.name.name)
218-
219-
val vtable = ctx.calculateVtable(className)
220-
val vtableType = genVTableType(vtable)
221-
ctx.addGCType(vtableType)
222-
223-
val globalVTable =
224-
WasmGlobal(
225-
vtableName,
226-
WasmRefNullType(WasmHeapType.Type(vtableType.name)),
227-
WasmExpr(vtable.toVTableEntries(vtableType.name)),
228-
isMutable = false
229-
)
230-
ctx.addGlobal(globalVTable)
231-
232-
(vtableType, vtableName)
233-
}
234-
235248
private def transformClass(clazz: LinkedClass)(implicit ctx: WasmContext): Unit = {
236249
assert(clazz.kind == ClassKind.Class)
237250
transformClassCommon(clazz)

0 commit comments

Comments
 (0)