Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit 6e3616a

Browse files
authored
Merge pull request #121 from tanishiking/compress-itables
Compress interface tables with packed encoding
2 parents eb3b4f1 + 3bd0321 commit 6e3616a

File tree

4 files changed

+175
-25
lines changed

4 files changed

+175
-25
lines changed

Diff for: wasm/src/main/scala/ir2wasm/HelperFunctions.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -1700,7 +1700,7 @@ object HelperFunctions {
17001700
val itables = fctx.addLocal("itables", WasmRefType.nullable(WasmArrayTypeName.itables))
17011701
val exprNonNullLocal = fctx.addLocal("exprNonNull", WasmRefType.any)
17021702

1703-
val itableIdx = ctx.getItableIdx(clazz.name.name)
1703+
val itableIdx = ctx.getItableIdx(classInfo)
17041704
fctx.block(WasmRefType.anyref) { testFail =>
17051705
// if expr is not an instance of Object, return false
17061706
instrs += LOCAL_GET(exprParam)
@@ -1724,9 +1724,9 @@ object HelperFunctions {
17241724
instrs += LOCAL_GET(itables)
17251725
instrs += I32_CONST(itableIdx)
17261726
instrs += ARRAY_GET(WasmTypeName.WasmArrayTypeName.itables)
1727-
instrs += BR_ON_NULL(testFail)
1728-
1729-
instrs += I32_CONST(1)
1727+
instrs += REF_TEST(
1728+
WasmRefType(WasmStructTypeName.forITable(clazz.name.name))
1729+
)
17301730
instrs += RETURN
17311731
} // test fail
17321732

Diff for: wasm/src/main/scala/ir2wasm/Preprocessor.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ object Preprocessor {
2828

2929
for (clazz <- classes) {
3030
ctx.getClassInfo(clazz.className).buildMethodTable()
31+
}
32+
ctx.assignBuckets(classes)
3133

34+
for (clazz <- classes) {
3235
if (clazz.kind == ClassKind.Interface && clazz.hasInstanceTests)
3336
HelperFunctions.genInstanceTest(clazz)
3437
HelperFunctions.genCloneFunction(clazz)
@@ -101,7 +104,8 @@ object Preprocessor {
101104
!clazz.hasDirectInstances,
102105
hasRuntimeTypeInfo,
103106
clazz.jsNativeLoadSpec,
104-
clazz.jsNativeMembers.map(m => m.name.name -> m.jsNativeLoadSpec).toMap
107+
clazz.jsNativeMembers.map(m => m.name.name -> m.jsNativeLoadSpec).toMap,
108+
_itableIdx = -1
105109
)
106110
)
107111

Diff for: wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ private class WasmExpressionBuilder private (
655655
): Unit = {
656656
// Generates an itable-based dispatch.
657657
def genITableDispatch(): Unit = {
658-
val itableIdx = ctx.getItableIdx(receiverClassInfo.name)
658+
val itableIdx = ctx.getItableIdx(receiverClassInfo)
659659
val methodIdx = receiverClassInfo.tableMethodInfos(methodName).tableIndex
660660

661661
instrs += LOCAL_GET(receiverLocalForDispatch)

Diff for: wasm/src/main/scala/wasm4s/WasmContext.scala

+165-19
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,27 @@ import wasm.ir2wasm.WasmExpressionBuilder
2020
import org.scalajs.linker.interface.ModuleInitializer
2121
import org.scalajs.linker.interface.unstable.ModuleInitializerImpl
2222
import org.scalajs.linker.standard.LinkedTopLevelExport
23+
import org.scalajs.linker.standard.LinkedClass
2324

2425
abstract class ReadOnlyWasmContext {
2526
import WasmContext._
2627

27-
protected val itableIdx = mutable.Map[IRNames.ClassName, Int]()
2828
protected val classInfo = mutable.Map[IRNames.ClassName, WasmClassInfo]()
29-
protected var nextItableIdx: Int
3029

3130
val cloneFunctionTypeName: WasmTypeName
3231
val isJSClassInstanceFuncTypeName: WasmTypeName
3332

34-
def itablesLength = nextItableIdx
33+
protected var _itablesLength: Int = 0
34+
def itablesLength = _itablesLength
3535

3636
/** Get an index of the itable for the given interface. The itable instance must be placed at the
3737
* index in the array of itables (whose size is `itablesLength`).
3838
*/
39-
def getItableIdx(iface: IRNames.ClassName): Int =
40-
itableIdx.getOrElse(
41-
iface,
42-
throw new IllegalArgumentException(s"Interface $iface is not registed.")
43-
)
39+
def getItableIdx(iface: WasmClassInfo): Int = {
40+
val idx = iface.itableIdx
41+
if (idx < 0) throw new IllegalArgumentException(s"Interface $iface is not registed.")
42+
idx
43+
}
4444

4545
def getClassInfoOption(name: IRNames.ClassName): Option[WasmClassInfo] =
4646
classInfo.get(name)
@@ -212,7 +212,8 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
212212
private val _importedModules: mutable.LinkedHashSet[String] =
213213
new mutable.LinkedHashSet()
214214

215-
override protected var nextItableIdx: Int = 0
215+
def assignBuckets(classes: List[LinkedClass]): Unit =
216+
_itablesLength = assignBuckets0(classes.filterNot(_.kind.isJSType))
216217

217218
private val _jsPrivateFieldNames: mutable.ListBuffer[IRNames.FieldName] =
218219
new mutable.ListBuffer()
@@ -255,13 +256,8 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
255256
def addFuncDeclaration(name: WasmFunctionName): Unit =
256257
_funcDeclarations += name
257258

258-
def putClassInfo(name: IRNames.ClassName, info: WasmClassInfo): Unit = {
259+
def putClassInfo(name: IRNames.ClassName, info: WasmClassInfo): Unit =
259260
classInfo.put(name, info)
260-
if (info.isInterface) {
261-
itableIdx.put(name, nextItableIdx)
262-
nextItableIdx += 1
263-
}
264-
}
265261

266262
def addJSPrivateFieldName(fieldName: IRNames.FieldName): Unit =
267263
_jsPrivateFieldNames += fieldName
@@ -567,13 +563,13 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
567563
import fctx.instrs
568564

569565
// Initialize itables
570-
571566
for ((name, globalName) <- classItableGlobals) {
572567
val classInfo = getClassInfo(name)
573568
val interfaces = classInfo.ancestors.map(getClassInfo(_)).filter(_.isInterface)
574569
val resolvedMethodInfos = classInfo.resolvedMethodInfos
570+
575571
interfaces.foreach { iface =>
576-
val idx = getItableIdx(iface.name)
572+
val idx = getItableIdx(iface)
577573
instrs += WasmInstr.GLOBAL_GET(globalName)
578574
instrs += WasmInstr.I32_CONST(idx)
579575

@@ -595,7 +591,7 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
595591
interfaceInfo <- getClassInfoOption(interfaceName)
596592
} {
597593
instrs += GLOBAL_GET(globalName)
598-
instrs += I32_CONST(getItableIdx(interfaceName))
594+
instrs += I32_CONST(getItableIdx(interfaceInfo))
599595

600596
for (method <- interfaceInfo.tableEntries)
601597
instrs += refFuncWithDeclaration(resolvedMethodInfos(method).tableEntryName)
@@ -704,6 +700,142 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
704700
module.addElement(WasmElement(WasmRefType.funcref, exprs, WasmElement.Mode.Declarative))
705701
}
706702
}
703+
704+
/** Group interface types + types that implements any interfaces into buckets, where no two types
705+
* in the same bucket can have common subtypes.
706+
*
707+
* It allows compressing the itable by reusing itable's index (buckets) for unrelated types,
708+
* instead of having a 1-1 mapping from type to index. As a result, the itables' length will be
709+
* the same as the number of buckets).
710+
*
711+
* The algorithm separates the type hierarchy into three disjoint subsets,
712+
*
713+
* - join types: types with multiple parents (direct supertypes) that have only single
714+
* subtyping descendants: `join(T) = {x ∈ multis(T) | ∄ y ∈ multis(T) : y <: x}` where
715+
* multis(T) means types with multiple direct supertypes.
716+
* - spine types: all ancestors of join types: `spine(T) = {x ∈ T | ∃ y ∈ join(T) : x ∈
717+
* ancestors(y)}`
718+
* - plain types: types that are neither join nor spine types
719+
*
720+
* The bucket assignment process consists of two parts:
721+
*
722+
* **1. Assign buckets to spine types**
723+
*
724+
* Two spine types can share the same bucket only if they do not have any common join type
725+
* descendants.
726+
*
727+
* Visit spine types in reverse topological order because (from leaves to root) when assigning a
728+
* a spine type to bucket, the algorithm already has the complete information about the
729+
* join/spine type descendants of that spine type.
730+
*
731+
* Assign a bucket to a spine type if adding it doesn't violate the bucket assignment rule: two
732+
* spine types can share a bucket only if they don't have any common join type descendants. If no
733+
* existing bucket satisfies the rule, create a new bucket.
734+
*
735+
* **2. Assign buckets to non-spine types (plain and join types)**
736+
*
737+
* Visit these types in level order (from root to leaves) For each type, compute the set of
738+
* buckets already used by its ancestors. Assign the type to any available bucket not in this
739+
* set. If no available bucket exists, create a new one.
740+
*
741+
* To test if type A is a subtype of type B: load the bucket index of type B (we do this by
742+
* `getItableIdx`), load the itable at that index from A, and check if the itable is an itable
743+
* for B.
744+
*
745+
* @see
746+
* This algorithm is based on the "packed encoding" presented in the paper "Efficient Type
747+
* Inclusion Tests"
748+
* [[https://www.researchgate.net/publication/2438441_Efficient_Type_Inclusion_Tests]]
749+
*/
750+
private def assignBuckets0(classes: List[LinkedClass]): Int = {
751+
var nextIdx = 0
752+
def newBucket(): Bucket = {
753+
val idx = nextIdx
754+
nextIdx += 1
755+
new Bucket(idx)
756+
}
757+
def getAllInterfaces(info: WasmClassInfo): List[IRNames.ClassName] =
758+
info.ancestors.filter(getClassInfo(_).isInterface)
759+
760+
val buckets = new mutable.ListBuffer[Bucket]()
761+
762+
/** All join type descendants of the class */
763+
val joinsOf =
764+
new mutable.HashMap[IRNames.ClassName, mutable.HashSet[IRNames.ClassName]]()
765+
766+
/** the buckets that have been assigned to any of the ancestors of the class */
767+
val usedOf = new mutable.HashMap[IRNames.ClassName, mutable.HashSet[Bucket]]()
768+
val spines = new mutable.HashSet[IRNames.ClassName]()
769+
770+
for (clazz <- classes.reverseIterator) {
771+
val info = getClassInfo(clazz.name.name)
772+
val ifaces = getAllInterfaces(info)
773+
if (ifaces.nonEmpty) {
774+
val joins = joinsOf.getOrElse(clazz.name.name, new mutable.HashSet())
775+
776+
if (joins.nonEmpty) { // spine type
777+
var found = false
778+
val bs = buckets.iterator
779+
// look for an existing bucket to add the spine type to
780+
while (!found && bs.hasNext) {
781+
val b = bs.next()
782+
// two spine types can share a bucket only if they don't have any common join type descendants
783+
if (!b.joins.exists(joins)) {
784+
found = true
785+
b.add(info)
786+
b.joins ++= joins
787+
}
788+
}
789+
if (!found) { // there's no bucket to add, create new bucket
790+
val b = newBucket()
791+
b.add(info)
792+
buckets.append(b)
793+
b.joins ++= joins
794+
}
795+
for (iface <- ifaces) {
796+
joinsOf.getOrElseUpdate(iface, new mutable.HashSet()) ++= joins
797+
}
798+
spines.add(clazz.name.name)
799+
} else if (ifaces.length > 1) { // join type, add to joins map, bucket assignment is done later
800+
ifaces.foreach { iface =>
801+
joinsOf.getOrElseUpdate(iface, new mutable.HashSet()) += clazz.name.name
802+
}
803+
}
804+
// else: plain, do nothing
805+
}
806+
807+
}
808+
809+
for (clazz <- classes) {
810+
val info = getClassInfo(clazz.name.name)
811+
val ifaces = getAllInterfaces(info)
812+
if (ifaces.nonEmpty && !spines.contains(clazz.name.name)) {
813+
val used = usedOf.getOrElse(clazz.name.name, new mutable.HashSet())
814+
for {
815+
iface <- ifaces
816+
parentUsed <- usedOf.get(iface)
817+
} { used ++= parentUsed }
818+
819+
var found = false
820+
val bs = buckets.iterator
821+
while (!found && bs.hasNext) {
822+
val b = bs.next()
823+
if (!used.contains(b)) {
824+
found = true
825+
b.add(info)
826+
used.add(b)
827+
}
828+
}
829+
if (!found) {
830+
val b = newBucket()
831+
buckets.append(b)
832+
b.add(info)
833+
used.add(b)
834+
}
835+
}
836+
}
837+
buckets.length
838+
}
707839
}
708840

709841
object WasmContext {
@@ -723,7 +855,8 @@ object WasmContext {
723855
val isAbstract: Boolean,
724856
val hasRuntimeTypeInfo: Boolean,
725857
val jsNativeLoadSpec: Option[IRTrees.JSNativeLoadSpec],
726-
val jsNativeMembers: Map[IRNames.MethodName, IRTrees.JSNativeLoadSpec]
858+
val jsNativeMembers: Map[IRNames.MethodName, IRTrees.JSNativeLoadSpec],
859+
private var _itableIdx: Int
727860
) {
728861
private val fieldIdxByName: Map[IRNames.FieldName, Int] =
729862
allFieldDefs.map(_.name.name).zipWithIndex.map(p => p._1 -> (p._2 + classFieldOffset)).toMap
@@ -757,6 +890,12 @@ object WasmContext {
757890

758891
def hasInstances: Boolean = _hasInstances
759892

893+
def setItableIdx(idx: Int): Unit = _itableIdx = idx
894+
895+
/** Returns the index of this interface's itable in the classes' interface tables.
896+
*/
897+
def itableIdx: Int = _itableIdx
898+
760899
private var _specialInstanceTypes: Int = 0
761900

762901
def addSpecialInstanceType(jsValueType: Int): Unit =
@@ -890,4 +1029,11 @@ object WasmContext {
8901029
}
8911030

8921031
final class TableMethodInfo(val methodName: IRNames.MethodName, val tableIndex: Int)
1032+
1033+
private[WasmContext] class Bucket(idx: Int) {
1034+
def add(clazz: WasmClassInfo) = clazz.setItableIdx((idx))
1035+
1036+
/** A set of join types that are descendants of the types assigned to that bucket */
1037+
val joins = new mutable.HashSet[IRNames.ClassName]()
1038+
}
8931039
}

0 commit comments

Comments
 (0)