Skip to content
This repository was archived by the owner on Jul 12, 2024. It is now read-only.

Commit 9a0d361

Browse files
committed
Implement GetClass.
When the value is one of our objects, we look at its `vtable` field, which is also the appropriate `typeData` reference. When it is a JS value, we implement the dispatch in a dedicated helper function.
1 parent 02494ac commit 9a0d361

File tree

11 files changed

+347
-19
lines changed

11 files changed

+347
-19
lines changed

Diff for: cli/src/main/scala/TestSuites.scala

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ object TestSuites {
99
TestSuite("testsuite.core.InterfaceCall"),
1010
TestSuite("testsuite.core.AsInstanceOfTest"),
1111
TestSuite("testsuite.core.ClassOfTest"),
12+
TestSuite("testsuite.core.GetClassTest"),
1213
TestSuite("testsuite.core.JSInteropTest"),
1314
TestSuite("testsuite.core.HijackedClassesDispatchTest"),
1415
TestSuite("testsuite.core.HijackedClassesMonoTest"),

Diff for: loader.mjs

+26
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,32 @@ const scalaJSHelpers = {
7777
stringConcat: (x, y) => ("" + x) + y, // the added "" is for the case where x === y === null
7878
isString: (x) => typeof x === 'string',
7979

80+
/* Get the type of JS value of `x` in a single JS helper call, for the purpose of dispatch.
81+
*
82+
* 0: false
83+
* 1: true
84+
* 2: string
85+
* 3: number
86+
* 4: undefined
87+
* 5: everything else
88+
*
89+
* This encoding has the following properties:
90+
*
91+
* - false and true also return their value as the appropriate i32.
92+
* - the types implementing `Comparable` are consecutive from 0 to 3.
93+
*/
94+
jsValueType: (x) => {
95+
if (typeof x === 'number')
96+
return 3;
97+
if (typeof x === 'string')
98+
return 2;
99+
if (typeof x === 'boolean')
100+
return x | 0;
101+
if (typeof x === 'undefined')
102+
return 4;
103+
return 5;
104+
},
105+
80106
// Hash code, because it is overridden in all hijacked classes
81107
// Specified by the hashCode() method of the corresponding hijacked classes
82108
jsValueHashCode: (x) => {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package testsuite.core
2+
3+
import testsuite.Assert.assertSame
4+
5+
object GetClassTest {
6+
def main(): Unit = {
7+
testNoHijackedDispatch()
8+
testWithHijackedDispath()
9+
}
10+
11+
class Foo
12+
13+
class Bar extends Foo
14+
15+
def testNoHijackedDispatch(): Unit = {
16+
def getClassOfFoo(x: Foo): Class[_] = x.getClass()
17+
18+
assertSame(classOf[Foo], getClassOfFoo(new Foo))
19+
assertSame(classOf[Bar], getClassOfFoo(new Bar))
20+
}
21+
22+
def testWithHijackedDispath(): Unit = {
23+
def getClassOf(x: Any): Class[_] = x.getClass()
24+
25+
assertSame(classOf[Foo], getClassOf(new Foo))
26+
assertSame(classOf[Bar], getClassOf(new Bar))
27+
28+
assertSame(classOf[java.lang.Boolean], getClassOf(true))
29+
assertSame(classOf[java.lang.Boolean], getClassOf(false))
30+
assertSame(classOf[java.lang.Void], getClassOf(()))
31+
assertSame(classOf[java.lang.String], getClassOf("foo"))
32+
33+
assertSame(classOf[java.lang.Byte], getClassOf(0.0))
34+
assertSame(classOf[java.lang.Byte], getClassOf(56))
35+
assertSame(classOf[java.lang.Byte], getClassOf(-128))
36+
assertSame(classOf[java.lang.Short], getClassOf(200))
37+
assertSame(classOf[java.lang.Short], getClassOf(-32000))
38+
assertSame(classOf[java.lang.Integer], getClassOf(500000))
39+
assertSame(classOf[java.lang.Integer], getClassOf(Int.MinValue))
40+
41+
assertSame(classOf[java.lang.Float], getClassOf(1.5))
42+
assertSame(classOf[java.lang.Double], getClassOf(1.4))
43+
assertSame(classOf[java.lang.Double], getClassOf(Float.MaxValue.toDouble * 8.0))
44+
45+
assertSame(classOf[java.lang.Float], getClassOf(-0.0))
46+
assertSame(classOf[java.lang.Float], getClassOf(Double.PositiveInfinity))
47+
assertSame(classOf[java.lang.Float], getClassOf(Double.NegativeInfinity))
48+
assertSame(classOf[java.lang.Float], getClassOf(Double.NaN))
49+
50+
assertSame(null, getClassOf(scala.scalajs.js.Math))
51+
}
52+
}

Diff for: wasm/src/main/scala/converters/WasmBinaryWriter.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ final class WasmBinaryWriter(module: WasmModule) {
278278

279279
case FuncIdx(value) => writeFuncIdx(buf, value)
280280
case labelIdx: LabelIdx => writeLabelIdx(buf, labelIdx)
281-
case LabelIdxVector(value) => ???
281+
case LabelIdxVector(value) => buf.vec(value)(writeLabelIdx(buf, _))
282282
case TypeIdx(value) => writeTypeIdx(buf, value)
283283
case TableIdx(value) => ???
284284
case TagIdx(value) => ???

Diff for: wasm/src/main/scala/converters/WasmTextWriter.scala

+12-2
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,19 @@ class WasmTextWriter {
223223
}
224224

225225
private def writeImmediate(i: WasmImmediate, instr: WasmInstr)(implicit b: WatBuilder): Unit = {
226+
def floatString(v: Double): String = {
227+
if (v.isNaN()) "nan"
228+
else if (v == Double.PositiveInfinity) "inf"
229+
else if (v == Double.NegativeInfinity) "-inf"
230+
else if (v.equals(-0.0)) "-0.0"
231+
else v.toString()
232+
}
233+
226234
val str = i match {
227235
case WasmImmediate.I64(v) => v.toString()
228236
case WasmImmediate.I32(v) => v.toString()
229-
case WasmImmediate.F64(v) => v.toString()
230-
case WasmImmediate.F32(v) => v.toString()
237+
case WasmImmediate.F64(v) => floatString(v)
238+
case WasmImmediate.F32(v) => floatString(v.toDouble)
231239
case WasmImmediate.LocalIdx(name) => name.show
232240
case WasmImmediate.GlobalIdx(name) => name.show
233241
case WasmImmediate.HeapType(ht) =>
@@ -245,6 +253,8 @@ class WasmTextWriter {
245253
case WasmImmediate.BlockType.ValueType(optTy) =>
246254
optTy.fold("") { ty => s"(result ${ty.show})" }
247255
case WasmImmediate.LabelIdx(i) => s"$$${i.toString}" // `loop 0` seems to be invalid
256+
case WasmImmediate.LabelIdxVector(indices) =>
257+
indices.map(i => "$" + i.value).mkString(" ")
248258
case i: WasmImmediate.CastFlags =>
249259
throw new UnsupportedOperationException(s"CastFlags $i must be handled directly in the instruction $instr")
250260
case _ =>

Diff for: wasm/src/main/scala/ir2wasm/HelperFunctions.scala

+214-13
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ object HelperFunctions {
1818
genCreateStringFromData()
1919
genTypeDataName()
2020
genCreateClassOf()
21+
genGetClassOf()
2122
genArrayTypeData()
2223
genGetComponentType()
24+
genAnyGetClass()
2325
}
2426

2527
/** `createStringFromData: (ref array u16) -> (ref any)` (representing a `string`). */
@@ -338,6 +340,43 @@ object HelperFunctions {
338340
fctx.buildAndAddToContext()
339341
}
340342

343+
/** `getClassOf: (ref typeData) -> (ref jlClass)`.
344+
*
345+
* Initializes the `java.lang.Class` instance associated with the given
346+
* `typeData` if not already done, and returns it.
347+
*
348+
* This includes the fast-path and the slow-path to `createClassOf`, for
349+
* call sites that are not performance-sensitive.
350+
*/
351+
private def genGetClassOf()(implicit ctx: WasmContext): Unit = {
352+
import WasmImmediate._
353+
import WasmTypeName.WasmStructTypeName
354+
355+
val typeDataType = WasmRefType(WasmHeapType.Type(WasmStructType.typeData.name))
356+
357+
val fctx = WasmFunctionContext(
358+
WasmFunctionName.getClassOf,
359+
List("typeData" -> typeDataType),
360+
List(WasmRefType(WasmHeapType.ClassType))
361+
)
362+
363+
val List(typeDataParam) = fctx.paramIndices
364+
365+
import fctx.instrs
366+
367+
fctx.block(WasmRefType(WasmHeapType.ClassType)) { alreadyInitializedLabel =>
368+
// fast path
369+
instrs += LOCAL_GET(typeDataParam)
370+
instrs += STRUCT_GET(TypeIdx(WasmStructTypeName.typeData), WasmFieldName.typeData.classOfIdx)
371+
instrs += BR_ON_NON_NULL(alreadyInitializedLabel)
372+
// slow path
373+
instrs += LOCAL_GET(typeDataParam)
374+
instrs += CALL(FuncIdx(WasmFunctionName.createClassOf))
375+
} // end bock alreadyInitializedLabel
376+
377+
fctx.buildAndAddToContext()
378+
}
379+
341380
/** `arrayTypeData: (ref typeData), i32 -> (ref typeData)`.
342381
*
343382
* Returns the typeData of an array with `dims` dimensions over the given
@@ -434,24 +473,186 @@ object HelperFunctions {
434473
val componentTypeDataLocal = fctx.addLocal("componentTypeData", typeDataType)
435474

436475
fctx.block() { nullResultLabel =>
437-
fctx.block(Types.WasmRefType(Types.WasmHeapType.ClassType)) { nonNullClassOfLabel =>
438-
// Try and extract non-null component type data
439-
instrs += LOCAL_GET(typeDataParam)
440-
instrs += STRUCT_GET(TypeIdx(WasmStructTypeName.typeData), WasmFieldName.typeData.componentTypeIdx)
441-
instrs += BR_ON_NULL(nullResultLabel)
442-
// fast path
443-
instrs += LOCAL_TEE(componentTypeDataLocal)
444-
instrs += STRUCT_GET(TypeIdx(WasmStructTypeName.typeData), WasmFieldName.typeData.classOfIdx)
445-
instrs += BR_ON_NON_NULL(nonNullClassOfLabel)
446-
// slow path
447-
instrs += LOCAL_GET(componentTypeDataLocal)
448-
instrs += CALL(FuncIdx(WasmFunctionName.createClassOf))
449-
} // end bock nonNullClassOfLabel
476+
// Try and extract non-null component type data
477+
instrs += LOCAL_GET(typeDataParam)
478+
instrs += STRUCT_GET(TypeIdx(WasmStructTypeName.typeData), WasmFieldName.typeData.componentTypeIdx)
479+
instrs += BR_ON_NULL(nullResultLabel)
480+
// Get the corresponding classOf
481+
instrs += CALL(FuncIdx(WasmFunctionName.getClassOf))
450482
instrs += RETURN
451483
} // end block nullResultLabel
452484
instrs += REF_NULL(HeapType(WasmHeapType.ClassType))
453485

454486
fctx.buildAndAddToContext()
455487
}
456488

489+
/** `anyGetClass: (ref any) -> (ref null jlClass)`.
490+
*
491+
* This is the implementation of `value.getClass()` when `value` can be an
492+
* instance of a hijacked class, i.e., a primitive.
493+
*
494+
* For `number`s, the result is based on the actual value, as specified by
495+
* [[https://www.scala-js.org/doc/semantics.html#getclass]].
496+
*/
497+
private def genAnyGetClass()(implicit ctx: WasmContext): Unit = {
498+
import WasmImmediate._
499+
import WasmTypeName.WasmStructTypeName
500+
501+
val typeDataType = WasmRefType(WasmHeapType.Type(WasmStructType.typeData.name))
502+
503+
val fctx = WasmFunctionContext(
504+
WasmFunctionName.anyGetClass,
505+
List("value" -> WasmRefType.any),
506+
List(WasmRefNullType(WasmHeapType.ClassType))
507+
)
508+
509+
val List(valueParam) = fctx.paramIndices
510+
511+
import fctx.instrs
512+
513+
val objectTypeIdx = TypeIdx(WasmStructTypeName(IRNames.ObjectClass))
514+
val typeDataLocal = fctx.addLocal("typeData", typeDataType)
515+
val doubleValueLocal = fctx.addLocal("doubleValue", WasmFloat64)
516+
val intValueLocal = fctx.addLocal("intValue", WasmInt32)
517+
518+
def getHijackedClassTypeDataInstr(className: IRNames.ClassName): WasmInstr =
519+
GLOBAL_GET(GlobalIdx(WasmGlobalName.WasmGlobalVTableName(IRTypes.ClassRef(className))))
520+
521+
fctx.block(WasmRefNullType(WasmHeapType.ClassType)) { nonNullClassOfLabel =>
522+
fctx.block(typeDataType) { gotTypeDataLabel =>
523+
fctx.block(WasmRefType(WasmHeapType.ObjectType)) { ourObjectLabel =>
524+
// if value is our object, jump to $ourObject
525+
instrs += LOCAL_GET(valueParam)
526+
instrs += BR_ON_CAST(
527+
CastFlags(false, false),
528+
ourObjectLabel,
529+
WasmImmediate.HeapType(WasmHeapType.Simple.Any),
530+
WasmImmediate.HeapType(WasmHeapType.ObjectType)
531+
)
532+
533+
// switch(jsValueType(value)) { ... }
534+
fctx.block() { typeOtherLabel =>
535+
fctx.block() { typeUndefinedLabel =>
536+
fctx.block() { typeNumberLabel =>
537+
fctx.block() { typeStringLabel =>
538+
fctx.block() { typeBooleanLabel =>
539+
instrs += LOCAL_GET(valueParam)
540+
instrs += CALL(FuncIdx(WasmFunctionName.jsValueType))
541+
instrs += BR_TABLE(LabelIdxVector(List(
542+
typeBooleanLabel, // 0
543+
typeBooleanLabel, // 1
544+
typeStringLabel, // 2
545+
typeNumberLabel, // 3
546+
typeUndefinedLabel, // 4
547+
)), typeOtherLabel)
548+
}
549+
550+
// typeBoolean:
551+
instrs += getHijackedClassTypeDataInstr(IRNames.BoxedBooleanClass)
552+
instrs += BR(gotTypeDataLabel)
553+
}
554+
555+
// typeString:
556+
instrs += getHijackedClassTypeDataInstr(IRNames.BoxedStringClass)
557+
instrs += BR(gotTypeDataLabel)
558+
}
559+
560+
/* typeNumber:
561+
* For `number`s, the result is based on the actual value, as specified by
562+
* [[https://www.scala-js.org/doc/semantics.html#getclass]].
563+
*/
564+
565+
// doubleValue := unboxDouble(value)
566+
instrs += LOCAL_GET(valueParam)
567+
instrs += CALL(FuncIdx(WasmFunctionName.unbox(IRTypes.DoubleRef)))
568+
instrs += LOCAL_TEE(doubleValueLocal)
569+
570+
// intValue := doubleValue.toInt
571+
instrs += I32_TRUNC_SAT_F64_S
572+
instrs += LOCAL_TEE(intValueLocal)
573+
574+
// if same(intValue.toDouble, doubleValue) -- same bit pattern to avoid +0.0 == -0.0
575+
instrs += F64_CONVERT_I32_S
576+
instrs += I64_REINTERPRET_F64
577+
instrs += LOCAL_GET(doubleValueLocal)
578+
instrs += I64_REINTERPRET_F64
579+
instrs += I64_EQ
580+
fctx.ifThenElse() {
581+
// then it is a Byte, a Short, or an Integer
582+
583+
// if intValue.toByte.toInt == intValue
584+
instrs += LOCAL_GET(intValueLocal)
585+
instrs += I32_EXTEND8_S
586+
instrs += LOCAL_GET(intValueLocal)
587+
instrs += I32_EQ
588+
fctx.ifThenElse() {
589+
// then it is a Byte
590+
instrs += getHijackedClassTypeDataInstr(IRNames.BoxedByteClass)
591+
instrs += BR(gotTypeDataLabel)
592+
} {
593+
// else, if intValue.toShort.toInt == intValue
594+
instrs += LOCAL_GET(intValueLocal)
595+
instrs += I32_EXTEND16_S
596+
instrs += LOCAL_GET(intValueLocal)
597+
instrs += I32_EQ
598+
fctx.ifThenElse() {
599+
// then it is a Short
600+
instrs += getHijackedClassTypeDataInstr(IRNames.BoxedShortClass)
601+
instrs += BR(gotTypeDataLabel)
602+
} {
603+
// else, it is an Integer
604+
instrs += getHijackedClassTypeDataInstr(IRNames.BoxedIntegerClass)
605+
instrs += BR(gotTypeDataLabel)
606+
}
607+
}
608+
} {
609+
// else, it is a Float or a Double
610+
611+
// if doubleValue.toFloat.toDouble == doubleValue
612+
instrs += LOCAL_GET(doubleValueLocal)
613+
instrs += F32_DEMOTE_F64
614+
instrs += F64_PROMOTE_F32
615+
instrs += LOCAL_GET(doubleValueLocal)
616+
instrs += F64_EQ
617+
fctx.ifThenElse() {
618+
// then it is a Float
619+
instrs += getHijackedClassTypeDataInstr(IRNames.BoxedFloatClass)
620+
instrs += BR(gotTypeDataLabel)
621+
} {
622+
// else, if it is NaN
623+
instrs += LOCAL_GET(doubleValueLocal)
624+
instrs += LOCAL_GET(doubleValueLocal)
625+
instrs += F64_NE
626+
fctx.ifThenElse() {
627+
// then it is a Float
628+
instrs += getHijackedClassTypeDataInstr(IRNames.BoxedFloatClass)
629+
instrs += BR(gotTypeDataLabel)
630+
} {
631+
// else, it is a Double
632+
instrs += getHijackedClassTypeDataInstr(IRNames.BoxedDoubleClass)
633+
instrs += BR(gotTypeDataLabel)
634+
}
635+
}
636+
}
637+
}
638+
639+
// typeUndefined:
640+
instrs += getHijackedClassTypeDataInstr(IRNames.BoxedUnitClass)
641+
instrs += BR(gotTypeDataLabel)
642+
}
643+
644+
// typeOther:
645+
instrs += REF_NULL(HeapType(WasmHeapType.ClassType))
646+
instrs += RETURN
647+
}
648+
649+
instrs += STRUCT_GET(objectTypeIdx, StructFieldIdx(0))
650+
}
651+
652+
instrs += CALL(FuncIdx(WasmFunctionName.getClassOf))
653+
}
654+
655+
fctx.buildAndAddToContext()
656+
}
657+
457658
}

Diff for: wasm/src/main/scala/ir2wasm/WasmBuilder.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ class WasmBuilder {
192192
// Declare the struct type for the class
193193
val vtableField = WasmStructField(
194194
Names.WasmFieldName.vtable,
195-
WasmRefNullType(WasmHeapType.Type(vtableType.name)),
195+
WasmRefType(WasmHeapType.Type(vtableType.name)),
196196
isMutable = false
197197
)
198198
val fields = clazz.fields.map(transformField)

0 commit comments

Comments
 (0)