Skip to content

Commit 9ae20d9

Browse files
committed
Refactor type maps to split out capability handling
Note i15923 does not signal a leak anymore. I moved it and some variants to pending. Note: There seems to be something more fundamentally wrong with this test: We get an infinite recursion for variant i15923b.
1 parent e23f248 commit 9ae20d9

File tree

13 files changed

+216
-101
lines changed

13 files changed

+216
-101
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureAnnotation.scala

+4-2
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,11 @@ case class CaptureAnnotation(refs: CaptureSet, boxed: Boolean)(cls: Symbol) exte
6363

6464
override def mapWith(tm: TypeMap)(using Context) =
6565
val elems = refs.elems.toList
66-
val elems1 = elems.mapConserve(tm)
66+
val elems1 = elems.mapConserve(tm.mapCapability(_))
6767
if elems1 eq elems then this
68-
else if elems1.forall(_.isTrackableRef)
68+
else if elems1.forall:
69+
case elem1: CaptureRef => elem1.isTrackableRef
70+
case _ => false
6971
then derivedAnnotation(CaptureSet(elems1.asInstanceOf[List[CaptureRef]]*), boxed)
7072
else EmptyAnnotation
7173

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

+11-6
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ extension (tp: Type)
114114
!tp.underlying.exists // might happen during construction of lambdas
115115
|| tp.derivesFrom(defn.Caps_CapSet)
116116
case root.Result(_) => true
117+
case root.Fresh(_) => true
117118
case AnnotatedType(parent, annot) =>
118119
defn.capabilityWrapperAnnots.contains(annot.symbol) && parent.isTrackableRef
119120
case _ =>
@@ -143,9 +144,9 @@ extension (tp: Type)
143144
if dcs.isAlwaysEmpty then tp.captureSet
144145
else tp match
145146
case tp @ ReachCapability(_) =>
146-
tp.singletonCaptureSet
147+
assert(false); tp.singletonCaptureSet
147148
case ReadOnlyCapability(ref) =>
148-
ref.deepCaptureSet(includeTypevars).readOnly
149+
assert(false); ref.deepCaptureSet(includeTypevars).readOnly
149150
case tp: SingletonCaptureRef if tp.isTrackableRef =>
150151
tp.reach.singletonCaptureSet
151152
case _ =>
@@ -195,9 +196,12 @@ extension (tp: Type)
195196
* are of the form this.C but their pathroot is still this.C, not this.
196197
*/
197198
final def pathRoot(using Context): Type = tp.dealias match
198-
case tp1: NamedType
199-
if tp1.symbol.maybeOwner.isClass && tp1.symbol != defn.captureRoot && !tp1.symbol.is(TypeParam) =>
200-
tp1.prefix.pathRoot
199+
case tp1: NamedType =>
200+
if tp1.symbol.maybeOwner.isClass && tp1.symbol != defn.captureRoot && !tp1.symbol.is(TypeParam) then
201+
tp1.prefix match
202+
case pre: CaptureRef => pre.pathRoot
203+
case _ => tp1
204+
else tp1
201205
case tp1 => tp1
202206

203207
/** If this part starts with `C.this`, the class `C`.
@@ -214,7 +218,8 @@ extension (tp: Type)
214218
tp1.prefix match
215219
case _: ThisType | NoPrefix =>
216220
tp1.symbol.is(Param) || tp1.symbol.is(ParamAccessor)
217-
case prefix => prefix.isParamPath
221+
case prefix: CaptureRef => prefix.isParamPath
222+
case _ => false
218223
case _: ParamRef => true
219224
case _ => false
220225

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+16-18
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ sealed abstract class CaptureSet extends Showable:
319319
def map(tm: TypeMap)(using Context): CaptureSet =
320320
tm match
321321
case tm: BiTypeMap =>
322-
val mappedElems = elems.map(tm.forward)
322+
val mappedElems = elems.map(tm.mapCapability(_))
323323
if isConst then
324324
if mappedElems == elems then this
325325
else Const(mappedElems)
@@ -487,7 +487,7 @@ object CaptureSet:
487487
override def toString = elems.toString
488488
end Const
489489

490-
case class EmptyWithProvenance(ref: CaptureRef, mapped: Type) extends Const(SimpleIdentitySet.empty):
490+
case class EmptyWithProvenance(ref: CaptureRef, mapped: CaptureSet) extends Const(SimpleIdentitySet.empty):
491491
override def optionalInfo(using Context): String =
492492
if ctx.settings.YccDebug.value
493493
then i" under-approximating the result of mapping $ref to $mapped"
@@ -587,8 +587,7 @@ object CaptureSet:
587587
*/
588588
private def checkSkippedMaps(elem: CaptureRef)(using Context): Unit =
589589
for tm <- skippedMaps do
590-
val elem1 = tm(elem)
591-
for elem1 <- tm(elem).captureSet.elems do
590+
for elem1 <- extrapolateCaptureRef(elem, tm, variance = 1).elems do
592591
assert(elem.subsumes(elem1),
593592
i"Skipped map ${tm.getClass} maps newly added $elem to $elem1 in $this")
594593

@@ -817,14 +816,14 @@ object CaptureSet:
817816

818817
override def tryInclude(elem: CaptureRef, origin: CaptureSet)(using Context, VarState): CompareResult =
819818
if origin eq source then
820-
val mappedElem = bimap.forward(elem)
819+
val mappedElem = bimap.mapCapability(elem)
821820
if accountsFor(mappedElem) then CompareResult.OK
822821
else addNewElem(mappedElem)
823822
else if accountsFor(elem) then
824823
CompareResult.OK
825824
else
826825
try
827-
source.tryInclude(bimap.backward(elem), this)
826+
source.tryInclude(bimap.inverse.mapCapability(elem), this)
828827
.showing(i"propagating new elem $elem backward from $this to $source = $result", captDebug)
829828
.andAlso(addNewElem(elem))
830829
catch case ex: AssertionError =>
@@ -1031,15 +1030,12 @@ object CaptureSet:
10311030
* - Otherwise assertion failure
10321031
*/
10331032
def extrapolateCaptureRef(r: CaptureRef, tm: TypeMap, variance: Int)(using Context): CaptureSet =
1034-
val r1 = tm(r)
1035-
val upper = r1.captureSet
1036-
def isExact =
1037-
upper.isAlwaysEmpty
1038-
|| upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1)
1039-
|| r.derivesFrom(defn.Caps_CapSet)
1040-
if variance > 0 || isExact then upper
1041-
else if variance < 0 then CaptureSet.EmptyWithProvenance(r, r1)
1042-
else upper.maybe
1033+
tm.mapCapability(r) match
1034+
case c: CaptureRef => c.captureSet
1035+
case (cs: CaptureSet, exact) =>
1036+
if cs.isAlwaysEmpty || exact || variance > 0 then cs
1037+
else if variance < 0 then CaptureSet.EmptyWithProvenance(r, cs)
1038+
else cs.maybe
10431039

10441040
/** Apply `f` to each element in `xs`, and join result sets with `++` */
10451041
def mapRefs(xs: Refs, f: CaptureRef => CaptureSet)(using Context): CaptureSet =
@@ -1289,9 +1285,11 @@ object CaptureSet:
12891285

12901286
def mapRef(ref: CaptureRef): CaptureRef
12911287

1292-
def apply(t: Type) = t match
1293-
case t: CaptureRef if t.isTrackableRef => mapRef(t)
1294-
case _ => mapOver(t)
1288+
def apply(t: Type) = mapOver(t)
1289+
1290+
override def mapCapability(c: CaptureRef, deep: Boolean): CaptureRef = c match
1291+
case c: CaptureRef if c.isTrackableRef => mapRef(c)
1292+
case _ => super.mapCapability(c, deep)
12951293

12961294
override def fuse(next: BiTypeMap)(using Context) = next match
12971295
case next: Inverse if next.inverse.getClass == getClass => assert(false); Some(IdentityTypeMap)

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ class CheckCaptures extends Recheck, SymTransformer:
441441
markFree(sym, sym.termRef, tree)
442442

443443
def markFree(sym: Symbol, ref: CaptureRef, tree: Tree)(using Context): Unit =
444-
if sym.exists && ref.isTracked then markFree(ref.captureSet, tree)
444+
if sym.exists && ref.isTracked then markFree(ref.singletonCaptureSet, tree)
445445

446446
/** Make sure the (projected) `cs` is a subset of the capture sets of all enclosing
447447
* environments. At each stage, only include references from `cs` that are outside

compiler/src/dotty/tools/dotc/cc/Setup.scala

+19-9
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
209209
innerApply(tp)
210210
finally isTopLevel = saved
211211

212+
override def mapArg(arg: Type, tparam: ParamInfo): Type =
213+
super.mapArg(Recheck.mapExprType(arg), tparam)
214+
212215
/** Map parametric functions with results that have a capture set somewhere
213216
* to dependent functions.
214217
*/
@@ -504,12 +507,15 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
504507
def add = new TypeTraverser:
505508
var reach = false
506509
def traverse(t: Type): Unit = t match
507-
case root.Fresh(hidden) =>
508-
if reach then hidden.elems += ref.reach
509-
else if ref.isTracked then hidden.elems += ref
510-
case t @ CapturingType(_, _) if t.isBoxed && !reach =>
511-
reach = true
512-
try traverseChildren(t) finally reach = false
510+
case t @ CapturingType(parent, refs) =>
511+
val saved = reach
512+
reach |= t.isBoxed
513+
try
514+
traverse(parent)
515+
for case root.Fresh(hidden) <- refs.elems.iterator do
516+
if reach then hidden.elems += ref.reach
517+
else if ref.isTracked then hidden.elems += ref
518+
finally reach = saved
513519
case _ =>
514520
traverseChildren(t)
515521
if ref.isTrackableRef then add.traverse(tp)
@@ -660,9 +666,13 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
660666

661667
def paramsToCap(mt: Type)(using Context): Type = mt match
662668
case mt: MethodType =>
663-
mt.derivedLambdaType(
664-
paramInfos = mt.paramInfos.map(root.freshToCap),
665-
resType = paramsToCap(mt.resType))
669+
try
670+
mt.derivedLambdaType(
671+
paramInfos = mt.paramInfos.map(root.freshToCap),
672+
resType = paramsToCap(mt.resType))
673+
catch case ex: AssertionError =>
674+
println(i"error while mapping params ${mt.paramInfos} of $sym")
675+
throw ex
666676
case mt: PolyType =>
667677
mt.derivedLambdaType(resType = paramsToCap(mt.resType))
668678
case _ => mt

compiler/src/dotty/tools/dotc/cc/root.scala

+49-39
Original file line numberDiff line numberDiff line change
@@ -154,21 +154,6 @@ object root:
154154
override def eql(that: Annotation) = that match
155155
case Annot(kind) => this.kind eq kind
156156
case _ => false
157-
158-
/** Special treatment of `SubstBindingMaps` which can change the binder of
159-
* Result instances
160-
*/
161-
override def mapWith(tm: TypeMap)(using Context) = kind match
162-
case Kind.Result(binder) => tm match
163-
case tm: Substituters.SubstBindingMap[MethodType] @unchecked if tm.from eq binder =>
164-
derivedAnnotation(tm.to)
165-
case tm: Substituters.SubstBindingsMap =>
166-
var i = 0
167-
while i < tm.from.length && (tm.from(i) ne binder) do i += 1
168-
if i < tm.from.length then derivedAnnotation(tm.to(i).asInstanceOf[MethodType])
169-
else this
170-
case _ => this
171-
case _ => this
172157
end Annot
173158

174159
def cap(using Context): TermRef = defn.captureRoot.termRef
@@ -222,8 +207,7 @@ object root:
222207
override def apply(t: Type) =
223208
if variance <= 0 then t
224209
else t match
225-
case t: CaptureRef if t.isCap =>
226-
Fresh(origin)
210+
case root(_) => assert(false)
227211
case t @ CapturingType(parent: TypeRef, _) if parent.symbol == defn.Caps_CapSet =>
228212
t
229213
case t @ CapturingType(_, _) =>
@@ -237,6 +221,11 @@ object root:
237221
case _ =>
238222
mapFollowingAliases(t)
239223

224+
override def mapCapability(c: CaptureRef, deep: Boolean): CaptureRef = c match
225+
case c: CaptureRef if c.isCap => Fresh(origin)
226+
case root(_) => c
227+
case _ => super.mapCapability(c, deep)
228+
240229
override def fuse(next: BiTypeMap)(using Context) = next match
241230
case next: Inverse => assert(false); Some(IdentityTypeMap)
242231
case _ => None
@@ -245,13 +234,14 @@ object root:
245234

246235
class Inverse extends BiTypeMap, FollowAliasesMap:
247236
def apply(t: Type): Type = t match
248-
case t @ Fresh(_) => cap
237+
case root(_) => assert(false)
249238
case t @ CapturingType(_, refs) => mapOver(t)
250239
case _ => mapFollowingAliases(t)
251240

252-
override def fuse(next: BiTypeMap)(using Context) = next match
253-
case next: CapToFresh => assert(false); Some(IdentityTypeMap)
254-
case _ => None
241+
override def mapCapability(c: CaptureRef, deep: Boolean): CaptureRef = c match
242+
case c @ Fresh(_) => cap
243+
case root(_) => c
244+
case _ => super.mapCapability(c, deep)
255245

256246
def inverse = thisMap
257247
override def toString = thisMap.toString + ".inverse"
@@ -283,9 +273,7 @@ object root:
283273
var localBinders: SimpleIdentitySet[MethodType] = SimpleIdentitySet.empty
284274

285275
def apply(t: Type): Type = t match
286-
case t @ Result(binder) =>
287-
if localBinders.contains(binder) then t // keep bound references
288-
else seen.getOrElseUpdate(t.annot, Fresh(origin)) // map free references to Fresh()
276+
case root(_) => assert(false)
289277
case t: MethodType =>
290278
// skip parameters
291279
val saved = localBinders
@@ -298,6 +286,14 @@ object root:
298286
case _ =>
299287
mapOver(t)
300288

289+
override def mapCapability(c: CaptureRef, deep: Boolean) = c match
290+
case t @ Result(binder) =>
291+
if localBinders.contains(binder) then t // keep bound references
292+
else seen.getOrElseUpdate(t.annot, Fresh(origin)) // map free references to Fresh()
293+
case root(_) => c
294+
case _ => super.mapCapability(c, deep)
295+
end subst
296+
301297
subst(tp)
302298
end resultToFresh
303299

@@ -320,15 +316,7 @@ object root:
320316
private val seen = EqHashMap[CaptureRef, Result]()
321317

322318
def apply(t: Type) = t match
323-
case t: CaptureRef if t.isCapOrFresh =>
324-
if variance > 0 then
325-
seen.getOrElseUpdate(t, Result(mt))
326-
else
327-
if variance == 0 then
328-
fail(em"""$tp captures the root capability `cap` in invariant position.
329-
|This capability cannot be converted to an existential in the result type of a function.""")
330-
// we accept variance < 0, and leave the cap as it is
331-
super.mapOver(t)
319+
case root(_) => assert(false)
332320
case defn.FunctionNOf(args, res, contextual) if t.typeSymbol.name.isImpureFunction =>
333321
if variance > 0 then
334322
super.mapOver:
@@ -337,26 +325,48 @@ object root:
337325
else mapOver(t)
338326
case _ =>
339327
mapOver(t)
328+
329+
override def mapCapability(c: CaptureRef, deep: Boolean) = c match
330+
case c: CaptureRef if c.isCapOrFresh =>
331+
if variance > 0 then
332+
seen.getOrElseUpdate(c, Result(mt))
333+
else
334+
if variance == 0 then
335+
fail(em"""$tp captures the root capability `cap` in invariant position.
336+
|This capability cannot be converted to an existential in the result type of a function.""")
337+
// we accept variance < 0, and leave the cap as it is
338+
c
339+
case root(_) => c
340+
case _ =>
341+
super.mapCapability(c, deep)
342+
340343
//.showing(i"mapcap $t = $result")
341344
override def toString = "toVar"
342345

343346
object inverse extends BiTypeMap:
344347
def apply(t: Type) = t match
345-
case t @ Result(`mt`) =>
348+
case root(_) => assert(false)
349+
case _ => mapOver(t)
350+
def inverse = toVar.this
351+
override def toString = "toVar.inverse"
352+
353+
override def mapCapability(c: CaptureRef, deep: Boolean) = c match
354+
case c @ Result(`mt`) =>
346355
// do a reverse getOrElseUpdate on `seen` to produce the
347356
// `Fresh` assosicated with `t`
348357
val it = seen.iterator
349358
var ref: CaptureRef | Null = null
350359
while it.hasNext && ref == null do
351360
val (k, v) = it.next
352-
if v.annot eq t.annot then ref = k
361+
if v eq c then ref = k
353362
if ref == null then
354363
ref = Fresh(Origin.Unknown)
355-
seen(ref) = t
364+
seen(ref) = c
356365
ref
357-
case _ => mapOver(t)
358-
def inverse = toVar.this
359-
override def toString = "toVar.inverse"
366+
case root(_) => c
367+
case _ =>
368+
super.mapCapability(c, deep)
369+
end inverse
360370
end toVar
361371

362372
toVar(tp)

0 commit comments

Comments
 (0)