Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Compress interface tables with packed encoding #121

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions wasm/src/main/scala/ir2wasm/HelperFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1700,7 +1700,7 @@ object HelperFunctions {
val itables = fctx.addLocal("itables", WasmRefType.nullable(WasmArrayTypeName.itables))
val exprNonNullLocal = fctx.addLocal("exprNonNull", WasmRefType.any)

val itableIdx = ctx.getItableIdx(clazz.name.name)
val itableIdx = ctx.getItableIdx(classInfo)
fctx.block(WasmRefType.anyref) { testFail =>
// if expr is not an instance of Object, return false
instrs += LOCAL_GET(exprParam)
Expand All @@ -1724,9 +1724,9 @@ object HelperFunctions {
instrs += LOCAL_GET(itables)
instrs += I32_CONST(itableIdx)
instrs += ARRAY_GET(WasmTypeName.WasmArrayTypeName.itables)
instrs += BR_ON_NULL(testFail)

instrs += I32_CONST(1)
instrs += REF_TEST(
WasmRefType(WasmStructTypeName.forITable(clazz.name.name))
)
instrs += RETURN
} // test fail

Expand Down
6 changes: 5 additions & 1 deletion wasm/src/main/scala/ir2wasm/Preprocessor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ object Preprocessor {

for (clazz <- classes) {
ctx.getClassInfo(clazz.className).buildMethodTable()
}
ctx.assignBuckets(classes)

for (clazz <- classes) {
if (clazz.kind == ClassKind.Interface && clazz.hasInstanceTests)
HelperFunctions.genInstanceTest(clazz)
HelperFunctions.genCloneFunction(clazz)
Expand Down Expand Up @@ -101,7 +104,8 @@ object Preprocessor {
!clazz.hasDirectInstances,
hasRuntimeTypeInfo,
clazz.jsNativeLoadSpec,
clazz.jsNativeMembers.map(m => m.name.name -> m.jsNativeLoadSpec).toMap
clazz.jsNativeMembers.map(m => m.name.name -> m.jsNativeLoadSpec).toMap,
_itableIdx = -1
)
)

Expand Down
2 changes: 1 addition & 1 deletion wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ private class WasmExpressionBuilder private (
): Unit = {
// Generates an itable-based dispatch.
def genITableDispatch(): Unit = {
val itableIdx = ctx.getItableIdx(receiverClassInfo.name)
val itableIdx = ctx.getItableIdx(receiverClassInfo)
val methodIdx = receiverClassInfo.tableMethodInfos(methodName).tableIndex

instrs += LOCAL_GET(receiverLocalForDispatch)
Expand Down
184 changes: 165 additions & 19 deletions wasm/src/main/scala/wasm4s/WasmContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,27 @@ import wasm.ir2wasm.WasmExpressionBuilder
import org.scalajs.linker.interface.ModuleInitializer
import org.scalajs.linker.interface.unstable.ModuleInitializerImpl
import org.scalajs.linker.standard.LinkedTopLevelExport
import org.scalajs.linker.standard.LinkedClass

abstract class ReadOnlyWasmContext {
import WasmContext._

protected val itableIdx = mutable.Map[IRNames.ClassName, Int]()
protected val classInfo = mutable.Map[IRNames.ClassName, WasmClassInfo]()
protected var nextItableIdx: Int

val cloneFunctionTypeName: WasmTypeName
val isJSClassInstanceFuncTypeName: WasmTypeName

def itablesLength = nextItableIdx
protected var _itablesLength: Int = 0
def itablesLength = _itablesLength

/** Get an index of the itable for the given interface. The itable instance must be placed at the
* index in the array of itables (whose size is `itablesLength`).
*/
def getItableIdx(iface: IRNames.ClassName): Int =
itableIdx.getOrElse(
iface,
throw new IllegalArgumentException(s"Interface $iface is not registed.")
)
def getItableIdx(iface: WasmClassInfo): Int = {
val idx = iface.itableIdx
if (idx < 0) throw new IllegalArgumentException(s"Interface $iface is not registed.")
idx
}

def getClassInfoOption(name: IRNames.ClassName): Option[WasmClassInfo] =
classInfo.get(name)
Expand Down Expand Up @@ -212,7 +212,8 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
private val _importedModules: mutable.LinkedHashSet[String] =
new mutable.LinkedHashSet()

override protected var nextItableIdx: Int = 0
def assignBuckets(classes: List[LinkedClass]): Unit =
_itablesLength = assignBuckets0(classes.filterNot(_.kind.isJSType))

private val _jsPrivateFieldNames: mutable.ListBuffer[IRNames.FieldName] =
new mutable.ListBuffer()
Expand Down Expand Up @@ -255,13 +256,8 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
def addFuncDeclaration(name: WasmFunctionName): Unit =
_funcDeclarations += name

def putClassInfo(name: IRNames.ClassName, info: WasmClassInfo): Unit = {
def putClassInfo(name: IRNames.ClassName, info: WasmClassInfo): Unit =
classInfo.put(name, info)
if (info.isInterface) {
itableIdx.put(name, nextItableIdx)
nextItableIdx += 1
}
}

def addJSPrivateFieldName(fieldName: IRNames.FieldName): Unit =
_jsPrivateFieldNames += fieldName
Expand Down Expand Up @@ -567,13 +563,13 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
import fctx.instrs

// Initialize itables

for ((name, globalName) <- classItableGlobals) {
val classInfo = getClassInfo(name)
val interfaces = classInfo.ancestors.map(getClassInfo(_)).filter(_.isInterface)
val resolvedMethodInfos = classInfo.resolvedMethodInfos

interfaces.foreach { iface =>
val idx = getItableIdx(iface.name)
val idx = getItableIdx(iface)
instrs += WasmInstr.GLOBAL_GET(globalName)
instrs += WasmInstr.I32_CONST(idx)

Expand All @@ -595,7 +591,7 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
interfaceInfo <- getClassInfoOption(interfaceName)
} {
instrs += GLOBAL_GET(globalName)
instrs += I32_CONST(getItableIdx(interfaceName))
instrs += I32_CONST(getItableIdx(interfaceInfo))

for (method <- interfaceInfo.tableEntries)
instrs += refFuncWithDeclaration(resolvedMethodInfos(method).tableEntryName)
Expand Down Expand Up @@ -704,6 +700,142 @@ class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext {
module.addElement(WasmElement(WasmRefType.funcref, exprs, WasmElement.Mode.Declarative))
}
}

/** Group interface types + types that implements any interfaces into buckets, where no two types
* in the same bucket can have common subtypes.
*
* It allows compressing the itable by reusing itable's index (buckets) for unrelated types,
* instead of having a 1-1 mapping from type to index. As a result, the itables' length will be
* the same as the number of buckets).
*
* The algorithm separates the type hierarchy into three disjoint subsets,
*
* - join types: types with multiple parents (direct supertypes) that have only single
* subtyping descendants: `join(T) = {x ∈ multis(T) | ∄ y ∈ multis(T) : y <: x}` where
* multis(T) means types with multiple direct supertypes.
* - spine types: all ancestors of join types: `spine(T) = {x ∈ T | ∃ y ∈ join(T) : x ∈
* ancestors(y)}`
* - plain types: types that are neither join nor spine types
*
* The bucket assignment process consists of two parts:
*
* **1. Assign buckets to spine types**
*
* Two spine types can share the same bucket only if they do not have any common join type
* descendants.
*
* Visit spine types in reverse topological order because (from leaves to root) when assigning a
* a spine type to bucket, the algorithm already has the complete information about the
* join/spine type descendants of that spine type.
*
* Assign a bucket to a spine type if adding it doesn't violate the bucket assignment rule: two
* spine types can share a bucket only if they don't have any common join type descendants. If no
* existing bucket satisfies the rule, create a new bucket.
*
* **2. Assign buckets to non-spine types (plain and join types)**
*
* Visit these types in level order (from root to leaves) For each type, compute the set of
* buckets already used by its ancestors. Assign the type to any available bucket not in this
* set. If no available bucket exists, create a new one.
*
* To test if type A is a subtype of type B: load the bucket index of type B (we do this by
* `getItableIdx`), load the itable at that index from A, and check if the itable is an itable
* for B.
*
* @see
* This algorithm is based on the "packed encoding" presented in the paper "Efficient Type
* Inclusion Tests"
* [[https://www.researchgate.net/publication/2438441_Efficient_Type_Inclusion_Tests]]
*/
private def assignBuckets0(classes: List[LinkedClass]): Int = {
var nextIdx = 0
def newBucket(): Bucket = {
val idx = nextIdx
nextIdx += 1
new Bucket(idx)
}
def getAllInterfaces(info: WasmClassInfo): List[IRNames.ClassName] =
info.ancestors.filter(getClassInfo(_).isInterface)

val buckets = new mutable.ListBuffer[Bucket]()

/** All join type descendants of the class */
val joinsOf =
new mutable.HashMap[IRNames.ClassName, mutable.HashSet[IRNames.ClassName]]()

/** the buckets that have been assigned to any of the ancestors of the class */
val usedOf = new mutable.HashMap[IRNames.ClassName, mutable.HashSet[Bucket]]()
val spines = new mutable.HashSet[IRNames.ClassName]()

for (clazz <- classes.reverseIterator) {
val info = getClassInfo(clazz.name.name)
val ifaces = getAllInterfaces(info)
if (ifaces.nonEmpty) {
val joins = joinsOf.getOrElse(clazz.name.name, new mutable.HashSet())

if (joins.nonEmpty) { // spine type
var found = false
val bs = buckets.iterator
// look for an existing bucket to add the spine type to
while (!found && bs.hasNext) {
val b = bs.next()
// two spine types can share a bucket only if they don't have any common join type descendants
if (!b.joins.exists(joins)) {
found = true
b.add(info)
b.joins ++= joins
}
}
if (!found) { // there's no bucket to add, create new bucket
val b = newBucket()
b.add(info)
buckets.append(b)
b.joins ++= joins
}
for (iface <- ifaces) {
joinsOf.getOrElseUpdate(iface, new mutable.HashSet()) ++= joins
}
spines.add(clazz.name.name)
} else if (ifaces.length > 1) { // join type, add to joins map, bucket assignment is done later
ifaces.foreach { iface =>
joinsOf.getOrElseUpdate(iface, new mutable.HashSet()) += clazz.name.name
}
}
// else: plain, do nothing
}

}

for (clazz <- classes) {
val info = getClassInfo(clazz.name.name)
val ifaces = getAllInterfaces(info)
if (ifaces.nonEmpty && !spines.contains(clazz.name.name)) {
val used = usedOf.getOrElse(clazz.name.name, new mutable.HashSet())
for {
iface <- ifaces
parentUsed <- usedOf.get(iface)
} { used ++= parentUsed }

var found = false
val bs = buckets.iterator
while (!found && bs.hasNext) {
val b = bs.next()
if (!used.contains(b)) {
found = true
b.add(info)
used.add(b)
}
}
if (!found) {
val b = newBucket()
buckets.append(b)
b.add(info)
used.add(b)
}
}
}
buckets.length
}
}

object WasmContext {
Expand All @@ -723,7 +855,8 @@ object WasmContext {
val isAbstract: Boolean,
val hasRuntimeTypeInfo: Boolean,
val jsNativeLoadSpec: Option[IRTrees.JSNativeLoadSpec],
val jsNativeMembers: Map[IRNames.MethodName, IRTrees.JSNativeLoadSpec]
val jsNativeMembers: Map[IRNames.MethodName, IRTrees.JSNativeLoadSpec],
private var _itableIdx: Int
) {
private val fieldIdxByName: Map[IRNames.FieldName, Int] =
allFieldDefs.map(_.name.name).zipWithIndex.map(p => p._1 -> (p._2 + classFieldOffset)).toMap
Expand Down Expand Up @@ -757,6 +890,12 @@ object WasmContext {

def hasInstances: Boolean = _hasInstances

def setItableIdx(idx: Int): Unit = _itableIdx = idx

/** Returns the index of this interface's itable in the classes' interface tables.
*/
def itableIdx: Int = _itableIdx

private var _specialInstanceTypes: Int = 0

def addSpecialInstanceType(jsValueType: Int): Unit =
Expand Down Expand Up @@ -890,4 +1029,11 @@ object WasmContext {
}

final class TableMethodInfo(val methodName: IRNames.MethodName, val tableIndex: Int)

private[WasmContext] class Bucket(idx: Int) {
def add(clazz: WasmClassInfo) = clazz.setItableIdx((idx))

/** A set of join types that are descendants of the types assigned to that bucket */
val joins = new mutable.HashSet[IRNames.ClassName]()
}
}