diff --git a/presentation-compiler/src/main/dotty/tools/pc/PcConvertToNamedLambdaParameters.scala b/presentation-compiler/src/main/dotty/tools/pc/PcConvertToNamedLambdaParameters.scala new file mode 100644 index 000000000000..2ca50107c36b --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/PcConvertToNamedLambdaParameters.scala @@ -0,0 +1,153 @@ +package dotty.tools.pc + +import java.nio.file.Paths +import java.util as ju + +import scala.jdk.CollectionConverters.* +import scala.meta.pc.OffsetParams + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile +import dotty.tools.dotc.util.SourcePosition +import org.eclipse.lsp4j as l +import dotty.tools.pc.utils.InteractiveEnrichments.* +import dotty.tools.pc.utils.TermNameInference.* + +/** + * Facilitates the code action that converts a wildcard lambda to a lambda with named parameters + * e.g. + * + * List(1, 2).map(<<_>> + 1) => List(1, 2).map(i => i + 1) + */ +final class PcConvertToNamedLambdaParameters( + driver: InteractiveDriver, + params: OffsetParams +): + import PcConvertToNamedLambdaParameters._ + + def convertToNamedLambdaParameters: ju.List[l.TextEdit] = { + val uri = params.uri + val filePath = Paths.get(uri) + driver.run( + uri, + SourceFile.virtual(filePath.toString, params.text), + ) + given newctx: Context = driver.localContext(params) + val pos = driver.sourcePosition(params) + val trees = driver.openedTrees(uri) + val treeList = Interactive.pathTo(trees, pos) + // Extractor for a lambda function (needs context, so has to be defined here) + val LambdaExtractor = Lambda(using newctx) + // select the most inner wildcard lambda + val firstLambda = treeList.collectFirst { + case LambdaExtractor(params, rhsFn) if params.forall(isWildcardParam) => + params -> rhsFn + } + + firstLambda match { + case Some((params, lambda)) => + // avoid names that are either defined or referenced in the lambda + val namesToAvoid = allDefAndRefNamesInTree(lambda) + // compute parameter names based on the type of the parameter + val computedParamNames: List[String] = + params.foldLeft(List.empty[String]) { (acc, param) => + val name = singleLetterNameStream(param.tpe.typeSymbol.name.toString()) + .find(n => !namesToAvoid.contains(n) && !acc.contains(n)) + acc ++ name.toList + } + if computedParamNames.size == params.size then + val paramReferenceEdits = params.zip(computedParamNames).flatMap { (param, paramName) => + val paramReferencePosition = findParamReferencePosition(param, lambda) + paramReferencePosition.toList.map { pos => + val position = pos.toLsp + val range = new l.Range( + position.getStart(), + position.getEnd() + ) + new l.TextEdit(range, paramName) + } + } + val paramNamesStr = computedParamNames.mkString(", ") + val paramDefsStr = + if params.size == 1 then paramNamesStr + else s"($paramNamesStr)" + val defRange = new l.Range( + lambda.sourcePos.toLsp.getStart(), + lambda.sourcePos.toLsp.getStart() + ) + val paramDefinitionEdits = List( + new l.TextEdit(defRange, s"$paramDefsStr => ") + ) + (paramDefinitionEdits ++ paramReferenceEdits).asJava + else + List.empty.asJava + case _ => + List.empty.asJava + } + } + +end PcConvertToNamedLambdaParameters + +object PcConvertToNamedLambdaParameters: + val codeActionId = "ConvertToNamedLambdaParameters" + + class Lambda(using Context): + def unapply(tree: tpd.Block): Option[(List[tpd.ValDef], tpd.Tree)] = tree match { + case tpd.Block((ddef @ tpd.DefDef(_, tpd.ValDefs(params) :: Nil, _, body: tpd.Tree)) :: Nil, tpd.Closure(_, meth, _)) + if ddef.symbol == meth.symbol => + params match { + case List(param) => + // lambdas with multiple wildcard parameters are represented as a single parameter function and a block with wildcard valdefs + Some(multipleUnderscoresFromBody(param, body)) + case _ => Some(params -> body) + } + case _ => None + } + end Lambda + + private def multipleUnderscoresFromBody(param: tpd.ValDef, body: tpd.Tree)(using Context): (List[tpd.ValDef], tpd.Tree) = body match { + case tpd.Block(defs, expr) if param.symbol.is(Flags.Synthetic) => + val wildcardParamDefs = defs.collect { + case valdef: tpd.ValDef if isWildcardParam(valdef) => valdef + } + if wildcardParamDefs.size == defs.size then wildcardParamDefs -> expr + else List(param) -> body + case _ => List(param) -> body + } + + def isWildcardParam(param: tpd.ValDef)(using Context): Boolean = + param.name.toString.startsWith("_$") && param.symbol.is(Flags.Synthetic) + + def findParamReferencePosition(param: tpd.ValDef, lambda: tpd.Tree)(using Context): Option[SourcePosition] = + var pos: Option[SourcePosition] = None + object FindParamReference extends tpd.TreeTraverser: + override def traverse(tree: tpd.Tree)(using Context): Unit = + tree match + case ident @ tpd.Ident(_) if ident.symbol == param.symbol => + pos = Some(tree.sourcePos) + case _ => + traverseChildren(tree) + FindParamReference.traverse(lambda) + pos + end findParamReferencePosition + + def allDefAndRefNamesInTree(tree: tpd.Tree)(using Context): List[String] = + object FindDefinitionsAndRefs extends tpd.TreeAccumulator[List[String]]: + override def apply(x: List[String], tree: tpd.Tree)(using Context): List[String] = + tree match + case tpd.DefDef(name, _, _, _) => + super.foldOver(x :+ name.toString, tree) + case tpd.ValDef(name, _, _) => + super.foldOver(x :+ name.toString, tree) + case tpd.Ident(name) => + super.foldOver(x :+ name.toString, tree) + case _ => + super.foldOver(x, tree) + FindDefinitionsAndRefs.foldOver(Nil, tree) + end allDefAndRefNamesInTree + +end PcConvertToNamedLambdaParameters diff --git a/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala b/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala index dc53525480c3..dde95b848135 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala @@ -61,7 +61,8 @@ case class ScalaPresentationCompiler( CodeActionId.ImplementAbstractMembers, CodeActionId.ExtractMethod, CodeActionId.InlineValue, - CodeActionId.InsertInferredType + CodeActionId.InsertInferredType, + PcConvertToNamedLambdaParameters.codeActionId ).asJava def this() = this("", None, Nil, Nil) @@ -82,26 +83,30 @@ case class ScalaPresentationCompiler( codeActionPayload: Optional[T] ): CompletableFuture[ju.List[TextEdit]] = (codeActionId, codeActionPayload.asScala) match - case ( - CodeActionId.ConvertToNamedArguments, - Some(argIndices: ju.List[_]) - ) => - val payload = - argIndices.asScala.collect { case i: Integer => i.toInt }.toSet - convertToNamedArguments(params, payload) - case (CodeActionId.ImplementAbstractMembers, _) => - implementAbstractMembers(params) - case (CodeActionId.InsertInferredType, _) => - insertInferredType(params) - case (CodeActionId.InlineValue, _) => - inlineValue(params) - case (CodeActionId.ExtractMethod, Some(extractionPos: OffsetParams)) => - params match { - case range: RangeParams => - extractMethod(range, extractionPos) - case _ => failedFuture(new IllegalArgumentException(s"Expected range parameters")) - } - case (id, _) => failedFuture(new IllegalArgumentException(s"Unsupported action id $id")) + case ( + CodeActionId.ConvertToNamedArguments, + Some(argIndices: ju.List[_]) + ) => + val payload = + argIndices.asScala.collect { case i: Integer => i.toInt }.toSet + convertToNamedArguments(params, payload) + case (CodeActionId.ImplementAbstractMembers, _) => + implementAbstractMembers(params) + case (CodeActionId.InsertInferredType, _) => + insertInferredType(params) + case (CodeActionId.InlineValue, _) => + inlineValue(params) + case (CodeActionId.ExtractMethod, Some(extractionPos: OffsetParams)) => + params match { + case range: RangeParams => + extractMethod(range, extractionPos) + case _ => failedFuture(new IllegalArgumentException(s"Expected range parameters")) + } + case (PcConvertToNamedLambdaParameters.codeActionId, _) => + compilerAccess.withNonInterruptableCompiler(List.empty[l.TextEdit].asJava, params.token) { + access => PcConvertToNamedLambdaParameters(access.compiler(), params).convertToNamedLambdaParameters + }(params.toQueryContext) + case (id, _) => failedFuture(new IllegalArgumentException(s"Unsupported action id $id")) private def failedFuture[T](e: Throwable): CompletableFuture[T] = val f = new CompletableFuture[T]() diff --git a/presentation-compiler/src/main/dotty/tools/pc/utils/TermNameInference.scala b/presentation-compiler/src/main/dotty/tools/pc/utils/TermNameInference.scala new file mode 100644 index 000000000000..a1a09968325c --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/utils/TermNameInference.scala @@ -0,0 +1,56 @@ +package dotty.tools.pc.utils + +/** + * Helpers for generating variable names based on the desired types. + */ +object TermNameInference { + + /** Single character names for types. (`Int` => `i`, `i1`, `i2`, ...) */ + def singleLetterNameStream(typeName: String): LazyList[String] = { + sanitizeInput(typeName).fold(saneNamesStream) { typeName1 => + val firstCharStr = typeName1.headOption.getOrElse('x').toLower.toString + numberedStreamFromName(firstCharStr) + } + } + + /** Names only from upper case letters (`OnDemandSymbolIndex` => `odsi`, `odsi1`, `odsi2`, ...) */ + def shortNameStream(typeName: String): LazyList[String] = { + sanitizeInput(typeName).fold(saneNamesStream) { typeName1 => + val upperCases = typeName1.filter(_.isUpper).map(_.toLower) + val name = if (upperCases.isEmpty) typeName1 else upperCases + numberedStreamFromName(name) + } + } + + /** Names from lower case letters (`OnDemandSymbolIndex` => `onDemandSymbolIndex`, `onDemandSymbolIndex1`, ...) */ + def fullNameStream(typeName: String): LazyList[String] = { + sanitizeInput(typeName).fold(saneNamesStream) { typeName1 => + val withFirstLower = + typeName1.headOption.map(_.toLower).getOrElse('x').toString + typeName1.drop(1) + numberedStreamFromName(withFirstLower) + } + } + + /** A lazy list of names: a, b, ..., z, aa, ab, ..., az, ba, bb, ... */ + def saneNamesStream: LazyList[String] = { + val letters = ('a' to 'z').map(_.toString) + def computeNext(acc: String): String = { + if (acc.last == 'z') + computeNext(acc.init) + letters.head + else + acc.init + letters(letters.indexOf(acc.last) + 1) + } + def loop(acc: String): LazyList[String] = + acc #:: loop(computeNext(acc)) + loop("a") + } + + private def sanitizeInput(typeName: String): Option[String] = + val typeName1 = typeName.filter(_.isLetterOrDigit) + Option.when(typeName1.nonEmpty)(typeName1) + + private def numberedStreamFromName(name: String): LazyList[String] = { + val rest = LazyList.from(1).map(name + _) + name #:: rest + } +} diff --git a/presentation-compiler/test/dotty/tools/pc/tests/edit/ConvertToNamedLambdaParametersSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/edit/ConvertToNamedLambdaParametersSuite.scala new file mode 100644 index 000000000000..c1ab229fc216 --- /dev/null +++ b/presentation-compiler/test/dotty/tools/pc/tests/edit/ConvertToNamedLambdaParametersSuite.scala @@ -0,0 +1,187 @@ +package dotty.tools.pc.tests.edit + +import java.net.URI +import java.util.Optional + +import scala.meta.internal.jdk.CollectionConverters.* +import scala.meta.internal.metals.CompilerOffsetParams +import scala.meta.pc.CodeActionId +import scala.meta.pc.DisplayableException +import scala.language.unsafeNulls + +import dotty.tools.pc.base.BaseCodeActionSuite +import dotty.tools.pc.utils.TextEdits +import dotty.tools.pc.PcConvertToNamedLambdaParameters + +import org.eclipse.lsp4j as l +import org.junit.{Test, Ignore} + +class ConvertToNamedLambdaParametersSuite extends BaseCodeActionSuite: + + @Test def `Int => Int function in map` = + checkEdit( + """|object A{ + | val a = List(1, 2).map(<<_>> + 1) + |}""".stripMargin, + """|object A{ + | val a = List(1, 2).map(i => i + 1) + |}""".stripMargin + ) + + @Test def `Int => Int function in map with another wildcard lambda` = + checkEdit( + """|object A{ + | val a = List(1, 2).map(<<_>> + 1).map(_ + 1) + |}""".stripMargin, + """|object A{ + | val a = List(1, 2).map(i => i + 1).map(_ + 1) + |}""".stripMargin + ) + + @Test def `String => String function in map` = + checkEdit( + """|object A{ + | val a = List("a", "b").map(<<_>> + "c") + |}""".stripMargin, + """|object A{ + | val a = List("a", "b").map(s => s + "c") + |}""".stripMargin + ) + + @Test def `Person => Person function to custom method` = + checkEdit( + """|object A{ + | case class Person(name: String, age: Int) + | val bob = Person("Bob", 30) + | def m[A](f: Person => A): A = f(bob) + | m(_<<.>>name) + |} + |""".stripMargin, + """|object A{ + | case class Person(name: String, age: Int) + | val bob = Person("Bob", 30) + | def m[A](f: Person => A): A = f(bob) + | m(p => p.name) + |} + |""".stripMargin + ) + + @Test def `(String, Int) => Int function in map with multiple underscores` = + checkEdit( + """|object A{ + | val a = List(("a", 1), ("b", 2)).map(<<_>> + _) + |}""".stripMargin, + """|object A{ + | val a = List(("a", 1), ("b", 2)).map((s, i) => s + i) + |}""".stripMargin + ) + + @Test def `Int => Int function in map with multiple underscores` = + checkEdit( + """|object A{ + | val a = List(1, 2).map(x => x -> (x + 1)).map(<<_>> + _) + |}""".stripMargin, + """|object A{ + | val a = List(1, 2).map(x => x -> (x + 1)).map((i, i1) => i + i1) + |}""".stripMargin + ) + + @Test def `Int => Float function in nested lambda 1` = + checkEdit( + """|object A{ + | val a = List(1, 2).flatMap(List(_).flatMap(v => List(v, v + 1).map(<<_>>.toFloat))) + |}""".stripMargin, + """|object A{ + | val a = List(1, 2).flatMap(List(_).flatMap(v => List(v, v + 1).map(i => i.toFloat))) + |}""".stripMargin + ) + + @Test def `Int => Float function in nested lambda 2` = + checkEdit( + """|object A{ + | val a = List(1, 2).flatMap(List(<<_>>).flatMap(v => List(v, v + 1).map(_.toFloat))) + |}""".stripMargin, + """|object A{ + | val a = List(1, 2).flatMap(i => List(i).flatMap(v => List(v, v + 1).map(_.toFloat))) + |}""".stripMargin + ) + + @Test def `Int => Float function in nested lambda with shadowing` = + checkEdit( + """|object A{ + | val a = List(1, 2).flatMap(List(<<_>>).flatMap(i => List(i, i + 1).map(_.toFloat))) + |}""".stripMargin, + """|object A{ + | val a = List(1, 2).flatMap(i1 => List(i1).flatMap(i => List(i, i + 1).map(_.toFloat))) + |}""".stripMargin + ) + + @Test def `(String, String, String, String, String, String, String) => String function in map` = + checkEdit( + """|object A{ + | val a = List( + | ("a", "b", "c", "d", "e", "f", "g"), + | ("h", "i", "j", "k", "l", "m", "n") + | ).map(_<< >>+ _ + _ + _ + _ + _ + _) + |}""".stripMargin, + """|object A{ + | val a = List( + | ("a", "b", "c", "d", "e", "f", "g"), + | ("h", "i", "j", "k", "l", "m", "n") + | ).map((s, s1, s2, s3, s4, s5, s6) => s + s1 + s2 + s3 + s4 + s5 + s6) + |}""".stripMargin + ) + + @Test def `Long => Long with match and wildcard pattern` = + checkEdit( + """|object A{ + | val a = List(1L, 2L).map(_ match { + | case 1L => 1L + | case _ => <<2L>> + | }) + |}""".stripMargin, + """|object A{ + | val a = List(1L, 2L).map(l => l match { + | case 1L => 1L + | case _ => 2L + | }) + |}""".stripMargin + ) + + @Ignore + @Test def `Int => Int eta-expansion in map` = + checkEdit( + """|object A{ + | def f(x: Int): Int = x + 1 + | val a = List(1, 2).map(<>) + |}""".stripMargin, + """|object A{ + | def f(x: Int): Int = x + 1 + | val a = List(1, 2).map(i => f(i)) + |}""".stripMargin + ) + + def checkEdit( + original: String, + expected: String, + compat: Map[String, String] = Map.empty + ): Unit = + val edits = convertToNamedLambdaParameters(original) + val (code, _, _) = params(original) + val obtained = TextEdits.applyEdits(code, edits) + assertNoDiff(expected, obtained) + + def convertToNamedLambdaParameters( + original: String, + filename: String = "file:/A.scala" + ): List[l.TextEdit] = { + val (code, _, offset) = params(original) + val result = presentationCompiler + .codeAction( + CompilerOffsetParams(URI.create(filename), code, offset, cancelToken), + PcConvertToNamedLambdaParameters.codeActionId, + Optional.empty() + ) + .get() + result.asScala.toList + }