Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit 3bd0321

Browse files
committed
Compress interface tables with packed encoding
This commit implements the "packed encoding" algorithm for compressing interface tables (itables). The algorithm is based on the paper "Efficient Type Inclusion Tests". https://www.researchgate.net/publication/2438441_Efficient_Type_Inclusion_Tests It compresses the itable by reusing itable indices for unrelated types, reducing the size of the itable array. For scalajs-test-suite, the itable size is reduced from 413 to 47. This change didn't reduce the size of wasm binary that much (14957kb to 14943kb), but it should be space efficient. We chose packed encoding because the compression rate is good enough among the presented methods in the paper. While hierarchical encoding is slightly better compression rate, packed encoding is better in both compile-time and runtime performance in CPU and space.
1 parent 5128aa5 commit 3bd0321

File tree

4 files changed

+175
-25
lines changed

4 files changed

+175
-25
lines changed

Diff for: wasm/src/main/scala/ir2wasm/HelperFunctions.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -1700,7 +1700,7 @@ object HelperFunctions {
17001700
val itables = fctx.addLocal("itables", WasmRefType.nullable(WasmArrayTypeName.itables))
17011701
val exprNonNullLocal = fctx.addLocal("exprNonNull", WasmRefType.any)
17021702

1703-
val itableIdx = ctx.getItableIdx(clazz.name.name)
1703+
val itableIdx = ctx.getItableIdx(classInfo)
17041704
fctx.block(WasmRefType.anyref) { testFail =>
17051705
// if expr is not an instance of Object, return false
17061706
instrs += LOCAL_GET(exprParam)
@@ -1724,9 +1724,9 @@ object HelperFunctions {
17241724
instrs += LOCAL_GET(itables)
17251725
instrs += I32_CONST(itableIdx)
17261726
instrs += ARRAY_GET(WasmTypeName.WasmArrayTypeName.itables)
1727-
instrs += BR_ON_NULL(testFail)
1728-
1729-
instrs += I32_CONST(1)
1727+
instrs += REF_TEST(
1728+
WasmRefType(WasmStructTypeName.forITable(clazz.name.name))
1729+
)
17301730
instrs += RETURN
17311731
} // test fail
17321732

Diff for: wasm/src/main/scala/ir2wasm/Preprocessor.scala

+5-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ object Preprocessor {
2828

2929
for (clazz <- classes) {
3030
ctx.getClassInfo(clazz.className).buildMethodTable()
31+
}
32+
ctx.assignBuckets(classes)
3133

34+
for (clazz <- classes) {
3235
if (clazz.kind == ClassKind.Interface && clazz.hasInstanceTests)
3336
HelperFunctions.genInstanceTest(clazz)
3437
HelperFunctions.genCloneFunction(clazz)
@@ -101,7 +104,8 @@ object Preprocessor {
101104
!clazz.hasDirectInstances,
102105
hasRuntimeTypeInfo,
103106
clazz.jsNativeLoadSpec,
104-
clazz.jsNativeMembers.map(m => m.name.name -> m.jsNativeLoadSpec).toMap
107+
clazz.jsNativeMembers.map(m => m.name.name -> m.jsNativeLoadSpec).toMap,
108+
_itableIdx = -1
105109
)
106110
)
107111

Diff for: wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ private class WasmExpressionBuilder private (
655655
): Unit = {
656656
// Generates an itable-based dispatch.
657657
def genITableDispatch(): Unit = {
658-
val itableIdx = ctx.getItableIdx(receiverClassInfo.name)
658+
val itableIdx = ctx.getItableIdx(receiverClassInfo)
659659
val methodIdx = receiverClassInfo.tableMethodInfos(methodName).tableIndex
660660

661661
instrs += LOCAL_GET(receiverLocalForDispatch)

Diff for: wasm/src/main/scala/wasm4s/WasmContext.scala

+165-19
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,27 @@ import wasm.ir2wasm.WasmExpressionBuilder
2020
import org.scalajs.linker.interface.ModuleInitializer
2121
import org.scalajs.linker.interface.unstable.ModuleInitializerImpl
2222
import org.scalajs.linker.standard.LinkedTopLevelExport
23+
import org.scalajs.linker.standard.LinkedClass
2324

2425
abstract class ReadOnlyWasmContext {
2526
import WasmContext._
2627

27-
protected val itableIdx = mutable.Map[IRNames.ClassName, Int]()
2828
protected val classInfo = mutable.Map[IRNames.ClassName, WasmClassInfo]()
29-
protected var nextItableIdx: Int
3029

3130
val cloneFunctionTypeName: WasmTypeName
3231
val isJSClassInstanceFuncTypeName: WasmTypeName
3332

34-
def itablesLength = nextItableIdx
33+
protected var _itablesLength: Int = 0
34+
def itablesLength = _itablesLength
3535

3636
/** Get an index of the itable for the given interface. The itable instance must be placed at the
3737
* index in the array of itables (whose size is `itablesLength`).
3838
*/
39-
def getItableIdx(iface: IRNames.ClassName): Int =
40-
itableIdx.getOrElse(
41-
iface,
42-
throw new IllegalArgumentException(s"Interface $iface is not registed.")
43-
)
39+
def getItableIdx(iface: WasmClassInfo): Int = {
40+
val idx = iface.itableIdx
41+
if (idx < 0) throw new IllegalArgumentException(s"Interface $iface is not registed.")
42+
idx
43+
}
4444

4545
def getClassInfoOption(name: IRNames.ClassName): Option[WasmClassInfo] =
4646
classInfo.get(name)
@@ -212,7 +212,8 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
212212
private val _importedModules: mutable.LinkedHashSet[String] =
213213
new mutable.LinkedHashSet()
214214

215-
override protected var nextItableIdx: Int = 0
215+
def assignBuckets(classes: List[LinkedClass]): Unit =
216+
_itablesLength = assignBuckets0(classes.filterNot(_.kind.isJSType))
216217

217218
private val _jsPrivateFieldNames: mutable.ListBuffer[IRNames.FieldName] =
218219
new mutable.ListBuffer()
@@ -255,13 +256,8 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
255256
def addFuncDeclaration(name: WasmFunctionName): Unit =
256257
_funcDeclarations += name
257258

258-
def putClassInfo(name: IRNames.ClassName, info: WasmClassInfo): Unit = {
259+
def putClassInfo(name: IRNames.ClassName, info: WasmClassInfo): Unit =
259260
classInfo.put(name, info)
260-
if (info.isInterface) {
261-
itableIdx.put(name, nextItableIdx)
262-
nextItableIdx += 1
263-
}
264-
}
265261

266262
def addJSPrivateFieldName(fieldName: IRNames.FieldName): Unit =
267263
_jsPrivateFieldNames += fieldName
@@ -567,13 +563,13 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
567563
import fctx.instrs
568564

569565
// Initialize itables
570-
571566
for ((name, globalName) <- classItableGlobals) {
572567
val classInfo = getClassInfo(name)
573568
val interfaces = classInfo.ancestors.map(getClassInfo(_)).filter(_.isInterface)
574569
val resolvedMethodInfos = classInfo.resolvedMethodInfos
570+
575571
interfaces.foreach { iface =>
576-
val idx = getItableIdx(iface.name)
572+
val idx = getItableIdx(iface)
577573
instrs += WasmInstr.GLOBAL_GET(globalName)
578574
instrs += WasmInstr.I32_CONST(idx)
579575

@@ -595,7 +591,7 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
595591
interfaceInfo <- getClassInfoOption(interfaceName)
596592
} {
597593
instrs += GLOBAL_GET(globalName)
598-
instrs += I32_CONST(getItableIdx(interfaceName))
594+
instrs += I32_CONST(getItableIdx(interfaceInfo))
599595

600596
for (method <- interfaceInfo.tableEntries)
601597
instrs += refFuncWithDeclaration(resolvedMethodInfos(method).tableEntryName)
@@ -704,6 +700,142 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
704700
module.addElement(WasmElement(WasmRefType.funcref, exprs, WasmElement.Mode.Declarative))
705701
}
706702
}
703+
704+
/** Group interface types + types that implements any interfaces into buckets, where no two types
705+
* in the same bucket can have common subtypes.
706+
*
707+
* It allows compressing the itable by reusing itable's index (buckets) for unrelated types,
708+
* instead of having a 1-1 mapping from type to index. As a result, the itables' length will be
709+
* the same as the number of buckets).
710+
*
711+
* The algorithm separates the type hierarchy into three disjoint subsets,
712+
*
713+
* - join types: types with multiple parents (direct supertypes) that have only single
714+
* subtyping descendants: `join(T) = {x ∈ multis(T) | ∄ y ∈ multis(T) : y <: x}` where
715+
* multis(T) means types with multiple direct supertypes.
716+
* - spine types: all ancestors of join types: `spine(T) = {x ∈ T | ∃ y ∈ join(T) : x ∈
717+
* ancestors(y)}`
718+
* - plain types: types that are neither join nor spine types
719+
*
720+
* The bucket assignment process consists of two parts:
721+
*
722+
* **1. Assign buckets to spine types**
723+
*
724+
* Two spine types can share the same bucket only if they do not have any common join type
725+
* descendants.
726+
*
727+
* Visit spine types in reverse topological order because (from leaves to root) when assigning a
728+
* a spine type to bucket, the algorithm already has the complete information about the
729+
* join/spine type descendants of that spine type.
730+
*
731+
* Assign a bucket to a spine type if adding it doesn't violate the bucket assignment rule: two
732+
* spine types can share a bucket only if they don't have any common join type descendants. If no
733+
* existing bucket satisfies the rule, create a new bucket.
734+
*
735+
* **2. Assign buckets to non-spine types (plain and join types)**
736+
*
737+
* Visit these types in level order (from root to leaves) For each type, compute the set of
738+
* buckets already used by its ancestors. Assign the type to any available bucket not in this
739+
* set. If no available bucket exists, create a new one.
740+
*
741+
* To test if type A is a subtype of type B: load the bucket index of type B (we do this by
742+
* `getItableIdx`), load the itable at that index from A, and check if the itable is an itable
743+
* for B.
744+
*
745+
* @see
746+
* This algorithm is based on the "packed encoding" presented in the paper "Efficient Type
747+
* Inclusion Tests"
748+
* [[https://www.researchgate.net/publication/2438441_Efficient_Type_Inclusion_Tests]]
749+
*/
750+
private def assignBuckets0(classes: List[LinkedClass]): Int = {
751+
var nextIdx = 0
752+
def newBucket(): Bucket = {
753+
val idx = nextIdx
754+
nextIdx += 1
755+
new Bucket(idx)
756+
}
757+
def getAllInterfaces(info: WasmClassInfo): List[IRNames.ClassName] =
758+
info.ancestors.filter(getClassInfo(_).isInterface)
759+
760+
val buckets = new mutable.ListBuffer[Bucket]()
761+
762+
/** All join type descendants of the class */
763+
val joinsOf =
764+
new mutable.HashMap[IRNames.ClassName, mutable.HashSet[IRNames.ClassName]]()
765+
766+
/** the buckets that have been assigned to any of the ancestors of the class */
767+
val usedOf = new mutable.HashMap[IRNames.ClassName, mutable.HashSet[Bucket]]()
768+
val spines = new mutable.HashSet[IRNames.ClassName]()
769+
770+
for (clazz <- classes.reverseIterator) {
771+
val info = getClassInfo(clazz.name.name)
772+
val ifaces = getAllInterfaces(info)
773+
if (ifaces.nonEmpty) {
774+
val joins = joinsOf.getOrElse(clazz.name.name, new mutable.HashSet())
775+
776+
if (joins.nonEmpty) { // spine type
777+
var found = false
778+
val bs = buckets.iterator
779+
// look for an existing bucket to add the spine type to
780+
while (!found && bs.hasNext) {
781+
val b = bs.next()
782+
// two spine types can share a bucket only if they don't have any common join type descendants
783+
if (!b.joins.exists(joins)) {
784+
found = true
785+
b.add(info)
786+
b.joins ++= joins
787+
}
788+
}
789+
if (!found) { // there's no bucket to add, create new bucket
790+
val b = newBucket()
791+
b.add(info)
792+
buckets.append(b)
793+
b.joins ++= joins
794+
}
795+
for (iface <- ifaces) {
796+
joinsOf.getOrElseUpdate(iface, new mutable.HashSet()) ++= joins
797+
}
798+
spines.add(clazz.name.name)
799+
} else if (ifaces.length > 1) { // join type, add to joins map, bucket assignment is done later
800+
ifaces.foreach { iface =>
801+
joinsOf.getOrElseUpdate(iface, new mutable.HashSet()) += clazz.name.name
802+
}
803+
}
804+
// else: plain, do nothing
805+
}
806+
807+
}
808+
809+
for (clazz <- classes) {
810+
val info = getClassInfo(clazz.name.name)
811+
val ifaces = getAllInterfaces(info)
812+
if (ifaces.nonEmpty && !spines.contains(clazz.name.name)) {
813+
val used = usedOf.getOrElse(clazz.name.name, new mutable.HashSet())
814+
for {
815+
iface <- ifaces
816+
parentUsed <- usedOf.get(iface)
817+
} { used ++= parentUsed }
818+
819+
var found = false
820+
val bs = buckets.iterator
821+
while (!found && bs.hasNext) {
822+
val b = bs.next()
823+
if (!used.contains(b)) {
824+
found = true
825+
b.add(info)
826+
used.add(b)
827+
}
828+
}
829+
if (!found) {
830+
val b = newBucket()
831+
buckets.append(b)
832+
b.add(info)
833+
used.add(b)
834+
}
835+
}
836+
}
837+
buckets.length
838+
}
707839
}
708840

709841
object WasmContext {
@@ -723,7 +855,8 @@ object WasmContext {
723855
val isAbstract: Boolean,
724856
val hasRuntimeTypeInfo: Boolean,
725857
val jsNativeLoadSpec: Option[IRTrees.JSNativeLoadSpec],
726-
val jsNativeMembers: Map[IRNames.MethodName, IRTrees.JSNativeLoadSpec]
858+
val jsNativeMembers: Map[IRNames.MethodName, IRTrees.JSNativeLoadSpec],
859+
private var _itableIdx: Int
727860
) {
728861
private val fieldIdxByName: Map[IRNames.FieldName, Int] =
729862
allFieldDefs.map(_.name.name).zipWithIndex.map(p => p._1 -> (p._2 + classFieldOffset)).toMap
@@ -757,6 +890,12 @@ object WasmContext {
757890

758891
def hasInstances: Boolean = _hasInstances
759892

893+
def setItableIdx(idx: Int): Unit = _itableIdx = idx
894+
895+
/** Returns the index of this interface's itable in the classes' interface tables.
896+
*/
897+
def itableIdx: Int = _itableIdx
898+
760899
private var _specialInstanceTypes: Int = 0
761900

762901
def addSpecialInstanceType(jsValueType: Int): Unit =
@@ -890,4 +1029,11 @@ object WasmContext {
8901029
}
8911030

8921031
final class TableMethodInfo(val methodName: IRNames.MethodName, val tableIndex: Int)
1032+
1033+
private[WasmContext] class Bucket(idx: Int) {
1034+
def add(clazz: WasmClassInfo) = clazz.setItableIdx((idx))
1035+
1036+
/** A set of join types that are descendants of the types assigned to that bucket */
1037+
val joins = new mutable.HashSet[IRNames.ClassName]()
1038+
}
8931039
}

0 commit comments

Comments
 (0)