diff --git a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala index 226e8dbf3fb4..e5baeffee70f 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Objects.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Objects.scala @@ -18,6 +18,7 @@ import util.{ SourcePosition, NoSourcePosition } import config.Printers.init as printer import reporting.StoreReporter import reporting.trace as log +import reporting.trace.force as forcelog import typer.Applications.* import Errors.* @@ -118,61 +119,83 @@ class Objects(using Context @constructorOnly): * regions ::= List(sourcePosition) */ - sealed abstract class Value: + sealed trait Value: def show(using Context): String - /** ValueElement are elements that can be contained in a RefSet */ - sealed abstract class ValueElement extends Value + /** ValueElement are elements that can be contained in a ValueSet */ + sealed trait ValueElement extends Value /** * A reference caches the values for outers and immutable fields. */ - sealed abstract class Ref( - valsMap: mutable.Map[Symbol, Value], - varsMap: mutable.Map[Symbol, Heap.Addr], - outersMap: mutable.Map[ClassSymbol, Value]) - extends ValueElement: - protected val vals: mutable.Map[Symbol, Value] = valsMap - protected val vars: mutable.Map[Symbol, Heap.Addr] = varsMap - protected val outers: mutable.Map[ClassSymbol, Value] = outersMap + sealed abstract class Scope(using trace: Trace): + // protected val vals: mutable.Map[Symbol, Value] = valsMap + // protected val vars: mutable.Map[Symbol, Heap.Addr] = varsMap + // protected val outers: mutable.Map[ClassSymbol, Value] = outersMap def isObjectRef: Boolean = this.isInstanceOf[ObjectRef] - def klass: ClassSymbol + def getTrace: Trace = trace + + def isRef = this.isInstanceOf[Ref] + + def isEnv = this.isInstanceOf[Env.Data] + + def meth: Symbol + + def owner: ClassSymbol + + def level: Int + + def show(using Context): String + + def valValue(sym: Symbol)(using Heap.MutableData): Value = Heap.readVal(this, sym) - def valValue(sym: Symbol): Value = vals(sym) + def varValue(sym: Symbol)(using Heap.MutableData): Value = Heap.readVal(this, sym) - def varAddr(sym: Symbol): Heap.Addr = vars(sym) + // def varAddr(sym: Symbol): Heap.Addr = vars(sym) - def outerValue(cls: ClassSymbol): Value = outers(cls) + def outerValue(sym: Symbol)(using Heap.MutableData): ScopeSet = Heap.readOuter(this, sym) - def hasVal(sym: Symbol): Boolean = vals.contains(sym) + def outer(using Heap.MutableData): ScopeSet = this.outerValue(meth) - def hasVar(sym: Symbol): Boolean = vars.contains(sym) + def hasVal(sym: Symbol)(using Heap.MutableData): Boolean = Heap.containsVal(this, sym) - def hasOuter(cls: ClassSymbol): Boolean = outers.contains(cls) + def hasVar(sym: Symbol)(using Heap.MutableData): Boolean = Heap.containsVal(this, sym) - def initVal(field: Symbol, value: Value)(using Context) = log("Initialize " + field.show + " = " + value + " for " + this, printer) { + def hasOuter(cls: ClassSymbol)(using Heap.MutableData): Boolean = Heap.containsOuter(this, cls) + + def initVal(field: Symbol, value: Value)(using Context, Heap.MutableData) = log("Initialize " + field.show + " = " + value + " for " + this, printer) { assert(!field.is(Flags.Mutable), "Field is mutable: " + field.show) - assert(!vals.contains(field), "Field already set: " + field.show) - vals(field) = value + Heap.writeJoinVal(this, field, value) } - def initVar(field: Symbol, addr: Heap.Addr)(using Context) = log("Initialize " + field.show + " = " + addr + " for " + this, printer) { + def initVar(field: Symbol, value: Value)(using Context, Heap.MutableData) = log("Initialize " + field.show + " = " + value + " for " + this, printer) { assert(field.is(Flags.Mutable), "Field is not mutable: " + field.show) - assert(!vars.contains(field), "Field already set: " + field.show) - vars(field) = addr + Heap.writeJoinVal(this, field, value) } - def initOuter(cls: ClassSymbol, value: Value)(using Context) = log("Initialize outer " + cls.show + " = " + value + " for " + this, printer) { - assert(!outers.contains(cls), "Outer already set: " + cls) - outers(cls) = value + def initOuter(sym: Symbol, outerScope: Scope)(using Context, Heap.MutableData) = log("Initialize outer " + sym.show + " = " + outerScope + " for " + this, printer) { + Heap.writeJoinOuter(this, sym, ScopeSet(Set(outerScope))) } + def initOuterSet(sym: Symbol, outerScopeSet: ScopeSet)(using Context, Heap.MutableData) = + Heap.writeJoinOuter(this, sym, outerScopeSet) + + sealed abstract class Ref(using Trace) extends Scope with ValueElement: + def klass: ClassSymbol + /** A reference to a static object */ - case class ObjectRef(klass: ClassSymbol) - extends Ref(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty, outersMap = mutable.Map.empty): - val owner = klass + case class ObjectRef( + klass: ClassSymbol + )(using @constructorOnly context: Context, @constructorOnly heap: Heap.MutableData, trace: Trace) extends Ref: + initOuter(klass, Env.NoEnv) + + def meth = klass + + def owner = klass + + def level = 1 def show(using Context) = "ObjectRef(" + klass.show + ")" @@ -182,25 +205,21 @@ class Objects(using Context @constructorOnly): * Note that the 2nd parameter block does not take part in the definition of equality. */ case class OfClass private ( - klass: ClassSymbol, outer: Value, ctor: Symbol, args: List[Value], env: Env.Data)( - valsMap: mutable.Map[Symbol, Value], varsMap: mutable.Map[Symbol, Heap.Addr], outersMap: mutable.Map[ClassSymbol, Value]) - extends Ref(valsMap, varsMap, outersMap): - def widenedCopy(outer: Value, args: List[Value], env: Env.Data): OfClass = - new OfClass(klass, outer, ctor, args, env)(this.valsMap, this.varsMap, this.outersMap) + klass: ClassSymbol, owner: ClassSymbol, ctor: Symbol, level: Int)(using Regions.Data, Trace) + extends Ref: + def meth = ctor def show(using Context) = - val valFields = vals.map(_.show + " -> " + _.show) - "OfClass(" + klass.show + ", outer = " + outer + ", args = " + args.map(_.show) + " env = " + env.show + ", vals = " + valFields + ")" + "OfClass(" + klass.show + ", owner = " + owner + ")" object OfClass: def apply( - klass: ClassSymbol, outer: Value, ctor: Symbol, args: List[Value], env: Env.Data)( - using Context + klass: ClassSymbol, outerScope: Env.Data, ctor: Symbol)( + using Context, Heap.MutableData, State.Data, Regions.Data, Trace ): OfClass = - val instance = new OfClass(klass, outer, ctor, args, env)( - valsMap = mutable.Map.empty, varsMap = mutable.Map.empty, outersMap = mutable.Map.empty - ) - instance.initOuter(klass, outer) + val owner = State.currentObject + val instance = new OfClass(klass, owner, ctor, outerScope.level + 1) + instance.initOuter(ctor, outerScope) instance /** @@ -215,9 +234,14 @@ class Objects(using Context @constructorOnly): * * @param owner The static object whose initialization creates the array. */ - case class OfArray(owner: ClassSymbol, regions: Regions.Data)(using @constructorOnly ctx: Context, @constructorOnly trace: Trace) extends ValueElement: - val klass: ClassSymbol = defn.ArrayClass - val addr: Heap.Addr = Heap.arrayAddr(regions, owner) + case class OfArray(owner: ClassSymbol, regions: Regions.Data)(using Context, Trace, Heap.MutableData) extends Ref: + def meth = defn.ArrayClass + def klass: ClassSymbol = defn.ArrayClass + + initOuter(klass, Env.NoEnv) + + def level = 1 + def outerScope = Env.NoEnv def show(using Context) = "OfArray(owner = " + owner.show + ")" /** @@ -225,7 +249,7 @@ class Objects(using Context @constructorOnly): * @param klass The enclosing class of the anonymous function's creation site */ case class Fun(code: Tree, thisV: ThisValue, klass: ClassSymbol, env: Env.Data) extends ValueElement: - def show(using Context) = "Fun(" + code.show + ", " + thisV.show + ", " + klass.show + ")" + def show(using Context) = "Fun(" + code.show + ", " + env.show + ", " + klass.show + ")" /** * Represents common base values like Int, String, etc. @@ -272,6 +296,11 @@ class Objects(using Context @constructorOnly): case class ValueSet(values: Set[ValueElement]) extends Value: def show(using Context) = values.map(_.show).mkString("[", ",", "]") + case class ScopeSet(scopes: Set[Scope]): + assert(!scopes.isEmpty, "Empty scope?") + assert(scopes.forall(_.isRef) || scopes.forall(_.isEnv), "All scopes should have the same type!") + def show(using Context) = scopes.map(_.show).mkString("[", ",", "]") + case class Package(packageModuleClass: ClassSymbol) extends Value: def show(using Context): String = "Package(" + packageModuleClass.show + ")" @@ -302,7 +331,7 @@ class Objects(using Context @constructorOnly): val Bottom = ValueSet(ListSet.empty) /** Possible types for 'this' */ - type ThisValue = Ref | Top.type + type ThisValue = Ref | Top.type | ValueSet /** Checking state */ object State: @@ -357,7 +386,7 @@ class Objects(using Context @constructorOnly): end doCheckObject def checkObjectAccess(clazz: ClassSymbol)(using data: Data, ctx: Context, pendingTrace: Trace): ObjectRef = - val index = data.checkingObjects.indexOf(ObjectRef(clazz)) + val index = data.checkingObjects.indexWhere(_.klass == clazz) if index != -1 then if data.checkingObjects.size - 1 > index then @@ -384,63 +413,25 @@ class Objects(using Context @constructorOnly): /** Environment for parameters */ object Env: - abstract class Data: - private[Env] def getVal(x: Symbol)(using Context): Option[Value] - private[Env] def getVar(x: Symbol)(using Context): Option[Heap.Addr] - - def widen(height: Int)(using Context): Data - + abstract class Data(using Trace) extends Scope: def level: Int - def show(using Context): String - /** Local environments can be deeply nested, therefore we need `outer`. * * For local variables in rhs of class field definitions, the `meth` is the primary constructor. */ - private case class LocalEnv - (private[Env] val params: Map[Symbol, Value], meth: Symbol, outer: Data) - (valsMap: mutable.Map[Symbol, Value], varsMap: mutable.Map[Symbol, Heap.Addr]) - (using Context) - extends Data: - val level = outer.level + 1 - + private case class LocalEnv(meth: Symbol, owner: ClassSymbol, level: Int)(using Trace) extends Data: if (level > 3) report.warning("[Internal error] Deeply nested environment, level = " + level + ", " + meth.show + " in " + meth.enclosingClass.show, meth.defTree) - private[Env] val vals: mutable.Map[Symbol, Value] = valsMap - private[Env] val vars: mutable.Map[Symbol, Heap.Addr] = varsMap - - private[Env] def getVal(x: Symbol)(using Context): Option[Value] = - if x.is(Flags.Param) then params.get(x) - else vals.get(x) - - private[Env] def getVar(x: Symbol)(using Context): Option[Heap.Addr] = - vars.get(x) - - def widen(height: Int)(using Context): Data = - new LocalEnv(params.map(_ -> _.widen(height)), meth, outer.widen(height))(this.vals, this.vars) - def show(using Context) = - "owner: " + meth.show + "\n" + - "params: " + params.map(_.show + " ->" + _.show).mkString("{", ", ", "}") + "\n" + - "vals: " + vals.map(_.show + " ->" + _.show).mkString("{", ", ", "}") + "\n" + - "vars: " + vars.map(_.show + " ->" + _).mkString("{", ", ", "}") + "\n" + - "outer = {\n" + outer.show + "\n}" - + "meth: " + meth.show + "\n" + + "owner: " + owner.show end LocalEnv - object NoEnv extends Data: + object NoEnv extends Data(using Trace.empty): val level = 0 - private[Env] def getVal(x: Symbol)(using Context): Option[Value] = - throw new RuntimeException("Invalid usage of non-existent env") - - private[Env] def getVar(x: Symbol)(using Context): Option[Heap.Addr] = - throw new RuntimeException("Invalid usage of non-existent env") - - def widen(height: Int)(using Context): Data = this - def show(using Context): String = "NoEnv" end NoEnv @@ -449,56 +440,66 @@ class Objects(using Context @constructorOnly): * The owner for the local environment for field initializers is the primary constructor of the * enclosing class. */ - def emptyEnv(meth: Symbol)(using Context): Data = - new LocalEnv(Map.empty, meth, NoEnv)(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty) + def emptyEnv(meth: Symbol)(using Context, State.Data, Heap.MutableData, Trace): Data = + _of(Map.empty, meth, NoEnv) - def valValue(x: Symbol)(using data: Data, ctx: Context, trace: Trace): Value = - data.getVal(x) match - case Some(theValue) => - theValue - case _ => + def valValue(x: Symbol)(using data: Data, ctx: Context, trace: Trace, heap: Heap.MutableData): Value = + if data.hasVal(x) then + data.valValue(x) + else report.warning("[Internal error] Value not found " + x.show + "\nenv = " + data.show + ". " + Trace.show, Trace.position) Bottom - def getVal(x: Symbol)(using data: Data, ctx: Context): Option[Value] = data.getVal(x) + def getVal(x: Symbol)(using data: Data, ctx: Context, heap: Heap.MutableData): Option[Value] = + if data.hasVal(x) then + Some(data.valValue(x)) + else + None - def getVar(x: Symbol)(using data: Data, ctx: Context): Option[Heap.Addr] = data.getVar(x) + def getVar(x: Symbol)(using data: Data, ctx: Context, heap: Heap.MutableData): Option[Value] = + if data.hasVar(x) then + Some(data.varValue(x)) + else + None - private[Env] def _of(argMap: Map[Symbol, Value], meth: Symbol, outer: Data): Data = - new LocalEnv(argMap, meth, outer)(valsMap = mutable.Map.empty, varsMap = mutable.Map.empty) + private[Env] def _of(argMap: Map[Symbol, Value], meth: Symbol, outer: Data) + (using State.Data, Heap.MutableData, Trace): Data = + val env = LocalEnv(meth, State.currentObject, outer.level + 1) + argMap.foreach(env.initVal(_, _)) + env.initOuter(meth, outer) + env - def ofDefDef(ddef: DefDef, args: List[Value], outer: Data)(using Context): Data = + def ofDefDef(ddef: DefDef, args: List[Value], outer: Data) + (using State.Data, Heap.MutableData, Trace): Data = val params = ddef.termParamss.flatten.map(_.symbol) assert(args.size == params.size, "arguments = " + args.size + ", params = " + params.size) assert(ddef.symbol.owner.isClass ^ (outer != NoEnv), "ddef.owner = " + ddef.symbol.owner.show + ", outer = " + outer + ", " + ddef.source) _of(params.zip(args).toMap, ddef.symbol, outer) - def ofByName(byNameParam: Symbol, outer: Data): Data = + def ofByName(byNameParam: Symbol, outer: Data)(using State.Data, Heap.MutableData, Trace): Data = assert(byNameParam.is(Flags.Param) && byNameParam.info.isInstanceOf[ExprType]); _of(Map.empty, byNameParam, outer) - def setLocalVal(x: Symbol, value: Value)(using data: Data, ctx: Context): Unit = + def setLocalVal(x: Symbol, value: Value)(using data: Data, ctx: Context, heap: Heap.MutableData): Unit = assert(!x.isOneOf(Flags.Param | Flags.Mutable), "Only local immutable variable allowed") data match case localEnv: LocalEnv => - assert(!localEnv.vals.contains(x), "Already initialized local " + x.show) - localEnv.vals(x) = value + localEnv.initVal(x, value) case _ => throw new RuntimeException("Incorrect local environment for initializing " + x.show) - def setLocalVar(x: Symbol, addr: Heap.Addr)(using data: Data, ctx: Context): Unit = + def setLocalVar(x: Symbol, value: Value)(using data: Data, ctx: Context, heap: Heap.MutableData): Unit = assert(x.is(Flags.Mutable, butNot = Flags.Param), "Only local mutable variable allowed") data match case localEnv: LocalEnv => - assert(!localEnv.vars.contains(x), "Already initialized local " + x.show) - localEnv.vars(x) = addr + localEnv.initVar(x, value) case _ => throw new RuntimeException("Incorrect local environment for initializing " + x.show) /** * Resolve the environment by searching for a given symbol. * - * Searches for the environment that owns `target`, starting from `env` as the innermost. + * Searches for the environment that defines `target`, starting from `env` as the innermost. * * Due to widening, the corresponding environment might not exist. As a result reading the local * variable will return `Cold` and it's forbidden to write to the local variable. @@ -509,27 +510,33 @@ class Objects(using Context @constructorOnly): * * @return the environment that owns the `target` and value for `this` owned by the given method. */ - def resolveEnvByValue(target: Symbol, thisV: ThisValue, env: Data)(using Context): Option[(ThisValue, Data)] = log("Resolving env by value for " + target.show + ", this = " + thisV.show + ", env = " + env.show, printer) { - env match - case localEnv: LocalEnv => - if localEnv.getVal(target).isDefined then Some(thisV -> localEnv) - else if localEnv.getVar(target).isDefined then Some(thisV -> localEnv) - else resolveEnvByValue(target, thisV, localEnv.outer) - case NoEnv => - thisV match - case ref: OfClass => - ref.outer match - case outer : ThisValue => - resolveEnvByValue(target, outer, ref.env) - case _ => - // TODO: properly handle the case where ref.outer is ValueSet - None - case _ => + def resolveEnvByValue(target: Symbol, thisV: ThisValue, env: Data) + (using Context, Heap.MutableData): Contextual[Option[(ThisValue, ScopeSet)]] = log("Resolving env by value for " + target.show + ", this = " + thisV.show + ", env = " + env.show, printer) { + def recur(thisV: ThisValue, scopeSet: ScopeSet): Option[(ThisValue, ScopeSet)] = + val head = scopeSet.scopes.head + if head.level == 0 then // all scopes are NoEnv None + else + val filter = scopeSet.scopes.filter(_.hasVal(target)) + assert(filter.isEmpty || filter.size == scopeSet.scopes.size, "Either all scopes or no scopes contain " + target) + if (!filter.isEmpty) then + Some(thisV, ScopeSet(filter)) + else if head.isRef then + val currentClass = head.asInstanceOf[Ref].klass + val outerClass = currentClass.owner.lexicallyEnclosingClass.asClass + val outerThis = resolveThis(outerClass, thisV, currentClass) + val outerScopes = scopeSet.scopes.map(_.outer).join + recur(outerThis, outerScopes) + else + val outerScopes = scopeSet.scopes.map(_.outer).join + recur(thisV, outerScopes) + end recur + + recur(thisV, ScopeSet(Set(env))) } /** - * Resolve the environment owned by the given method `enclosing`. + * Resolve the environment associated by the given method `enclosing`, starting from `env` as the innermost. * * The method could be located in outer scope with intermixed classes between its definition * site and usage site. @@ -544,7 +551,7 @@ class Objects(using Context @constructorOnly): * * @return the environment and value for `this` owned by the given method. */ - def resolveEnvByOwner(enclosing: Symbol, thisV: ThisValue, env: Data)(using Context): Option[(ThisValue, Data)] = log("Resolving env by owner for " + enclosing.show + ", this = " + thisV.show + ", env = " + env.show, printer) { + def resolveEnvByOwner(enclosing: Symbol, thisV: ThisValue, env: Data)(using Context, Heap.MutableData): Option[(ThisValue, Data)] = log("Resolving env by owner for " + enclosing.show + ", this = " + thisV.show + ", env = " + env.show, printer) { assert(enclosing.is(Flags.Method), "Only method symbols allows, got " + enclosing.show) env match case localEnv: LocalEnv => @@ -581,37 +588,111 @@ class Objects(using Context @constructorOnly): /** The address for mutable local variables . */ private case class LocalVarAddr(regions: Regions.Data, sym: Symbol, owner: ClassSymbol) extends Addr + private case class ScopeBody( + paramsMap: Map[Symbol, Value], + valsMap: Map[Symbol, Value], + outersMap: Map[Symbol, ScopeSet] + ) + + private def emptyScopeBody(): ScopeBody = ScopeBody( + paramsMap = Map.empty, + valsMap = Map.empty, + outersMap = Map.empty + ) + /** Immutable heap data used in the cache. * * We need to use structural equivalence so that in different iterations the cache can be effective. * * TODO: speed up equality check for heap. */ - opaque type Data = Map[Addr, Value] + opaque type Data = Map[Scope, ScopeBody] /** Store the heap as a mutable field to avoid threading it through the program. */ class MutableData(private[Heap] var heap: Data): - private[Heap] def writeJoin(addr: Addr, value: Value): Unit = - heap.get(addr) match + private[Heap] def writeJoinParam(scope: Scope, param: Symbol, value: Value): Unit = + heap.get(scope) match + case None => + heap = heap.updated(scope, Heap.emptyScopeBody()) + writeJoinParam(scope, param, value) + + case Some(current) => + val newParamsMap = current.paramsMap.join(param, value) + heap = heap.updated(scope, new ScopeBody( + paramsMap = newParamsMap, + valsMap = current.valsMap, + outersMap = current.outersMap + )) + + private[Heap] def writeJoinVal(scope: Scope, valSymbol: Symbol, value: Value): Unit = + heap.get(scope) match case None => - heap = heap.updated(addr, value) + heap = heap.updated(scope, Heap.emptyScopeBody()) + writeJoinVal(scope, valSymbol, value) case Some(current) => - val value2 = value.join(current) - if value2 != current then - heap = heap.updated(addr, value2) + val newValsMap = current.valsMap.join(valSymbol, value) + heap = heap.updated(scope, new ScopeBody( + paramsMap = current.paramsMap, + valsMap = newValsMap, + outersMap = current.outersMap + )) + + private[Heap] def writeJoinOuter(scope: Scope, outerSymbol: Symbol, outerScope: ScopeSet): Unit = + heap.get(scope) match + case None => + heap = heap.updated(scope, Heap.emptyScopeBody()) + writeJoinOuter(scope, outerSymbol, outerScope) + + case Some(current) => + val newOutersMap = current.outersMap.join(outerSymbol, outerScope) + heap = heap.updated(scope, new ScopeBody( + paramsMap = current.paramsMap, + valsMap = current.valsMap, + outersMap = newOutersMap + )) end MutableData def empty(): MutableData = new MutableData(Map.empty) - def contains(addr: Addr)(using mutable: MutableData): Boolean = - mutable.heap.contains(addr) + def contains(scope: Scope)(using mutable: MutableData): Boolean = + mutable.heap.contains(scope) + + def containsParam(scope: Scope, param: Symbol)(using mutable: MutableData): Boolean = + if mutable.heap.contains(scope) then + mutable.heap(scope).paramsMap.contains(param) + else + false + + def containsVal(scope: Scope, value: Symbol)(using mutable: MutableData): Boolean = + if mutable.heap.contains(scope) then + mutable.heap(scope).valsMap.contains(value) + else + false + + def containsOuter(scope: Scope, outer: Symbol)(using mutable: MutableData): Boolean = + if mutable.heap.contains(scope) then + mutable.heap(scope).outersMap.contains(outer) + else + false + + def readParam(scope: Scope, param: Symbol)(using mutable: MutableData): Value = + mutable.heap(scope).paramsMap(param) + + def readVal(scope: Scope, value: Symbol)(using mutable: MutableData): Value = + mutable.heap(scope).valsMap(value) + + def readOuter(scope: Scope, outer: Symbol)(using mutable: MutableData): ScopeSet = + mutable.heap(scope).outersMap(outer) - def read(addr: Addr)(using mutable: MutableData): Value = - mutable.heap(addr) + def writeJoinParam(scope: Scope, param: Symbol, value: Value)(using mutable: MutableData): Unit = + mutable.writeJoinParam(scope, param, value) - def writeJoin(addr: Addr, value: Value)(using mutable: MutableData): Unit = - mutable.writeJoin(addr, value) + def writeJoinVal(scope: Scope, valSymbol: Symbol, value: Value)(using mutable: MutableData): Unit = + mutable.writeJoinParam(scope, valSymbol, value) + + def writeJoinOuter(scope: Scope, outer: Symbol, outerScope: ScopeSet)(using mutable: MutableData): Unit = + mutable.writeJoinOuter(scope, outer, outerScope) def localVarAddr(regions: Regions.Data, sym: Symbol, owner: ClassSymbol): Addr = LocalVarAddr(regions, sym, owner) @@ -691,22 +772,28 @@ class Objects(using Context @constructorOnly): case class ArgInfo(value: Value, trace: Trace, tree: Tree) - extension (a: Value) - def join(b: Value): Value = - assert(!a.isInstanceOf[Package] && !b.isInstanceOf[Package], "Unexpected join between " + a + " and " + b) - (a, b) match - case (Top, _) => Top - case (_, Top) => Top - case (UnknownValue, _) => UnknownValue - case (_, UnknownValue) => UnknownValue - case (Bottom, b) => b - case (a, Bottom) => a - case (ValueSet(values1), ValueSet(values2)) => ValueSet(values1 ++ values2) - case (a : ValueElement, ValueSet(values)) => ValueSet(values + a) - case (ValueSet(values), b : ValueElement) => ValueSet(values + b) - case (a : ValueElement, b : ValueElement) => ValueSet(Set(a, b)) - case _ => Bottom + trait Join[V]: + extension (v1: V) + def join(v2: V): V + + given Join[Value] with + extension (a: Value) + def join(b: Value): Value = + assert(!a.isInstanceOf[Package] && !b.isInstanceOf[Package], "Unexpected join between " + a + " and " + b) + (a, b) match + case (Top, _) => Top + case (_, Top) => Top + case (UnknownValue, _) => UnknownValue + case (_, UnknownValue) => UnknownValue + case (Bottom, b) => b + case (a, Bottom) => a + case (ValueSet(values1), ValueSet(values2)) => ValueSet(values1 ++ values2) + case (a : ValueElement, ValueSet(values)) => ValueSet(values + a) + case (ValueSet(values), b : ValueElement) => ValueSet(values + b) + case (a : ValueElement, b : ValueElement) => ValueSet(Set(a, b)) + case _ => Bottom + extension (a: Value) def remove(b: Value): Value = (a, b) match case (ValueSet(values1), b: ValueElement) => ValueSet(values1 - b) case (ValueSet(values1), ValueSet(values2)) => ValueSet(values1.removedAll(values2)) @@ -715,27 +802,6 @@ class Objects(using Context @constructorOnly): case (a: Package, b: Package) if a == b => Bottom case _ => a - def widen(height: Int)(using Context): Value = log("widening value " + a.show + " down to height " + height, printer, (_: Value).show) { - if height == 0 then Top - else - a match - case Bottom => Bottom - - case ValueSet(values) => - values.map(ref => ref.widen(height)).join - - case Fun(code, thisV, klass, env) => - Fun(code, thisV.widenThisValue(height), klass, env.widen(height - 1)) - - case ref @ OfClass(klass, outer, _, args, env) => - val outer2 = outer.widen(height - 1) - val args2 = args.map(_.widen(height - 1)) - val env2 = env.widen(height - 1) - ref.widenedCopy(outer2, args2, env2) - - case _ => a - } - def filterType(tpe: Type)(using Context): Value = tpe match case t @ SAMType(_, _) if a.isInstanceOf[Fun] => a // if tpe is SAMType and a is Fun, allow it @@ -763,15 +829,25 @@ class Objects(using Context @constructorOnly): case fun: Fun => if klass.isOneOf(AbstractOrTrait) && klass.baseClasses.exists(defn.isFunctionClass) then fun else Bottom - extension (value: ThisValue) - def widenThisValue(height : Int)(using Context) : ThisValue = - assert(height > 0, "Cannot call widenThisValue with height 0!") - value.widen(height).asInstanceOf[ThisValue] + given Join[ScopeSet] with + extension (a: ScopeSet) + def join(b: ScopeSet): ScopeSet = + assert(!a.scopes.isEmpty && !b.scopes.isEmpty && a.scopes.head.level == b.scopes.head.level, + "Invalid join on scopes!") + ScopeSet(a.scopes ++ b.scopes) extension (values: Iterable[Value]) def join: Value = if values.isEmpty then Bottom else values.reduce { (v1, v2) => v1.join(v2) } - def widen(height: Int): Contextual[List[Value]] = values.map(_.widen(height)).toList + extension (scopes: Iterable[ScopeSet]) + def join: ScopeSet = scopes.reduce { (s1, s2) => s1.join(s2) } + + // def widen(height: Int): Contextual[List[V]] = values.map(_.widen(height)).toList + + extension [V : Join](map: Map[Symbol, V]) + def join(sym: Symbol, value: V): Map[Symbol, V] = + if !map.contains(sym) then map.updated(sym, value) + else map.updated(sym, map(sym).join(value)) /** Check if the checker option reports warnings about unknown code */ @@ -793,7 +869,7 @@ class Objects(using Context @constructorOnly): * @param superType The type of the super in a super call. NoType for non-super calls. * @param needResolve Whether the target of the call needs resolution? */ - def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = log("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) { + def call(value: Value, meth: Symbol, args: List[ArgInfo], receiver: Type, superType: Type, needResolve: Boolean = true): Contextual[Value] = forcelog("call " + meth.show + ", this = " + value.show + ", args = " + args.map(_.value.show), printer, (_: Value).show) { value.filterClass(meth.owner) match case Top => report.warning("Value is unknown to the checker due to widening. " + Trace.show, Trace.position) @@ -842,7 +918,7 @@ class Objects(using Context @constructorOnly): val ddef = target.defTree.asInstanceOf[DefDef] val cls = target.owner.enclosingClass.asClass // convert SafeType to an OfClass before analyzing method body - val ref = OfClass(cls, Bottom, NoSymbol, Nil, Env.NoEnv) + val ref = OfClass(cls, Env.NoEnv, NoSymbol) call(ref, meth, args, receiver, superType, needResolve) case Bottom => @@ -856,15 +932,15 @@ class Objects(using Context @constructorOnly): val target = resolve(defn.ArrayClass, meth) if target == defn.Array_apply || target == defn.Array_clone then - if arr.addr.owner == State.currentObject then - Heap.read(arr.addr) + if arr.owner == State.currentObject then + arr.valValue() else - errorReadOtherStaticObject(State.currentObject, arr.addr) + errorReadOtherStaticObject(State.currentObject, arr) Bottom else if target == defn.Array_update then assert(args.size == 2, "Incorrect number of arguments for Array update, found = " + args.size) - if arr.addr.owner != State.currentObject then - errorMutateOtherStaticObject(State.currentObject, arr.addr) + if arr.owner != State.currentObject then + errorMutateOtherStaticObject(State.currentObject, arr) else Heap.writeJoin(arr.addr, args.tail.head.value) Bottom @@ -1034,11 +1110,10 @@ class Objects(using Context @constructorOnly): def isNextFieldOfColonColon: Boolean = ref.klass == defn.ConsClass && target.name.toString == "next" if target.isMutableVarOrAccessor && !isNextFieldOfColonColon then if ref.hasVar(target) then - val addr = ref.varAddr(target) - if addr.owner == State.currentObject then - Heap.read(addr) + if ref.owner == State.currentObject then + ref.varValue(target) else - errorReadOtherStaticObject(State.currentObject, addr) + errorReadOtherStaticObject(State.currentObject, ref) Bottom else if ref.isObjectRef && ref.klass.hasSource then report.warning("Access uninitialized field " + field.show + ". " + Trace.show, Trace.position) @@ -1106,11 +1181,10 @@ class Objects(using Context @constructorOnly): case ref: Ref => if ref.hasVar(field) then - val addr = ref.varAddr(field) - if addr.owner != State.currentObject then - errorMutateOtherStaticObject(State.currentObject, addr) + if ref.owner != State.currentObject then + errorMutateOtherStaticObject(State.currentObject, ref) else - Heap.writeJoin(addr, rhs) + Heap.writeJoinVal(ref, field, rhs) else report.warning("Mutating a field before its initialization: " + field.show + ". " + Trace.show, Trace.position) end match @@ -1159,12 +1233,12 @@ class Objects(using Context @constructorOnly): report.warning("[Internal error] top-level class should have `Package` as outer, class = " + klass.show + ", outer = " + outer.show + ", " + Trace.show, Trace.position) (Bottom, Env.NoEnv) else - (thisV.widenThisValue(1), Env.NoEnv) + (thisV, Env.NoEnv) else // klass.enclosingMethod returns its primary constructor Env.resolveEnvByOwner(klass.owner.enclosingMethod, thisV, summon[Env.Data]).getOrElse(UnknownValue -> Env.NoEnv) - val instance = OfClass(klass, outerWidened, ctor, args.map(_.value), envWidened) + val instance = OfClass(klass, envWidened, ctor) callConstructor(instance, ctor, args) case ValueSet(values) => @@ -1178,9 +1252,7 @@ class Objects(using Context @constructorOnly): */ def initLocal(sym: Symbol, value: Value): Contextual[Unit] = log("initialize local " + sym.show + " with " + value.show, printer) { if sym.is(Flags.Mutable) then - val addr = Heap.localVarAddr(summon[Regions.Data], sym, State.currentObject) - Env.setLocalVar(sym, addr) - Heap.writeJoin(addr, value) + Env.setLocalVar(sym, value) else Env.setLocalVal(sym, value) } @@ -1195,30 +1267,23 @@ class Objects(using Context @constructorOnly): // Can't use enclosingMethod here because values defined in a by-name closure will have the wrong enclosingMethod, // since our phase is before elimByName. Env.resolveEnvByValue(sym, thisV, summon[Env.Data]) match - case Some(thisV -> env) => + case Some(thisV -> scopeSet) => if sym.is(Flags.Mutable) then // Assume forward reference check is doing a good job - given Env.Data = env - Env.getVar(sym) match - case Some(addr) => - if addr.owner == State.currentObject then - Heap.read(addr) - else - errorReadOtherStaticObject(State.currentObject, addr) - Bottom - end if - case _ => - // Only vals can be lazy - report.warning("[Internal error] Variable not found " + sym.show + "\nenv = " + env.show + ". " + Trace.show, Trace.position) + val scopesOwnedByOthers = scopeSet.scopes.filter(_.owner != State.currentObject) + if scopesOwnedByOthers.isEmpty then + scopeSet.scopes.map(_.varValue(sym)).join + else + errorReadOtherStaticObject(State.currentObject, scopesOwnedByOthers.head) Bottom + end if else - given Env.Data = env if sym.is(Flags.Lazy) then val rhs = sym.defTree.asInstanceOf[ValDef].rhs eval(rhs, thisV, sym.enclosingClass.asClass, cacheResult = true) else // Assume forward reference check is doing a good job - val value = Env.valValue(sym) + val value = scopeSet.scopes.map(_.varValue(sym)).join if isByNameParam(sym) then value match case fun: Fun => @@ -1254,16 +1319,12 @@ class Objects(using Context @constructorOnly): // Can't use enclosingMethod here because values defined in a by-name closure will have the wrong enclosingMethod, // since our phase is before elimByName. Env.resolveEnvByValue(sym, thisV, summon[Env.Data]) match - case Some(thisV -> env) => - given Env.Data = env - Env.getVar(sym) match - case Some(addr) => - if addr.owner != State.currentObject then - errorMutateOtherStaticObject(State.currentObject, addr) - else - Heap.writeJoin(addr, value) - case _ => - report.warning("[Internal error] Variable not found " + sym.show + "\nenv = " + env.show + ". " + Trace.show, Trace.position) + case Some(thisV -> scopeSet) => + val scopesOwnedByOthers = scopeSet.scopes.filter(_.owner != State.currentObject) + if !scopesOwnedByOthers.isEmpty then + errorMutateOtherStaticObject(State.currentObject, scopesOwnedByOthers.head) + else + scopeSet.scopes.foreach(Heap.writeJoinVal(_, sym, value)) case _ => report.warning("Assigning to variables in outer scope. " + Trace.show, Trace.position) @@ -1274,7 +1335,7 @@ class Objects(using Context @constructorOnly): // -------------------------------- algorithm -------------------------------- /** Check an individual object */ - private def accessObject(classSym: ClassSymbol)(using Context, State.Data, Trace): ObjectRef = log("accessing " + classSym.show, printer, (_: Value).show) { + private def accessObject(classSym: ClassSymbol)(using Context, State.Data, Trace, Heap.MutableData): ObjectRef = log("accessing " + classSym.show, printer, (_: Value).show) { if classSym.hasSource then State.checkObjectAccess(classSym) else @@ -1282,7 +1343,7 @@ class Objects(using Context @constructorOnly): } - def checkClasses(classes: List[ClassSymbol])(using Context): Unit = + def checkClasses(classes: List[ClassSymbol])(using Context, Heap.MutableData) = given State.Data = new State.Data given Trace = Trace.empty @@ -1444,12 +1505,11 @@ class Objects(using Context @constructorOnly): extendTrace(id) { evalType(prefix, thisV, klass) } val value = eval(rhs, thisV, klass) - val widened = widenEscapedValue(value, rhs) if isLocal then - writeLocal(thisV, lhs.symbol, widened) + writeLocal(thisV, lhs.symbol, value) else - withTrace(trace2) { assign(receiver, lhs.symbol, widened, rhs.tpe) } + withTrace(trace2) { assign(receiver, lhs.symbol, value, rhs.tpe) } case closureDef(ddef) => Fun(ddef, thisV, klass, summon[Env.Data]) @@ -1809,36 +1869,6 @@ class Objects(using Context @constructorOnly): throw new Exception("unexpected type: " + tp + ", Trace:\n" + Trace.show) } - /** Widen the escaped value (a method argument or rhs of an assignment) - * - * The default widening is 1 for most values, 2 for function values. - * User-specified widening annotations are repected. - */ - def widenEscapedValue(value: Value, annotatedTree: Tree): Contextual[Value] = - def parseAnnotation: Option[Int] = - annotatedTree.tpe.getAnnotation(defn.InitWidenAnnot).flatMap: annot => - annot.argument(0).get match - case arg @ Literal(c: Constants.Constant) => - val height = c.intValue - if height < 0 then - report.warning("The argument should be positive", arg) - None - else - Some(height) - case arg => - report.warning("The argument should be a constant integer value", arg) - None - end parseAnnotation - - parseAnnotation match - case Some(i) => - value.widen(i) - - case None => - if value.isInstanceOf[Fun] - then value.widen(2) - else value.widen(1) - /** Evaluate arguments of methods and constructors */ def evalArgs(args: List[Arg], thisV: ThisValue, klass: ClassSymbol): Contextual[List[ArgInfo]] = val argInfos = new mutable.ArrayBuffer[ArgInfo] @@ -1849,8 +1879,7 @@ class Objects(using Context @constructorOnly): else eval(arg.tree, thisV, klass) - val widened = widenEscapedValue(res, arg.tree) - argInfos += ArgInfo(widened, trace.add(arg.tree), arg.tree) + argInfos += ArgInfo(res, trace.add(arg.tree), arg.tree) } argInfos.toList @@ -1869,9 +1898,7 @@ class Objects(using Context @constructorOnly): klass.paramGetters.foreach { acc => val value = paramsMap(acc.name.toTermName) if acc.is(Flags.Mutable) then - val addr = Heap.fieldVarAddr(summon[Regions.Data], acc, State.currentObject) - thisV.initVar(acc, addr) - Heap.writeJoin(addr, value) + thisV.initVar(acc, value) else thisV.initVal(acc, value) printer.println(acc.show + " initialized with " + value) @@ -1884,7 +1911,15 @@ class Objects(using Context @constructorOnly): val cls = tref.classSymbol.asClass // update outer for super class val res = outerValue(tref, thisV, klass) - thisV.initOuter(cls, res) + res match { + case ref: Ref => thisV.initOuter(cls, ref) + case ValueSet(values) if values.forall(_.isInstanceOf[Ref]) => + thisV.initOuterSet(cls, ScopeSet(values.map(_.asInstanceOf[Ref]))) + case _ => + val error = "[Internal error] Invalid outer value, cls = " + cls + ", value = " + res + Trace.show + report.warning(error, Trace.position) + return + } // follow constructor if cls.hasSource then @@ -1964,9 +1999,7 @@ class Objects(using Context @constructorOnly): val sym = vdef.symbol val res = if (allowList.contains(sym)) Bottom else eval(vdef.rhs, thisV, klass) if sym.is(Flags.Mutable) then - val addr = Heap.fieldVarAddr(summon[Regions.Data], sym, State.currentObject) - thisV.initVar(sym, addr) - Heap.writeJoin(addr, res) + thisV.initVar(sym, res) else thisV.initVal(sym, res) @@ -1990,31 +2023,37 @@ class Objects(using Context @constructorOnly): * Object access elision happens when the object access is used as a prefix * in `new o.C` and `C` does not need an outer. */ - def resolveThis(target: ClassSymbol, thisV: Value, klass: ClassSymbol, elideObjectAccess: Boolean = false): Contextual[Value] = log("resolveThis target = " + target.show + ", this = " + thisV.show, printer, (_: Value).show) { - if target == klass then - thisV - else if target.is(Flags.Package) then - Package(target) // TODO: What is the semantics for package.this? + def resolveThis(target: ClassSymbol, thisV: ThisValue, klass: ClassSymbol, elideObjectAccess: Boolean = false): Contextual[ThisValue] = log("resolveThis target = " + target.show + ", this = " + thisV.show, printer, (_: Value).show) { + def recur(scopeSet: ScopeSet): ThisValue = + val head = scopeSet.scopes.head + if head.isInstanceOf[Ref] then + val klass = head.asInstanceOf[Ref].klass + assert(scopeSet.scopes.forall(_.asInstanceOf[Ref].klass == klass), "Multiple possible outer class?") + if klass == target then + ValueSet(scopeSet.scopes.map(_.asInstanceOf[Ref])) + else + recur(scopeSet.scopes.map(_.outer).join) + else + recur(scopeSet.scopes.map(_.outer).join) + end recur + + if target.is(Flags.Package) then + val error = "[Internal error] target cannot be packages, target = " + target + ", klass = " + klass + Trace.show + report.warning(error, Trace.position) + Bottom else if target.isStaticObject then val res = ObjectRef(target.moduleClass.asClass) if elideObjectAccess then res else accessObject(target) else thisV match - case Bottom => Bottom - case UnknownValue => UnknownValue case Top => Top + case Bottom => Bottom case ref: Ref => - val outerCls = klass.owner.lexicallyEnclosingClass.asClass - if !ref.hasOuter(klass) then - val error = "[Internal error] outer not yet initialized, target = " + target + ", klass = " + klass + Trace.show - report.warning(error, Trace.position) - Bottom - else - resolveThis(target, ref.outerValue(klass), outerCls) - case ValueSet(values) => - values.map(ref => resolveThis(target, ref, klass)).join - case _: Fun | _ : OfArray | _: Package | SafeValue(_) => + recur(ScopeSet(Set(ref))) + case ValueSet(values) if values.forall(_.isInstanceOf[Ref]) => + recur(ScopeSet(values.map(_.asInstanceOf[Ref]))) + case _ => report.warning("[Internal error] unexpected thisV = " + thisV + ", target = " + target.show + ", klass = " + klass.show + Trace.show, Trace.position) Bottom } @@ -2025,14 +2064,17 @@ class Objects(using Context @constructorOnly): * @param thisV The value for `C.this` where `C` is represented by the parameter `klass`. * @param klass The enclosing class where the type `tref` is located. */ - def outerValue(tref: TypeRef, thisV: ThisValue, klass: ClassSymbol): Contextual[Value] = + def outerValue(tref: TypeRef, thisV: ThisValue, klass: ClassSymbol): Contextual[ThisValue] = val cls = tref.classSymbol.asClass if tref.prefix == NoPrefix then val enclosing = cls.owner.lexicallyEnclosingClass.asClass resolveThis(enclosing, thisV, klass, elideObjectAccess = cls.isStatic) else if cls.isAllOf(Flags.JavaInterface) then Bottom - else evalType(tref.prefix, thisV, klass, elideObjectAccess = cls.isStatic) + else + val res = evalType(tref.prefix, thisV, klass, elideObjectAccess = cls.isStatic) + assert(res.isInstanceOf[ThisValue], "Not a ref?") + res.asInstanceOf[ThisValue] def printTraceWhenMultiple(trace: Trace)(using Context): String = if trace.toVector.size > 1 then @@ -2040,25 +2082,25 @@ class Objects(using Context @constructorOnly): else "" val mutateErrorSet: mutable.Set[(ClassSymbol, ClassSymbol)] = mutable.Set.empty - def errorMutateOtherStaticObject(currentObj: ClassSymbol, addr: Heap.Addr)(using Trace, Context) = - val otherObj = addr.owner - val addr_trace = addr.getTrace + def errorMutateOtherStaticObject(currentObj: ClassSymbol, scope: Scope)(using Trace, Context) = + val otherObj = scope.owner + val scope_trace = scope.getTrace if mutateErrorSet.add((currentObj, otherObj)) then val msg = s"Mutating ${otherObj.show} during initialization of ${currentObj.show}.\n" + "Mutating other static objects during the initialization of one static object is forbidden. " + Trace.show + - printTraceWhenMultiple(addr_trace) + printTraceWhenMultiple(scope_trace) report.warning(msg, Trace.position) val readErrorSet: mutable.Set[(ClassSymbol, ClassSymbol)] = mutable.Set.empty - def errorReadOtherStaticObject(currentObj: ClassSymbol, addr: Heap.Addr)(using Trace, Context) = - val otherObj = addr.owner - val addr_trace = addr.getTrace + def errorReadOtherStaticObject(currentObj: ClassSymbol, scope: Scope)(using Trace, Context) = + val otherObj = scope.owner + val scope_trace = scope.getTrace if readErrorSet.add((currentObj, otherObj)) then val msg = "Reading mutable state of " + otherObj.show + " during initialization of " + currentObj.show + ".\n" + "Reading mutable state of other static objects is forbidden as it breaks initialization-time irrelevance. " + Trace.show + - printTraceWhenMultiple(addr_trace) + printTraceWhenMultiple(scope_trace) report.warning(msg, Trace.position)