diff --git a/presentation-compiler/src/main/dotty/tools/pc/SelectionRangeProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/SelectionRangeProvider.scala index 7973f4103ff6..42760f38d9af 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/SelectionRangeProvider.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/SelectionRangeProvider.scala @@ -6,7 +6,8 @@ import java.util as ju import scala.jdk.CollectionConverters._ import scala.meta.pc.OffsetParams -import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.untpd.* +import dotty.tools.dotc.ast.NavigateAST import dotty.tools.dotc.core.Contexts.Context import dotty.tools.dotc.interactive.Interactive import dotty.tools.dotc.interactive.InteractiveDriver @@ -23,10 +24,7 @@ import org.eclipse.lsp4j.SelectionRange * @param compiler Metals Global presentation compiler wrapper. * @param params offset params converted from the selectionRange params. */ -class SelectionRangeProvider( - driver: InteractiveDriver, - params: ju.List[OffsetParams] -): +class SelectionRangeProvider(driver: InteractiveDriver, params: ju.List[OffsetParams]): /** * Get the seletion ranges for the provider params @@ -44,10 +42,13 @@ class SelectionRangeProvider( val source = SourceFile.virtual(filePath.toString, text) driver.run(uri, source) val pos = driver.sourcePosition(param) - val path = - Interactive.pathTo(driver.openedTrees(uri), pos)(using ctx) + val unit = driver.compilationUnits(uri) - val bareRanges = path + val untpdPath: List[Tree] = NavigateAST + .pathTo(pos.span, List(unit.untpdTree), true).collect: + case untpdTree: Tree => untpdTree + + val bareRanges = untpdPath .flatMap(selectionRangesFromTree(pos)) val comments = @@ -78,31 +79,31 @@ class SelectionRangeProvider( end selectionRange /** Given a tree, create a seq of [[SelectionRange]]s corresponding to that tree. */ - private def selectionRangesFromTree(pos: SourcePosition)(tree: tpd.Tree)(using Context) = + private def selectionRangesFromTree(pos: SourcePosition)(tree: Tree)(using Context) = def toSelectionRange(srcPos: SourcePosition) = val selectionRange = new SelectionRange() selectionRange.setRange(srcPos.toLsp) selectionRange - val treeSelectionRange = toSelectionRange(tree.sourcePos) + val treeSelectionRange = Seq(toSelectionRange(tree.sourcePos)) + + def allArgsSelectionRange(args: List[Tree]): Option[SelectionRange] = + args match + case Nil => None + case list => + val srcPos = list.head.sourcePos + val lastSpan = list.last.span + val allArgsSrcPos = SourcePosition(srcPos.source, srcPos.span union lastSpan, srcPos.outer) + if allArgsSrcPos.contains(pos) then Some(toSelectionRange(allArgsSrcPos)) + else None tree match - case tpd.DefDef(name, paramss, tpt, rhs) => - // If source position is within a parameter list, add a selection range covering that whole list. - val selectedParams = - paramss - .iterator - .flatMap: // parameter list to a sourcePosition covering the whole list - case Seq(param) => Some(param.sourcePos) - case params @ Seq(head, tail*) => - val srcPos = head.sourcePos - val lastSpan = tail.last.span - Some(SourcePosition(srcPos.source, srcPos.span union lastSpan, srcPos.outer)) - case Seq() => None - .find(_.contains(pos)) - .map(toSelectionRange) - selectedParams ++ Seq(treeSelectionRange) - case _ => Seq(treeSelectionRange) + case DefDef(_, paramss, _, _) => paramss.flatMap(allArgsSelectionRange) ++ treeSelectionRange + case Apply(_, args) => allArgsSelectionRange(args) ++ treeSelectionRange + case TypeApply(_, args) => allArgsSelectionRange(args) ++ treeSelectionRange + case UnApply(_, _, pattern) => allArgsSelectionRange(pattern) ++ treeSelectionRange + case Function(args, body) => allArgsSelectionRange(args) ++ treeSelectionRange + case _ => treeSelectionRange private def setParent( child: SelectionRange, diff --git a/presentation-compiler/test/dotty/tools/pc/tests/SelectionRangeSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/SelectionRangeSuite.scala index 143d998a0ec1..578c8e8bede4 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/SelectionRangeSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/SelectionRangeSuite.scala @@ -75,6 +75,12 @@ class SelectionRangeSuite extends BaseSelectionRangeSuite: | b <- Some(2) | } yield a + b |}""".stripMargin, + """|object Main extends App { + | val total = for { + | >>region>>a <- Some(1)<>region>>for { | a <- Some(1) @@ -102,7 +108,7 @@ class SelectionRangeSuite extends BaseSelectionRangeSuite: ) ) - @Test def `function params` = + @Test def `function-params-1` = check( """|object Main extends App { | def func(a@@: Int, b: Int) = @@ -124,6 +130,32 @@ class SelectionRangeSuite extends BaseSelectionRangeSuite: ) ) + @Test def `function-params-2` = + check( + """|object Main extends App { + | val func = (a@@: Int, b: Int) => + | a + b + |}""".stripMargin, + List[String]( + """|object Main extends App { + | val func = (>>region>>a: Int< + | a + b + |}""".stripMargin, + """|object Main extends App { + | val func = (>>region>>a: Int, b: Int< + | a + b + |}""".stripMargin, + """|object Main extends App { + | val func = >>region>>(a: Int, b: Int) => + | a + b<>region>>val func = (a: Int, b: Int) => + | a + b<>region>>def foo[Type <: T1, B](hi: Int, b: Int, c:Int) = ???<>region>>56<>region>>34 + 56<>region>>(34 + 56)<>region>>12 * (34 + 56)< ???", + List( + "val hello = (aaa: Int, >>region>>bbb: Int< ???", + "val hello = (>>region>>aaa: Int, bbb: Int, ccc: Int< ???", + "val hello = >>region>>(aaa: Int, bbb: Int, ccc: Int) => ???<>region>>val hello = (aaa: Int, bbb: Int, ccc: Int) => ???<>region>>bbb: Int<>region>>aaa: Int, bbb: Int, ccc: Int<>region>>def hello(aaa: Int, bbb: Int, ccc: Int) = ???<>region>>222<>region>>111, 222, 333<>region>>List(111, 222, 333)<>region>>def hello = List(111, 222, 333)<>region>>Int<>region>>String, Int<>region>>Map[String, Int]<>region>>Map[String, Int]()<>region>>def hello = Map[String, Int]()<>region>>bbb<>region>>aaa, bbb, ccc<>region>>List(aaa, bbb, ccc)<>region>>val List(aaa, bbb, ccc) = List(111, 222, 333)<>region>>222<>region>>List(222)<>region>>def hello = List(222)<