Skip to content

Commit 5f33083

Browse files
committed
Find extension copy in companion
1 parent cf75c7f commit 5f33083

File tree

3 files changed

+138
-51
lines changed

3 files changed

+138
-51
lines changed

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ excludeLintKeys in Global ++= Set(ideSkipProject)
1414
val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
1515
organization := "com.softwaremill.quicklens",
1616
updateDocs := UpdateVersionInDocs(sLog.value, organization.value, version.value, List(file("README.md"))),
17-
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all"
17+
scalacOptions ++= Seq("-deprecation", "-feature", "-unchecked"), // useful for debugging macros: "-Ycheck:all", "-Xcheck-macros"
1818
ideSkipProject := (scalaVersion.value != scalaIdeaVersion)
1919
)
2020

quicklens/src/main/scala-3/com/softwaremill/quicklens/QuicklensMacros.scala

Lines changed: 86 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ object QuicklensMacros {
133133
/** Method call with one type parameter and using clause */
134134
case a @ Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) =>
135135
idents.flatMap(toPath(_, focus)) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
136+
case Apply(Ident(ident), Seq(deep)) => // this is an extension method, which is called e.g. as x(_$1)
137+
toPath(deep, focus) :+ PathSymbol.Field(ident)
136138
/** Field access */
137139
case Apply(deep, idents) =>
138140
toPath(deep, focus) ++ idents.flatMap(toPath(_, focus))
@@ -179,21 +181,73 @@ object QuicklensMacros {
179181
case lst => report.errorAndAbort(multipleMatchingMethods(sym.name, name, lst))
180182
}
181183

182-
def methodSymbolByNameAndArgsOrError(sym: Symbol, name: String, argsMap: Map[String, Term]): Symbol = {
184+
def filterMethodsByNameAndArgs(allMethods: List[Symbol], argsMap: Map[String, Term]): Option[Symbol] = {
183185
val argNames = argsMap.keys
184-
sym.methodMember(name).filter{ msym =>
186+
allMethods.filter { msym =>
185187
// for copy, we filter out the methods that don't have the desired parameter names
186188
val paramNames = msym.paramSymss.flatten.filter(_.isTerm).map(_.name)
187189
argNames.forall(paramNames.contains)
188190
} match
189-
case List(m) => m
190-
case Nil => report.errorAndAbort(noSuchMember(sym.name, name))
191-
case lst @ (m :: _) =>
191+
case List(m) => Some(m)
192+
case Nil => None
193+
case lst@(m :: _) =>
192194
// if we have multiple matching copy methods, pick the synthetic one, if it exists, otherwise, pick any method
193195
val syntheticCopies = lst.filter(_.flags.is(Flags.Synthetic))
194196
syntheticCopies match
195-
case List(mSynth) => mSynth
196-
case _ => m
197+
case List(mSynth) => Some(mSynth)
198+
case _ => Some(m)
199+
}
200+
201+
def methodSymbolByNameAndArgs(sym: Symbol, name: String, argsMap: Map[String, Term]): Option[Symbol] = {
202+
val memberMethods = sym.methodMember(name)
203+
filterMethodsByNameAndArgs(memberMethods, argsMap)
204+
}
205+
206+
def callMethod(obj: Term, copy: Symbol, argsMap: List[Map[String, Term]], extension: Boolean = false) = {
207+
val objTpe = obj.tpe.widenAll
208+
val objSymbol = objTpe.matchingTypeSymbol
209+
210+
val typeParams = objTpe match {
211+
case AppliedType(_, typeParams) => Some(typeParams)
212+
case _ => None
213+
}
214+
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
215+
val copyParams: List[(String, Option[Term])] = copyTree.termParamss.zip(argsMap)
216+
.map((params, args) => params.params.map(_.name).map(name => name -> args.get(name)))
217+
.flatten.toList
218+
219+
val args = copyParams.zipWithIndex.map { case ((n, v), _i) =>
220+
val i = _i + 1
221+
def defaultMethod =
222+
val methodSymbol = methodSymbolByNameOrError(objSymbol, copy.name + "$default$" + i.toString)
223+
// default values in extensions are obtained by calling a method receiving the extension parameter
224+
val defaultMethodArgs = argsMap.dropRight(1).headOption.toList.flatMap(_.values)
225+
//println(s"defaultMethodArgs ${obj.show} ${methodSymbol.name} $defaultMethodArgs")
226+
if defaultMethodArgs.nonEmpty then
227+
Apply(Select(obj, methodSymbol), defaultMethodArgs)
228+
else
229+
// note: this is not always correct, -Xcheck-macros shows errors here
230+
// sometimes we should call a method with empry parameter list instead
231+
obj.select(methodSymbol)
232+
233+
// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
234+
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
235+
n -> v.getOrElse(defaultMethod)
236+
}.toMap
237+
238+
val argLists = copyTree.termParamss.take(argsMap.size).map(list => list.params.map(p => args(p.name)))
239+
240+
if copyTree.termParamss.drop(argLists.size).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
241+
report.errorAndAbort(
242+
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit. ${copyTree.termParamss.drop(1)}"
243+
)
244+
245+
val applyOn = typeParams match {
246+
// if the object's type is parametrised, we need to call .copy with the same type parameters
247+
case Some(typeParams) => TypeApply(Select(obj, copy), typeParams.map(Inferred(_)))
248+
case _ => Select(obj, copy)
249+
}
250+
argLists.foldLeft(applyOn)((applied, list) => Apply(applied, list))
197251
}
198252

199253
def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
@@ -210,8 +264,19 @@ object QuicklensMacros {
210264
(sym.flags.is(Flags.Sealed) && (sym.flags.is(Flags.Trait) || sym.flags.is(Flags.Abstract)))
211265
}
212266

267+
def findExtensionMethod(using Quotes)(sym: Symbol, methodName: String): List[Symbol] = {
268+
// TODO: can we check parameter types somehow?
269+
def isExtensionMethod(sym: Symbol): Boolean = sym.isDefDef && sym.paramSymss.headOption.exists(_.sizeIs == 1)
270+
271+
// TODO: try to search in symbol parent object as well
272+
val symbols = Seq(sym.companionModule).filter(_ != Symbol.noSymbol)
273+
274+
symbols.flatMap(_.declaredMethods).filter(sym => sym.name == methodName).filter(isExtensionMethod).toList
275+
}
276+
213277
def isProductLike(sym: Symbol): Boolean = {
214-
sym.methodMember("copy").size >= 1
278+
// just assume true - we can always fail if there is no copy
279+
sym.methodMember("copy").nonEmpty || findExtensionMethod(sym, "copy").nonEmpty
215280
}
216281

217282
def caseClassCopy(
@@ -248,6 +313,7 @@ object QuicklensMacros {
248313
}
249314

250315
val elseThrow = '{ throw new IllegalStateException() }.asTerm
316+
251317
ifThens.foldRight(elseThrow) { case ((ifCond, ifThen), ifElse) =>
252318
If(ifCond, ifThen, ifElse)
253319
}
@@ -260,36 +326,18 @@ object QuicklensMacros {
260326
val namedArg = NamedArg(field.name, resTerm)
261327
field.name -> namedArg
262328
}.toMap
263-
val copy = methodSymbolByNameAndArgsOrError(objSymbol, "copy", argsMap)
264-
265-
val typeParams = objTpe match {
266-
case AppliedType(_, typeParams) => Some(typeParams)
267-
case _ => None
268-
}
269-
val copyTree: DefDef = copy.tree.asInstanceOf[DefDef]
270-
val copyParamNames: List[String] = copyTree.termParamss.headOption.map(_.params).toList.flatten.map(_.name)
271-
272-
val args = copyParamNames.zipWithIndex.map { (n, _i) =>
273-
val i = _i + 1
274-
val defaultMethod = obj.select(methodSymbolByNameOrError(objSymbol, "copy$default$" + i.toString))
275-
// for extension methods, might need sth more like this: (or probably some weird implicit conversion)
276-
// val defaultGetter = obj.select(symbolMethodByNameOrError(objSymbol, n))
277-
argsMap.getOrElse(
278-
n,
279-
defaultMethod
280-
)
281-
}.toList
282-
283-
if copyTree.termParamss.drop(1).exists(_.params.exists(!_.symbol.flags.is(Flags.Implicit))) then
284-
report.errorAndAbort(
285-
s"Implementation limitation: Only the first parameter list of the modified case classes can be non-implicit."
286-
)
287-
288-
typeParams match {
289-
// if the object's type is parametrised, we need to call .copy with the same type parameters
290-
case Some(typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
291-
case _ => Apply(Select(obj, copy), args)
292-
}
329+
methodSymbolByNameAndArgs(objSymbol, "copy", argsMap) match
330+
case Some(copy) =>
331+
callMethod(obj, copy, List(argsMap))
332+
case None =>
333+
val objCompanion = objSymbol.companionModule
334+
methodSymbolByNameAndArgs(objCompanion, "copy", argsMap) match
335+
case Some(copy) =>
336+
// now try to call the extension as a method, assume the object is its first parameter
337+
val firstParam = copy.paramSymss.headOption.map(_.headOption).flatten
338+
val argsWithObj = List(firstParam.map(name => name.name -> obj).toMap, argsMap)
339+
callMethod(Ref(objCompanion), copy, argsWithObj, extension = true)
340+
case None => report.errorAndAbort(noSuchMember(objSymbol.name, "copy"))
293341
} else
294342
report.errorAndAbort(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol of type ${objTpe.show} (${obj.show})")
295343
}

quicklens/src/test/scala-3/com/softwaremill/quicklens/test/ExtensionCopyTest.scala

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,79 @@ object ExtensionCopyTest {
1111

1212
object Vec {
1313
def apply(x: Double, y: Double): Vec = V(x, y)
14-
}
1514

16-
extension (v: Vec) {
17-
def x: Double = v.x
18-
def y: Double = v.y
19-
def copy(x: Double = v.x, y: Double = v.y): Vec = V(x, y)
15+
extension (v: Vec) {
16+
def x: Double = v.x
17+
def y: Double = v.y
18+
def copy(x: Double = v.x, y: Double = v.y): Vec = V(x, y)
19+
}
2020
}
2121
}
2222

2323
class ExtensionCopyTest extends AnyFlatSpec with Matchers {
24+
/*
25+
it should "modify a simple class with an extension copy method" in {
26+
class VecSimple(xp: Double, yp: Double) {
27+
val xMember = xp
28+
val yMember = yp
29+
}
30+
31+
object VecSimple {
32+
def apply(x: Double, y: Double): VecSimple = new VecSimple(x, y)
33+
}
34+
35+
extension (v: VecSimple) {
36+
def copy(x: Double = v.xMember, y: Double = v.yMember): VecSimple = new VecSimple(x, y)
37+
}
38+
val a = VecSimple(1, 2)
39+
val b = a.modify(_.xMember).using(_ + 1)
40+
println(b)
41+
}
42+
*/
43+
44+
it should "modify a simple class with an extension copy method in companion" in {
45+
class VecCompanion(xp: Double, yp: Double) {
46+
val x = xp
47+
val y = yp
48+
}
49+
50+
object VecCompanion {
51+
def apply(x: Double, y: Double): VecCompanion = new VecCompanion(x, y)
52+
extension (v: VecCompanion) {
53+
def copy(x: Double = v.x, y: Double = v.y): VecCompanion = new VecCompanion(x, y)
54+
}
55+
}
56+
57+
val a = VecCompanion(1, 2)
58+
val b = a.modify(_.x).using(_ + 1)
59+
println(b)
60+
}
61+
/*
62+
2463
it should "modify a class with an extension copy method" in {
2564
case class V(x: Double, y: Double)
2665
27-
class Vec(val v: V)
66+
class VecClass(val v: V)
2867
29-
object Vec {
30-
def apply(x: Double, y: Double): Vec = new Vec(V(x, y))
68+
object VecClass {
69+
def apply(x: Double, y: Double): VecClass = new VecClass(V(x, y))
3170
}
3271
33-
extension (v: Vec) {
72+
extension (v: VecClass) {
3473
def x: Double = v.v.x
3574
def y: Double = v.v.y
36-
def copy(x: Double = v.x, y: Double = v.y): Vec = new Vec(V(x, y))
75+
def copy(x: Double = v.x, y: Double = v.y): VecClass = new VecClass(V(x, y))
3776
}
38-
val a = Vec(1, 2)
77+
val a = VecClass(1, 2)
3978
val b = a.modify(_.x).using(_ + 1)
4079
println(b)
4180
}
42-
4381
it should "modify an opaque type with an extension copy method" in {
4482
import ExtensionCopyTest.*
4583
4684
val a = Vec(1, 2)
4785
val b = a.modify(_.x).using(_ + 1)
4886
println(b)
4987
}
88+
*/
5089
}

0 commit comments

Comments
 (0)