Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 41 additions & 20 deletions compiler/src/dotty/tools/dotc/transform/HoistSuperArgs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import collection.mutable
import ast.Trees.*
import core.NameKinds.SuperArgName

import core.Decorators.*

object HoistSuperArgs {
val name: String = "hoistSuperArgs"
val description: String = "hoist complex arguments of supercalls to enclosing scope"
Expand Down Expand Up @@ -59,14 +57,19 @@ class HoistSuperArgs extends MiniPhase with IdentityDenotTransformer { thisPhase
class Hoister(cls: Symbol)(using Context) {
val superArgDefs: mutable.ListBuffer[DefDef] = new mutable.ListBuffer

/** Check if symbol is a lifted by-name method (created by LiftToDefs). */
private def isLiftedByNameMethod(sym: Symbol): Boolean =
sym.is(Synthetic) && sym.is(Method) && sym.info.isInstanceOf[ExprType]

/** If argument is complex, hoist it out into its own method and refer to the
* method instead.
* @param arg The argument that might be hoisted
* @param cdef The definition of the constructor from which the call is made
* @param lifted Argument definitions that were lifted out in a call prefix
* @param arg The argument that might be hoisted
* @param cdef The definition of the constructor from which the call is made
* @param lifted Argument definitions that were lifted out in a call prefix
* @param inlinableMethods Map from lifted by-name method symbols to their bodies, for inlining
* @return The argument after possible hoisting
*/
private def hoistSuperArg(arg: Tree, cdef: DefDef, lifted: List[Symbol]): Tree = {
private def hoistSuperArg(arg: Tree, cdef: DefDef, lifted: List[Symbol], inlinableMethods: Map[Symbol, Tree] = Map.empty): Tree = {
val constr = cdef.symbol
lazy val origParams = // The parameters that can be accessed in the supercall
if (constr == cls.primaryConstructor)
Expand Down Expand Up @@ -151,6 +154,8 @@ class HoistSuperArgs extends MiniPhase with IdentityDenotTransformer { thisPhase
}
},
treeMap = {
case tree: Ident if inlinableMethods.contains(tree.symbol) =>
inlinableMethods(tree.symbol) // Inline references to lifted by-name methods
case tree: RefTree if needsRewire(tree.tpe) =>
cpy.Ident(tree)(tree.name).withType(tree.tpe)
case tree =>
Expand Down Expand Up @@ -181,26 +186,42 @@ class HoistSuperArgs extends MiniPhase with IdentityDenotTransformer { thisPhase
}

/** Hoist complex arguments in super call out of the class. */
def hoistSuperArgsFromCall(superCall: Tree, cdef: DefDef, lifted: mutable.ListBuffer[Symbol]): Tree = superCall match
def hoistSuperArgsFromCall(
superCall: Tree,
cdef: DefDef,
lifted: mutable.ListBuffer[Symbol],
inlinableMethods: mutable.Map[Symbol, Tree] = mutable.Map.empty
): Tree = superCall match
case Block(defs, expr) if !expr.symbol.owner.is(Scala2x) =>
// MO: The guard avoids the crash for #16351.
// It would be good to dig deeper, but I won't have the time myself to do it.
cpy.Block(superCall)(
stats = defs.mapconserve {
case vdef: ValDef =>
try cpy.ValDef(vdef)(rhs = hoistSuperArg(vdef.rhs, cdef, lifted.toList))
finally lifted += vdef.symbol
case ddef: DefDef =>
try cpy.DefDef(ddef)(rhs = hoistSuperArg(ddef.rhs, cdef, lifted.toList))
val processedStats = defs.mapconserve {
case vdef: ValDef =>
try cpy.ValDef(vdef)(rhs = hoistSuperArg(vdef.rhs, cdef, lifted.toList, inlinableMethods.toMap))
finally lifted += vdef.symbol
case ddef: DefDef =>
if isLiftedByNameMethod(ddef.symbol) then
// Store body for inlining, don't add to lifted buffer
inlinableMethods(ddef.symbol) = ddef.rhs
ddef // Keep DefDef temporarily, will be filtered out below
else
try cpy.DefDef(ddef)(rhs = hoistSuperArg(ddef.rhs, cdef, lifted.toList, inlinableMethods.toMap))
finally lifted += ddef.symbol
case stat =>
stat
},
expr = hoistSuperArgsFromCall(expr, cdef, lifted))
case stat =>
stat
}
// Filter out DefDefs that were inlined
val filteredStats = processedStats.filterNot {
case ddef: DefDef => inlinableMethods.contains(ddef.symbol)
case _ => false
}
cpy.Block(superCall)(
stats = filteredStats,
expr = hoistSuperArgsFromCall(expr, cdef, lifted, inlinableMethods))
case Apply(fn, args) =>
cpy.Apply(superCall)(
hoistSuperArgsFromCall(fn, cdef, lifted),
args.mapconserve(hoistSuperArg(_, cdef, lifted.toList)))
hoistSuperArgsFromCall(fn, cdef, lifted, inlinableMethods),
args.mapconserve(hoistSuperArg(_, cdef, lifted.toList, inlinableMethods.toMap)))
case _ =>
superCall

Expand Down
3 changes: 3 additions & 0 deletions tests/run/i24201.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
E1
1
3
14 changes: 14 additions & 0 deletions tests/run/i24201.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
abstract class Foo[T](defaultValue: => T, arg1: Int = 1, arg2: Int = 2):
def getDefault: T = defaultValue
def getArg1: Int = arg1
def getArg2: Int = arg2

enum Baz:
case E1, E2

object Baz extends Foo[Baz](Baz.E1, arg2 = 3)

@main def Test =
println(Baz.getDefault)
println(Baz.getArg1)
println(Baz.getArg2)
3 changes: 3 additions & 0 deletions tests/run/i24201b.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
MyValue
1
3
15 changes: 15 additions & 0 deletions tests/run/i24201b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
abstract class Foo[T](defaultValue: => T, arg1: Int = 1, arg2: Int = 2):
def getDefault: T = defaultValue
def getArg1: Int = arg1
def getArg2: Int = arg2

class MyValue:
override def toString = "MyValue"

object TheObject extends Foo[MyValue](TheObject.value, arg2 = 3):
val value = new MyValue

@main def Test =
println(TheObject.getDefault)
println(TheObject.getArg1)
println(TheObject.getArg2)
Loading