diff --git a/cli/src/main/scala/TestSuites.scala b/cli/src/main/scala/TestSuites.scala index 60f8d456..98f59e44 100644 --- a/cli/src/main/scala/TestSuites.scala +++ b/cli/src/main/scala/TestSuites.scala @@ -5,6 +5,7 @@ object TestSuites { val suites = List( TestSuite("testsuite.core.Simple"), TestSuite("testsuite.core.Add"), + TestSuite("testsuite.core.ArrayTest"), TestSuite("testsuite.core.VirtualDispatch"), TestSuite("testsuite.core.InterfaceCall"), TestSuite("testsuite.core.AsInstanceOfTest"), diff --git a/sample/src/main/scala/Sample.scala b/sample/src/main/scala/Sample.scala index c673e127..ed270d56 100644 --- a/sample/src/main/scala/Sample.scala +++ b/sample/src/main/scala/Sample.scala @@ -5,18 +5,14 @@ import scala.annotation.tailrec import scala.scalajs.js import scala.scalajs.js.annotation._ -// -// class Base { -// def sqrt(x: Int) = x * x -// } -// object Main { @JSExportTopLevel("test") def test(i: Int): Boolean = { val loopFib = fib(new LoopFib {}, i) val recFib = fib(new RecFib {}, i) val tailrecFib = fib(new TailRecFib {}, i) - js.Dynamic.global.console.log(s"loopFib: $loopFib -- recFib: $recFib -- tailrecFib: $tailrecFib") + js.Dynamic.global.console + .log(s"loopFib: $loopFib -- recFib: $recFib -- tailrecFib: $tailrecFib") val date = new js.Date(0) js.Dynamic.global.console.log(date) loopFib == recFib && loopFib == tailrecFib @@ -24,7 +20,6 @@ object Main { def fib(fib: Fib, n: Int): Int = fib.fib(n) } - trait LoopFib extends Fib { def fib(n: Int): Int = { var a = 0 @@ -61,41 +56,4 @@ trait TailRecFib extends Fib { trait Fib { def fib(n: Int): Int - // = { - // if (n <= 1) { - // n - // } else { - // fib(n - 1) + fib(n - 2) - // } - // } - } - -// -// -// object Bar { -// def bar(b: Base) = b.base -// } - -// class Base extends Incr { -// override def incr(x: Int) = foo(x) + 1 -// } -// -// trait Incr extends BaseTrait { -// // val one = 1 -// def incr(x: Int): Int -// } -// -// trait BaseTrait { -// def foo(x: Int) = x -// } - -// object Foo { -// def foo = -// Main.ident(1) -// } -// -// class Derived(override val i: Int) extends Base(i) { -// def derived(x: Int) = x * i -// override def base(x: Int): Int = x * i -// } diff --git a/test-suite/src/main/scala/testsuite/core/ArrayTest.scala b/test-suite/src/main/scala/testsuite/core/ArrayTest.scala new file mode 100644 index 00000000..fee0fac8 --- /dev/null +++ b/test-suite/src/main/scala/testsuite/core/ArrayTest.scala @@ -0,0 +1,36 @@ +package testsuite.core + +import testsuite.Assert + +object ArrayTest { + def main(): Unit = { + Assert.ok( + testLength() && testSelect() && testNew() + ) + } + + def testLength(): Boolean = { + Array(1, 2, 3).length == 3 && + (Array(Array(1, 2), Array(2), Array(3))).length == 3 + } + + def testSelect(): Boolean = { + val a = Array(Array(1), Array(2), Array(3)) + a(0)(0) == 1 && { + a(0)(0) = 100 // Assign(ArraySelect(...), ...) + a(0)(0) == 100 // ArraySelect(...) + } && { + a(1) = Array(1, 2, 3) + a(1).length == 3 && a(1)(0) == 1 + } + } + + def testNew(): Boolean = { + (Array.emptyBooleanArray.length == 0) && + (new Array[Int](10)).length == 10 && + (new Array[Int](1))(0) == 0 && + (new Array[Array[Array[Int]]](5))(0) == null + } + + // TODO: Array.ofDim[T](...) +} diff --git a/wasm/src/main/scala/ir2wasm/TypeTransformer.scala b/wasm/src/main/scala/ir2wasm/TypeTransformer.scala index ff5e19c9..e92671c5 100644 --- a/wasm/src/main/scala/ir2wasm/TypeTransformer.scala +++ b/wasm/src/main/scala/ir2wasm/TypeTransformer.scala @@ -15,7 +15,7 @@ object TypeTransformer { def transformFunctionType( // clazz: WasmContext.WasmClassInfo, method: WasmContext.WasmFunctionInfo - )(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionType = { + )(implicit ctx: TypeDefinableWasmContext): WasmFunctionType = { // val className = clazz.name val name = method.name val receiverType = makeReceiverType @@ -43,41 +43,35 @@ object TypeTransformer { t match { case IRTypes.AnyType => Types.WasmAnyRef - case tpe @ IRTypes.ArrayType(IRTypes.ArrayTypeRef(elemType, size)) => - // TODO - // val wasmElemTy = - // elemType match { - // case IRTypes.ClassRef(className) => - // // val gcTypeSym = context.gcTypes.reference(Ident(className.nameString)) - // Types.WasmRefType(Types.WasmHeapType.Type(Names.WasmGCTypeName.fromIR(className))) - // case IRTypes.PrimRef(tpe) => - // transform(tpe) - // } - // val field = WasmStructField("TODO", wasmElemTy, isMutable = false) - // val arrayTySym = - // context.gcTypes.define(WasmArrayType(Names.WasmGCTypeName.fromIR(tpe), field)) - // Types.WasmRefType(Types.WasmHeapType.Type(arrayTySym)) - ??? - case clazz @ IRTypes.ClassType(className) => - className match { - case _ => - val info = ctx.getClassInfo(clazz.className) - if (info.isAncestorOfHijackedClass) - Types.WasmAnyRef - else if (info.isInterface) - Types.WasmRefNullType(Types.WasmHeapType.ObjectType) - else - Types.WasmRefNullType( - Types.WasmHeapType.Type(Names.WasmTypeName.WasmStructTypeName(className)) - ) - } - case IRTypes.RecordType(fields) => ??? + case tpe: IRTypes.ArrayType => + Types.WasmRefNullType( + Types.WasmHeapType.Type(Names.WasmTypeName.WasmArrayTypeName(tpe.arrayTypeRef)) + ) + case IRTypes.ClassType(className) => transformClassByName(className) + case IRTypes.RecordType(fields) => ??? case IRTypes.StringType | IRTypes.UndefType => Types.WasmRefType.any case p: IRTypes.PrimTypeWithRef => transformPrimType(p) } - def transformPrimType( + private def transformClassByName( + className: IRNames.ClassName + )(implicit ctx: ReadOnlyWasmContext): Types.WasmType = { + className match { + case _ => + val info = ctx.getClassInfo(className) + if (info.isAncestorOfHijackedClass) + Types.WasmAnyRef + else if (info.isInterface) + Types.WasmRefNullType(Types.WasmHeapType.ObjectType) + else + Types.WasmRefNullType( + Types.WasmHeapType.Type(Names.WasmTypeName.WasmStructTypeName(className)) + ) + } + } + + private def transformPrimType( t: IRTypes.PrimTypeWithRef ): Types.WasmType = t match { diff --git a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala index 47a3f4e9..2fd1963b 100644 --- a/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala +++ b/wasm/src/main/scala/ir2wasm/WasmExpressionBuilder.scala @@ -17,10 +17,11 @@ import wasm4s.WasmImmediate._ import org.scalajs.ir.Types.ClassType import org.scalajs.ir.ClassKind import org.scalajs.ir.Position +import _root_.wasm4s.Defaults object WasmExpressionBuilder { def generateIRBody(tree: IRTrees.Tree, resultType: IRTypes.Type)(implicit - ctx: FunctionTypeWriterWasmContext, + ctx: TypeDefinableWasmContext, fctx: WasmFunctionContext ): Unit = { val builder = new WasmExpressionBuilder(ctx, fctx) @@ -40,8 +41,8 @@ object WasmExpressionBuilder { private object PrimTypeWithBoxUnbox { def unapply(primType: IRTypes.PrimTypeWithRef): Option[IRTypes.PrimTypeWithRef] = { primType match { - case IRTypes.BooleanType | IRTypes.ByteType | IRTypes.ShortType | - IRTypes.IntType | IRTypes.FloatType | IRTypes.DoubleType => + case IRTypes.BooleanType | IRTypes.ByteType | IRTypes.ShortType | IRTypes.IntType | + IRTypes.FloatType | IRTypes.DoubleType => Some(primType) case _ => None @@ -51,7 +52,7 @@ object WasmExpressionBuilder { } private class WasmExpressionBuilder private ( - ctx: FunctionTypeWriterWasmContext, + ctx: TypeDefinableWasmContext, fctx: WasmFunctionContext ) { import WasmExpressionBuilder._ @@ -142,6 +143,11 @@ private class WasmExpressionBuilder private ( case t: IRTrees.JSTypeOfGlobalRef => genJSTypeOfGlobalRef(t) case t: IRTrees.JSLinkingInfo => genJSLinkingInfo(t) + // array + case t: IRTrees.ArrayLength => genArrayLength(t) + case t: IRTrees.NewArray => genNewArray(t) + case t: IRTrees.ArraySelect => genArraySelect(t) + case t: IRTrees.ArrayValue => genArrayValue(t) case _ => println(tree) ??? @@ -151,23 +157,18 @@ private class WasmExpressionBuilder private ( // case IRTrees.RecordValue(pos) => // case IRTrees.JSNewTarget(pos) => // case IRTrees.SelectStatic(tpe) => - // case IRTrees.ArrayLength(pos) => // case IRTrees.JSSuperMethodCall(pos) => - // case IRTrees.NewArray(pos) => // case IRTrees.Match(tpe) => - // case IRTrees.Throw(pos) => // case IRTrees.Closure(pos) => // case IRTrees.RecordSelect(tpe) => // case IRTrees.TryFinally(pos) => // case IRTrees.JSImportMeta(pos) => // case IRTrees.JSSuperSelect(pos) => - // case IRTrees.ArraySelect(tpe) => // case IRTrees.WrapAsThrowable(pos) => // case IRTrees.JSSuperConstructorCall(pos) => // case IRTrees.Clone(pos) => // case IRTrees.CreateJSClass(pos) => // case IRTrees.Transient(pos) => - // case IRTrees.ArrayValue(pos) => // case IRTrees.ForIn(pos) => // case tc: IRTrees.TryCatch => ??? // case IRTrees.JSImportCall(pos) => @@ -238,7 +239,18 @@ private class WasmExpressionBuilder private ( genTree(t.rhs, t.lhs.tpe) instrs += STRUCT_SET(TypeIdx(WasmStructTypeName(className)), idx) - case assign: IRTrees.ArraySelect => ??? // array.set + case sel: IRTrees.ArraySelect => + val typeName = sel.array.tpe match { + case arrTy: IRTypes.ArrayType => WasmArrayTypeName(arrTy.arrayTypeRef) + case _ => + throw new IllegalArgumentException( + s"ArraySelect.array must be an array type, but has type ${sel.array.tpe}" + ) + } + genTreeAuto(sel.array) + genTree(sel.index, IRTypes.IntType) + genTree(t.rhs, t.lhs.tpe) + instrs += ARRAY_SET(TypeIdx(typeName)) case assign: IRTrees.RecordSelect => ??? // struct.set case assign: IRTrees.JSPrivateSelect => ??? @@ -279,7 +291,9 @@ private class WasmExpressionBuilder private ( // statically resolved call with non-null argument val receiverClassName = IRTypes.PrimTypeToBoxedClass(prim) genApplyStatically( - IRTrees.ApplyStatically(t.flags, t.receiver, receiverClassName, t.method, t.args)(t.tpe)(t.pos) + IRTrees.ApplyStatically(t.flags, t.receiver, receiverClassName, t.method, t.args)(t.tpe)( + t.pos + ) ) case IRTypes.ClassType(className) if IRNames.HijackedClasses.contains(className) => @@ -297,9 +311,9 @@ private class WasmExpressionBuilder private ( implicit val pos: Position = t.pos val receiverClassName = t.receiver.tpe match { - case ClassType(className) => className - case IRTypes.AnyType => IRNames.ObjectClass - case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}") + case ClassType(className) => className + case IRTypes.AnyType => IRNames.ObjectClass + case _ => throw new Error(s"Invalid receiver type ${t.receiver.tpe}") } val receiverClassInfo = ctx.getClassInfo(receiverClassName) @@ -461,21 +475,19 @@ private class WasmExpressionBuilder private ( } /** Generates a vtable- or itable-based dispatch. - * - * Before this code gen, the stack must contain the receiver and the args of - * the target method. In addition, the receiver must be available in the - * local `receiverLocalForDispatch`. The two occurrences of the receiver - * must have the type for dispatch. - * - * After this code gen, the stack contains the result. If the result type is - * `NothingType`, `genTableDispatch` leaves the stack in an arbitrary state. - * It is up to the caller to insert an `unreachable` instruction when - * appropriate. - */ + * + * Before this code gen, the stack must contain the receiver and the args of the target method. + * In addition, the receiver must be available in the local `receiverLocalForDispatch`. The two + * occurrences of the receiver must have the type for dispatch. + * + * After this code gen, the stack contains the result. If the result type is `NothingType`, + * `genTableDispatch` leaves the stack in an arbitrary state. It is up to the caller to insert an + * `unreachable` instruction when appropriate. + */ def genTableDispatch( - receiverClassInfo: WasmContext.WasmClassInfo, - methodName: IRNames.MethodName, - receiverLocalForDispatch: WasmLocalName + receiverClassInfo: WasmContext.WasmClassInfo, + methodName: IRNames.MethodName, + receiverLocalForDispatch: WasmLocalName ): Unit = { // Generates an itable-based dispatch. def genITableDispatch(): Unit = { @@ -992,7 +1004,9 @@ private class WasmExpressionBuilder private ( case IRTypes.NothingType => () // unreachable case IRTypes.NoType => - throw new AssertionError(s"Found expression of type void in String_+ at ${tree.pos}: $tree") + throw new AssertionError( + s"Found expression of type void in String_+ at ${tree.pos}: $tree" + ) } case IRTypes.ClassType(IRNames.BoxedStringClass) => @@ -1127,7 +1141,9 @@ private class WasmExpressionBuilder private ( case IRTypes.UndefType | IRTypes.StringType => () case PrimTypeWithBoxUnbox(primType) => - instrs += CALL(WasmImmediate.FuncIdx(WasmFunctionName.unboxOrNull(primType.primRef))) + instrs += CALL( + WasmImmediate.FuncIdx(WasmFunctionName.unboxOrNull(primType.primRef)) + ) case IRTypes.CharType => val structTypeName = WasmStructTypeName(SpecialNames.CharBoxClass) instrs += REF_CAST_NULL(HeapType(Types.WasmHeapType.Type(structTypeName))) @@ -1153,11 +1169,11 @@ private class WasmExpressionBuilder private ( } /** Unbox the `anyref` on the stack to the target `PrimType`. - * - * `targetTpe` must not be `NothingType`, `NullType` nor `NoType`. - * - * The type left on the stack is non-nullable. - */ + * + * `targetTpe` must not be `NothingType`, `NullType` nor `NoType`. + * + * The type left on the stack is non-nullable. + */ private def genUnbox(targetTpe: IRTypes.PrimType)(implicit pos: Position): Unit = { targetTpe match { case IRTypes.UndefType => @@ -1358,7 +1374,10 @@ private class WasmExpressionBuilder private ( } /** Codegen to box a primitive `char`/`long` into a `CharacterBox`/`LongBox`. */ - private def genBox(primType: IRTypes.PrimTypeWithRef, boxClassName: IRNames.ClassName): IRTypes.Type = { + private def genBox( + primType: IRTypes.PrimTypeWithRef, + boxClassName: IRNames.ClassName + ): IRTypes.Type = { // `primTyp` is `i32` for `char` (containing a `u16` value) or `i64` for `long`. val primTyp = TypeTransformer.transformType(primType)(ctx) val primLocal = WasmLocal(fctx.genSyntheticLocalName(), primTyp, isParameter = false) @@ -1455,14 +1474,16 @@ private class WasmExpressionBuilder private ( private def genSelectJSNativeMember(tree: IRTrees.SelectJSNativeMember): IRTypes.Type = { val info = ctx.getClassInfo(tree.className) - val jsNativeLoadSpec = info.jsNativeMembers.getOrElse(tree.member.name, { - throw new AssertionError(s"Found $tree for non-existing JS native member at ${tree.pos}") - }) + val jsNativeLoadSpec = info.jsNativeMembers.getOrElse( + tree.member.name, { + throw new AssertionError(s"Found $tree for non-existing JS native member at ${tree.pos}") + } + ) genLoadJSNativeLoadSpec(jsNativeLoadSpec)(tree.pos) } - private def genLoadJSNativeLoadSpec(loadSpec: IRTrees.JSNativeLoadSpec)( - implicit pos: Position + private def genLoadJSNativeLoadSpec(loadSpec: IRTrees.JSNativeLoadSpec)(implicit + pos: Position ): IRTypes.Type = { import IRTrees.JSNativeLoadSpec._ @@ -1578,4 +1599,57 @@ private class WasmExpressionBuilder private ( instrs += CALL(FuncIdx(WasmFunctionName.jsLinkingInfo)) IRTypes.AnyType } + + // =============================================================================== + // array + // =============================================================================== + private def genArrayLength(t: IRTrees.ArrayLength): IRTypes.Type = { + genTreeAuto(t.array) + instrs += ARRAY_LEN + IRTypes.IntType + } + + private def genNewArray(t: IRTrees.NewArray): IRTypes.Type = { + if (t.lengths.isEmpty || t.lengths.sizeIs > t.typeRef.dimensions) + throw new AssertionError( + s"invalid lengths ${t.lengths} for array type ${t.typeRef.displayName}" + ) + if (t.lengths.size == 1) { + val arrTy = ctx.getArrayType(t.typeRef) + val zero = Defaults.defaultValue(arrTy.field.typ) + ctx.inferTypeFromTypeRef(t.typeRef.base) + instrs += zero + genTree(t.lengths.head, IRTypes.IntType) + instrs += ARRAY_NEW(TypeIdx(arrTy.name)) + } else ??? // TODO support multi dimensional arrays + t.tpe + } + + /** For getting element from an array, array.set should be generated by transformation of + * `Assign(ArraySelect(...), ...)` + */ + private def genArraySelect(t: IRTrees.ArraySelect): IRTypes.Type = { + val irArrType = t.array.tpe match { + case t: IRTypes.ArrayType => t + case _ => + throw new IllegalArgumentException( + s"ArraySelect.array must be an array type, but has type ${t.array.tpe}" + ) + } + genTreeAuto(t.array) + genTree(t.index, IRTypes.IntType) + instrs += ARRAY_GET(TypeIdx(ctx.getArrayType(irArrType.arrayTypeRef).name)) + + val typeRef = irArrType.arrayTypeRef + if (typeRef.dimensions > 1) + IRTypes.ArrayType(typeRef.copy(dimensions = typeRef.dimensions - 1)) + else ctx.inferTypeFromTypeRef(typeRef.base) + } + + private def genArrayValue(t: IRTrees.ArrayValue): IRTypes.Type = { + val arrTy = ctx.getArrayType(t.typeRef) + t.elems.foreach(genTreeAuto) + instrs += ARRAY_NEW_FIXED(TypeIdx(arrTy.name), I32(t.elems.size)) + t.tpe + } } diff --git a/wasm/src/main/scala/wasm4s/Defaults.scala b/wasm/src/main/scala/wasm4s/Defaults.scala index 7104ab42..eecc7a77 100644 --- a/wasm/src/main/scala/wasm4s/Defaults.scala +++ b/wasm/src/main/scala/wasm4s/Defaults.scala @@ -6,8 +6,8 @@ import wasm.wasm4s.WasmInstr import wasm.wasm4s.WasmInstr._ object Defaults { - private def nonDefaultable(t: WasmType) = throw new Error(s"Non defaultable type: $t") - def defaultValue(t: WasmType): WasmInstr = t match { + private def nonDefaultable(t: WasmStorageType) = throw new Error(s"Non defaultable type: $t") + def defaultValue(t: WasmStorageType): WasmInstr = t match { case WasmUnreachableType => UNREACHABLE case WasmInt32 => I32_CONST(I32(0)) case WasmAnyRef => REF_NULL(HeapType(WasmHeapType.Simple.Any)) @@ -21,6 +21,8 @@ object Defaults { case WasmRefNullType(heapType) => REF_NULL(HeapType(heapType)) case WasmInt64 => I64_CONST(I64(0)) case WasmFloat64 => F64_CONST(F64(0)) - case WasmNoType => nonDefaultable(t) + case WasmInt16 => nonDefaultable(t) + case WasmInt8 => nonDefaultable(t) + case WasmNoType => nonDefaultable(t) } } diff --git a/wasm/src/main/scala/wasm4s/Names.scala b/wasm/src/main/scala/wasm4s/Names.scala index 42ce0585..06b845c2 100644 --- a/wasm/src/main/scala/wasm4s/Names.scala +++ b/wasm/src/main/scala/wasm4s/Names.scala @@ -250,6 +250,7 @@ object Names { val itable = new WasmFieldName("itable") val itables = new WasmFieldName("itables") val u16Array = new WasmFieldName("u16Array") + val arrayField = new WasmFieldName("array_field") // Fields of the typeData structs object typeData { @@ -336,11 +337,8 @@ object Names { final case class WasmArrayTypeName private (override private[wasm4s] val name: String) extends WasmTypeName(name) object WasmArrayTypeName { - def apply(ty: IRTypes.ArrayType) = { - val ref = ty.arrayTypeRef - // TODO: better naming? - new WasmArrayTypeName(s"${ref.base.displayName}_${ref.dimensions}") - } + def apply(typeRef: IRTypes.ArrayTypeRef) = + new WasmArrayTypeName(s"arrayOf_${typeRef.base.displayName}_${typeRef.dimensions}") val itables = new WasmArrayTypeName("itable") val u16Array = new WasmArrayTypeName("u16Array") } diff --git a/wasm/src/main/scala/wasm4s/Wasm.scala b/wasm/src/main/scala/wasm4s/Wasm.scala index cd20436f..3f869838 100644 --- a/wasm/src/main/scala/wasm4s/Wasm.scala +++ b/wasm/src/main/scala/wasm4s/Wasm.scala @@ -128,7 +128,7 @@ object WasmStructType { } case class WasmArrayType( - name: WasmTypeName, + name: WasmTypeName.WasmArrayTypeName, field: WasmStructField ) extends WasmGCTypeDefinition object WasmArrayType { @@ -176,6 +176,7 @@ object WasmElement { */ class WasmModule( private val _functionTypes: mutable.ListBuffer[WasmFunctionType] = new mutable.ListBuffer(), + private val _arrayTypes: mutable.Set[WasmArrayType] = new mutable.HashSet(), private val _recGroupTypes: mutable.ListBuffer[WasmStructType] = new mutable.ListBuffer(), // val importsInOrder: List[WasmNamedModuleField] = Nil, private val _imports: mutable.ListBuffer[WasmImport] = new mutable.ListBuffer(), @@ -197,6 +198,7 @@ class WasmModule( ) { def addImport(imprt: WasmImport): Unit = _imports.addOne(imprt) def addFunction(function: WasmFunction): Unit = _definedFunctions.addOne(function) + def addArrayType(typ: WasmArrayType): Unit = _arrayTypes.addOne(typ) def addFunctionType(typ: WasmFunctionType): Unit = _functionTypes.addOne(typ) def addRecGroupType(typ: WasmStructType): Unit = _recGroupTypes.addOne(typ) def addGlobal(typ: WasmGlobal): Unit = _globals.addOne(typ) @@ -206,7 +208,7 @@ class WasmModule( def functionTypes = _functionTypes.toList def recGroupTypes = WasmModule.tsort(_recGroupTypes.toList) - def arrayTypes = List(WasmArrayType.itables, WasmArrayType.u16Array) + def arrayTypes = List(WasmArrayType.itables, WasmArrayType.u16Array) ++ _arrayTypes.toList def imports = _imports.toList def definedFunctions = _definedFunctions.toList def globals = _globals.toList diff --git a/wasm/src/main/scala/wasm4s/WasmContext.scala b/wasm/src/main/scala/wasm4s/WasmContext.scala index 586e966a..4f47a150 100644 --- a/wasm/src/main/scala/wasm4s/WasmContext.scala +++ b/wasm/src/main/scala/wasm4s/WasmContext.scala @@ -108,11 +108,12 @@ trait ReadOnlyWasmContext { } } -trait FunctionTypeWriterWasmContext extends ReadOnlyWasmContext { this: WasmContext => +trait TypeDefinableWasmContext extends ReadOnlyWasmContext { this: WasmContext => protected val functionSignatures = LinkedHashMap.empty[WasmFunctionSignature, Int] protected val constantStringGlobals = LinkedHashMap.empty[String, WasmGlobalName] private var nextConstantStringIndex: Int = 1 + private var nextArrayTypeIndex: Int = 1 protected def addGlobal(g: WasmGlobal): Unit protected def addFuncDeclaration(name: WasmFunctionName): Unit @@ -163,9 +164,25 @@ trait FunctionTypeWriterWasmContext extends ReadOnlyWasmContext { this: WasmCont addFuncDeclaration(name) WasmInstr.REF_FUNC(WasmImmediate.FuncIdx(name)) } + + def getArrayType(typeRef: IRTypes.ArrayTypeRef): WasmArrayType = { + val elemTy = TypeTransformer.transformType(extractArrayElemType(typeRef))(this) + val arrTyName = Names.WasmTypeName.WasmArrayTypeName(typeRef) + val arrTy = WasmArrayType( + arrTyName, + WasmStructField(Names.WasmFieldName.arrayField, elemTy, isMutable = true) + ) + module.addArrayType(arrTy) + arrTy + } + + private def extractArrayElemType(typeRef: IRTypes.ArrayTypeRef): IRTypes.Type = { + if (typeRef.dimensions > 1) IRTypes.ArrayType(typeRef.copy(dimensions = typeRef.dimensions - 1)) + else inferTypeFromTypeRef(typeRef.base) + } } -class WasmContext(val module: WasmModule) extends FunctionTypeWriterWasmContext { +class WasmContext(val module: WasmModule) extends TypeDefinableWasmContext { import WasmContext._ private val _startInstructions: mutable.ListBuffer[WasmInstr] = new mutable.ListBuffer() @@ -419,7 +436,7 @@ object WasmContext { // flags: IRTrees.MemberFlags, isAbstract: Boolean ) { - def toWasmFunctionType()(implicit ctx: FunctionTypeWriterWasmContext): WasmFunctionType = + def toWasmFunctionType()(implicit ctx: TypeDefinableWasmContext): WasmFunctionType = TypeTransformer.transformFunctionType(this) }