Skip to content

Commit 6e12d35

Browse files
authored
Treat local mutable vars as capabilities (#24815)
Under separation checking: A mutable var owned by a term that is not annotated with @untrackedCaptures gives rise to a Mutable capability. Since a mutable variable is not trackable, we do this by adding a varMirror symbol to the variable which represents the capability.
2 parents 88c2127 + 69c9e00 commit 6e12d35

31 files changed

+392
-77
lines changed

compiler/src/dotty/tools/dotc/cc/CCState.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ class CCState:
7979
object Unrecorded extends VarState.Unrecorded
8080
object ClosedUnrecorded extends VarState.ClosedUnrecorded
8181

82+
// ----- Mirrors for local vars -------------------------
83+
84+
/** A cache for mirrors of local mutable vars */
85+
val varMirrors = util.EqHashMap[Symbol, Symbol]()
86+
8287
// ------ Context info accessed from companion object when isCaptureCheckingOrSetup is true
8388

8489
private var openExistentialScopes: List[MethodType] = Nil

compiler/src/dotty/tools/dotc/cc/Capability.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,18 @@ object Capabilities:
459459
case self: CoreCapability => self.isTrackableRef
460460
case _ => true
461461

462+
/** Under separation checking: Is this a mutable var owned by a term that is
463+
* not annotated with @untrackedCaptures? Such mutable variables need to be
464+
* tracked as capabilities. Since mutable variables are not trackable, we do
465+
* this by adding a varMirror symbol to such variables which represents the capability.
466+
*/
467+
final def isLocalMutable(using Context): Boolean = this match
468+
case tp @ TermRef(NoPrefix, _) =>
469+
ccConfig.newScheme && ccConfig.strictMutability
470+
&& tp.symbol.isMutableVar
471+
&& !tp.symbol.hasAnnotation(defn.UntrackedCapturesAnnot)
472+
case _ => false
473+
462474
/** The non-derived capability underlying this capability */
463475
final def core: CoreCapability | RootCapability = this match
464476
case self: (CoreCapability | RootCapability) => self

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ extension (tp: Type)
7474
GlobalCap
7575
case ref: Capability if ref.isTrackableRef =>
7676
ref
77+
case ref: TermRef if ref.isLocalMutable =>
78+
ref.mapLocalMutable
7779
case _ =>
7880
// if this was compiled from cc syntax, problem should have been reported at Typer
7981
throw IllegalCaptureRef(tp)
@@ -493,6 +495,12 @@ extension (tp: MethodType)
493495
def marksExistentialScope(using Context): Boolean =
494496
!tp.resType.isInstanceOf[MethodOrPoly]
495497

498+
extension (ref: TermRef | ThisType)
499+
/** Map a local mutable var to its mirror */
500+
def mapLocalMutable(using Context): TermRef | ThisType = ref match
501+
case ref: TermRef if ref.isLocalMutable => ref.symbol.varMirror.termRef
502+
case _ => ref
503+
496504
extension (cls: ClassSymbol)
497505

498506
def pureBaseClass(using Context): Option[Symbol] =
@@ -666,6 +674,16 @@ extension (sym: Symbol)
666674
def isArrayUnderStrictMut(using Context): Boolean =
667675
sym == defn.ArrayClass && ccConfig.strictMutability
668676

677+
def isDisallowedInCapset(using Context): Boolean =
678+
sym.isOneOf(if ccConfig.newScheme && ccConfig.strictMutability then Method else UnstableValueFlags)
679+
680+
def varMirror(using Context): Symbol =
681+
ccState.varMirrors.getOrElseUpdate(sym,
682+
sym.copy(
683+
flags = Flags.EmptyFlags,
684+
info = defn.Caps_Var.typeRef.appliedTo(sym.info)
685+
.capturing(FreshCap(sym, Origin.InDecl(sym)))))
686+
669687
extension (tp: AnnotatedType)
670688
/** Is this a boxed capturing type? */
671689
def isBoxed(using Context): Boolean = tp.annot match

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,14 @@ object CheckCaptures:
106106
/** Check that a @retains annotation only mentions references that can be tracked.
107107
* This check is performed at Typer.
108108
*/
109-
def checkWellformed(parent: Tree, ann: Tree)(using Context): Unit =
109+
def checkWellformedRetains(parent: Tree, ann: Tree)(using Context): Unit =
110110
def check(elem: Type): Unit = elem match
111111
case ref: TypeRef =>
112112
val refSym = ref.symbol
113113
if refSym.isType && !refSym.info.derivesFrom(defn.Caps_CapSet) then
114114
report.error(em"$elem is not a legal element of a capture set", ann.srcPos)
115115
case ref: CoreCapability =>
116-
if !ref.isTrackableRef && !ref.isCapRef then
116+
if !ref.isTrackableRef && !ref.isCapRef && !ref.isLocalMutable then
117117
report.error(em"$elem cannot be tracked since it is not a parameter or local value", ann.srcPos)
118118
case ReachCapability(ref) =>
119119
check(ref)
@@ -306,7 +306,7 @@ class CheckCaptures extends Recheck, SymTransformer:
306306
*/
307307
private val useInfos = mutable.ArrayBuffer[(Tree, CaptureSet, Env)]()
308308

309-
private var usedSet = util.EqHashMap[Tree, CaptureSet]()
309+
private val usedSet = util.EqHashMap[Tree, CaptureSet]()
310310

311311
/** The set of symbols that were rechecked via a completer */
312312
private val completed = new mutable.HashSet[Symbol]
@@ -690,8 +690,15 @@ class CheckCaptures extends Recheck, SymTransformer:
690690
// Lazy vals are like parameterless methods: accessing them may trigger initialization
691691
// that uses captured references.
692692
includeCallCaptures(sym, sym.info, tree)
693-
else if sym.exists && !sym.isStatic then
694-
markPathFree(sym.termRef, pt, tree)
693+
else
694+
if sym.isMutableVar && sym.owner.isTerm && pt != LhsProto then
695+
// When we have `var x: A^{c} = ...` where `x` is a local variable then
696+
// when dereferencing `x` we also need to charge `c`.
697+
// For fields it's not a problem since `c` would already have been
698+
// charged for the prefix `p` in `p.x`.
699+
markFree(sym.info.captureSet, tree)
700+
if sym.exists && !sym.isStatic then
701+
markPathFree(sym.termRef, pt, tree)
695702
mapResultRoots(super.recheckIdent(tree, pt), tree.symbol)
696703

697704
override def recheckThis(tree: This, pt: Type)(using Context): Type =
@@ -713,7 +720,7 @@ class CheckCaptures extends Recheck, SymTransformer:
713720
val sel = ref.select(pt.selector).asInstanceOf[TermRef]
714721
markPathFree(sel, pt.pt, pt.select)
715722
case _ =>
716-
markFree(ref.adjustReadOnly(pt), tree)
723+
markFree(ref.mapLocalMutable.adjustReadOnly(pt), tree)
717724

718725
/** The expected type for the qualifier of a selection. If the selection
719726
* could be part of a capability path or is a a read-only method, we return
@@ -1062,7 +1069,7 @@ class CheckCaptures extends Recheck, SymTransformer:
10621069
recheck(tree.rhs, lhsType.widen)
10631070
lhsType match
10641071
case lhsType @ TermRef(qualType, _)
1065-
if (qualType ne NoPrefix) && !lhsType.symbol.hasAnnotation(defn.UntrackedCapturesAnnot) =>
1072+
if !lhsType.symbol.hasAnnotation(defn.UntrackedCapturesAnnot) =>
10661073
checkUpdate(qualType, tree.srcPos)(i"Cannot assign to field ${lhsType.name} of ${qualType.showRef}")
10671074
case _ =>
10681075
defn.UnitType
@@ -1135,6 +1142,13 @@ class CheckCaptures extends Recheck, SymTransformer:
11351142
openClosures = openClosures.tail
11361143
end recheckClosureBlock
11371144

1145+
/** Add var mirrors to the list of block-local symbols to avoid */
1146+
override def avoidLocals(tp: Type, symsToAvoid: => List[Symbol])(using Context): Type =
1147+
val locals = symsToAvoid
1148+
val varMirrors = locals.collect:
1149+
case local if local.termRef.isLocalMutable => local.varMirror
1150+
super.avoidLocals(tp, varMirrors ++ locals)
1151+
11381152
/** Elements of a SeqLiteral instantiate a Seq or Array parameter, so they
11391153
* should be boxed.
11401154
*/

compiler/src/dotty/tools/dotc/cc/Mutability.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ object Mutability:
121121
&& (!tp.isStatefulType || tp.captureSet.mutability == CaptureSet.Mutability.Reader)
122122

123123
extension (ref: TermRef | ThisType)
124-
/** Map `ref` to `ref.readOnly` if its type extends Mutble, and one of the
124+
/** Map `ref` to `ref.readOnly` if its type extends Mutable, and one of the
125125
* following is true:
126126
* - it appears in a non-exclusive context,
127127
* - the expected type is a value type that is not a stateful type,

compiler/src/dotty/tools/dotc/cc/SepCheck.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,9 @@ class SepCheck(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
744744
val captured = genPart.deepCaptureSet.elems
745745
val hiddenSet = captured.transHiddenSet.pruned
746746
val clashSet = otherPart.deepCaptureSet.elems
747-
val deepClashSet = clashSet.completeFootprint.nonPeaks.pruned
747+
var deepClashSet = clashSet.completeFootprint.nonPeaks.pruned
748+
if deepClashSet.isEmpty then
749+
deepClashSet = clashSet.completeFootprint.pruned
748750
report.error(
749751
em"""Separation failure in ${role.description} $tpe.
750752
|One part, $genPart, hides capabilities ${CaptureSet(hiddenSet)}.

compiler/src/dotty/tools/dotc/cc/ccConfig.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ object ccConfig:
5454

5555
/** Not used currently. Handy for trying out new features */
5656
def newScheme(using ctx: Context): Boolean =
57-
Feature.sourceVersion.stable.isAtLeast(SourceVersion.`3.7`)
57+
Feature.sourceVersion.stable.isAtLeast(SourceVersion.`3.8`)
5858

5959
/** Allow @use annotations */
6060
def allowUse(using Context): Boolean =

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,9 +1027,11 @@ class Definitions {
10271027
@tu lazy val Caps_unsafeDiscardUses: Symbol = CapsUnsafeModule.requiredMethod("unsafeDiscardUses")
10281028
@tu lazy val Caps_unsafeErasedValue: Symbol = CapsUnsafeModule.requiredMethod("unsafeErasedValue")
10291029
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains")
1030+
@tu lazy val Caps_Shared: TypeSymbol = CapsModule.requiredType("Shared")
10301031
@tu lazy val Caps_ContainsModule: Symbol = requiredModule("scala.caps.Contains")
10311032
@tu lazy val Caps_containsImpl: TermSymbol = Caps_ContainsModule.requiredMethod("containsImpl")
10321033
@tu lazy val Caps_freeze: TermSymbol = CapsModule.requiredMethod("freeze")
1034+
@tu lazy val Caps_Var: ClassSymbol = requiredClass("scala.caps.internal.Var")
10331035

10341036
@tu lazy val PureClass: ClassSymbol = requiredClass("scala.caps.Pure")
10351037

@@ -1988,10 +1990,10 @@ class Definitions {
19881990
CapsModule, CapsModule.moduleClass, PureClass,
19891991
/* Caps_Classifier, Caps_SharedCapability, Caps_Control, -- already stable */
19901992
Caps_ExclusiveCapability, Caps_Mutable, Caps_Read, Caps_Unscoped, Caps_Stateful, Caps_Separate,
1991-
RequiresCapabilityAnnot,
1993+
Caps_Shared, RequiresCapabilityAnnot,
19921994
captureRoot, Caps_CapSet, Caps_ContainsTrait, Caps_ContainsModule, Caps_ContainsModule.moduleClass,
19931995
ConsumeAnnot, UseAnnot, ReserveAnnot,
1994-
CapsUnsafeModule, CapsUnsafeModule.moduleClass, Caps_freeze,
1996+
CapsUnsafeModule, CapsUnsafeModule.moduleClass, Caps_freeze, Caps_Var,
19951997
CapsInternalModule, CapsInternalModule.moduleClass,
19961998
RetainsAnnot, RetainsCapAnnot, RetainsByNameAnnot)
19971999

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -799,13 +799,12 @@ object SymDenotations {
799799
*
800800
* Note, (f: => T) is treated as a stable TermRef only in Capture Sets.
801801
*/
802-
final def isStableMember(using Context): Boolean = {
802+
final def isStableMember(using Context): Boolean =
803803
def isUnstableValue =
804804
isOneOf(UnstableValueFlags)
805-
|| !ctx.mode.is(Mode.InCaptureSet) && info.isInstanceOf[ExprType]
805+
|| info.isInstanceOf[ExprType]
806806
|| isAllOf(InlineParam)
807807
isType || is(StableRealizable) || exists && !isUnstableValue
808-
}
809808

810809
/** Is this a denotation of a real class that does not have - either direct or inherited -
811810
* initialization code?

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,9 +566,7 @@ object TypeOps:
566566
def avoid(tp: Type, symsToAvoid: => List[Symbol])(using Context): Type = {
567567
val widenMap = new AvoidMap {
568568
@threadUnsafe lazy val forbidden = symsToAvoid.toSet
569-
def toAvoid(tp: NamedType) =
570-
val sym = tp.symbol
571-
forbidden.contains(sym)
569+
def toAvoid(tp: NamedType) = forbidden.contains(tp.symbol)
572570

573571
override def apply(tp: Type): Type = tp match
574572
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>

0 commit comments

Comments
 (0)