diff --git a/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala b/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala index 0fdcd19e00e..c5c9c13018f 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/MetalsEnrichments.scala @@ -1387,6 +1387,9 @@ object MetalsEnrichments new l.TextDocumentIdentifier(location.getUri()), location.getRange().getStart(), ) + + def toReferenceParams(includeDeclaration: Boolean): l.ReferenceParams = + toTextDocumentPositionParams.toReferenceParams(includeDeclaration) } implicit class XtensionTextDocumentPositionParams( @@ -1407,6 +1410,16 @@ object MetalsEnrichments ).toOption.flatten .getOrElse(input.text) } + + def toReferenceParams(includeDeclaration: Boolean): l.ReferenceParams = { + val referenceParams = new l.ReferenceParams() + referenceParams.setPosition(params.getPosition()) + referenceParams.setTextDocument(params.getTextDocument()) + val context = new l.ReferenceContext() + context.setIncludeDeclaration(includeDeclaration) + referenceParams.setContext(context) + referenceParams + } } implicit class XtensionDebugSessionParams(params: b.DebugSessionParams) { diff --git a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala index d9dbb60f7dd..b55730aa224 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/MetalsLspService.scala @@ -459,6 +459,7 @@ abstract class MetalsLspService( buildTargets, compilers, scalaVersionSelector, + () => implementationProvider, ) protected val packageProvider: PackageProvider = diff --git a/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala b/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala index 97227d36f37..72f6fd78945 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/ReferenceProvider.scala @@ -14,6 +14,8 @@ import scala.util.control.NonFatal import scala.util.matching.Regex import scala.meta.Importee +import scala.meta.internal.implementation.ImplementationProvider +import scala.meta.internal.implementation.SuperMethodProvider import scala.meta.internal.metals.MetalsEnrichments._ import scala.meta.internal.metals.ResolvedSymbolOccurrence import scala.meta.internal.mtags.DefinitionAlternatives.GlobalSymbol @@ -37,6 +39,7 @@ import com.google.common.hash.BloomFilter import com.google.common.hash.Funnels import org.eclipse.lsp4j.Location import org.eclipse.lsp4j.ReferenceParams +import org.eclipse.lsp4j.TextDocumentPositionParams final class ReferenceProvider( workspace: AbsolutePath, @@ -47,6 +50,7 @@ final class ReferenceProvider( buildTargets: BuildTargets, compilers: Compilers, scalaVersionSelector: ScalaVersionSelector, + implementationProvider: () => ImplementationProvider, )(implicit ec: ExecutionContext) extends SemanticdbFeatureProvider with CompletionItemPriority { @@ -176,6 +180,7 @@ final class ReferenceProvider( params: ReferenceParams, findRealRange: AdjustRange = noAdjustRange, includeSynthetics: Synthetic => Boolean = _ => true, + includeImplementations: Boolean = true, )(implicit report: ReportContext): Future[List[ReferencesResult]] = { val source = params.getTextDocument.getUri.toAbsolutePath val textDoc = semanticdbs().textDocument(source) @@ -250,21 +255,53 @@ final class ReferenceProvider( } } } + val includeDeclaration = params.getContext().isIncludeDeclaration() val pcResult = pcReferences( source, results.flatMap(_.occurrence).map(_.symbol), - params.getContext().isIncludeDeclaration(), + includeDeclaration, findRealRange, ) + val implementationRefs: Future[List[ReferencesResult]] = + if (includeImplementations) { + val hasMethodSymbol = results.exists( + _.occurrence.exists(_.symbol.desc.isMethod) + ) + if (hasMethodSymbol) { + val textParams = new TextDocumentPositionParams( + params.getTextDocument(), + params.getPosition(), + ) + implementationProvider().implementations(textParams).flatMap { + implLocs => + val implRefs = implLocs.map { implLoc => + val implParams = + implLoc.toReferenceParams(includeDeclaration) + references( + implParams, + findRealRange, + includeSynthetics, + includeImplementations = false, + ) + } + Future.sequence(implRefs).map(_.flatten) + } + } else { + Future.successful(Nil) + } + } else { + Future.successful(Nil) + } + Future - .sequence(List(semanticdbResult, pcResult)) + .sequence(List(semanticdbResult, pcResult, implementationRefs)) .map( _.flatten .groupBy(_.symbol) .collect { case (symbol, refs) => - ReferencesResult(symbol, refs.flatMap(_.locations)) + ReferencesResult(symbol, refs.flatMap(_.locations).distinct) } .toList ) @@ -379,6 +416,12 @@ final class ReferenceProvider( alternatives.isContructorParam(info) }.toSet + val overriddenSymbols = definitionDoc.symbols + .find(_.symbol == symbol) + .map(SuperMethodProvider.getSuperMethodHierarchy) + .getOrElse(Nil) + .toSet + val nonSyntheticSymbols = for { occ <- definitionDoc.occurrences if isCandidate(occ.symbol) || occ.symbol == symbol @@ -397,11 +440,11 @@ final class ReferenceProvider( } yield info.symbol if (defPath.isJava) - isCandidate + isCandidate ++ overriddenSymbols else if (isSyntheticSymbol) - isCandidate ++ additionalAlternativesForSynthetic + isCandidate ++ additionalAlternativesForSynthetic ++ overriddenSymbols else - isCandidate + isCandidate ++ overriddenSymbols case None => Set.empty } } diff --git a/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala b/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala index b58898c54df..d10bb397c83 100644 --- a/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala +++ b/metals/src/main/scala/scala/meta/internal/rename/RenameProvider.scala @@ -40,8 +40,6 @@ import org.eclipse.lsp4j.Location import org.eclipse.lsp4j.MessageParams import org.eclipse.lsp4j.MessageType import org.eclipse.lsp4j.Position -import org.eclipse.lsp4j.ReferenceContext -import org.eclipse.lsp4j.ReferenceParams import org.eclipse.lsp4j.RenameFile import org.eclipse.lsp4j.RenameParams import org.eclipse.lsp4j.ResourceOperation @@ -224,10 +222,8 @@ final class RenameProvider( * isJava - in Java we can include declarations safely and we * also need to include contructors. */ - toReferenceParams( - txtParams, - includeDeclaration = isJava, - ), + txtParams + .toReferenceParams(includeDeclaration = isJava), findRealRange = AdjustRange(findRealRange(newName)), includeSynthetic, ) @@ -456,7 +452,7 @@ final class RenameProvider( } yield { referenceProvider .references( - toReferenceParams(loc, includeDeclaration = false), + loc.toReferenceParams(includeDeclaration = false), findRealRange = AdjustRange(findRealRange(newName)), ) .map(_.flatMap(_.locations :+ loc)) @@ -507,7 +503,7 @@ final class RenameProvider( result <- { val result = for { implLoc <- implLocs - locParams = toReferenceParams(implLoc, includeDeclaration = true) + locParams = implLoc.toReferenceParams(includeDeclaration = true) } yield { referenceProvider .references( @@ -706,44 +702,6 @@ final class RenameProvider( } } - private def toReferenceParams( - textDoc: TextDocumentIdentifier, - pos: Position, - includeDeclaration: Boolean, - ): ReferenceParams = { - val referenceParams = new ReferenceParams() - referenceParams.setPosition(pos) - referenceParams.setTextDocument(textDoc) - val context = new ReferenceContext() - context.setIncludeDeclaration(includeDeclaration) - referenceParams.setContext(context) - referenceParams - } - - private def toReferenceParams( - location: Location, - includeDeclaration: Boolean, - ): ReferenceParams = { - val textDoc = new TextDocumentIdentifier() - textDoc.setUri(location.getUri()) - toReferenceParams( - textDoc, - location.getRange().getStart(), - includeDeclaration, - ) - } - - private def toReferenceParams( - params: TextDocumentPositionParams, - includeDeclaration: Boolean, - ): ReferenceParams = { - toReferenceParams( - params.getTextDocument(), - params.getPosition(), - includeDeclaration, - ) - } - private def toTextParams(location: Location): TextDocumentPositionParams = { new TextDocumentPositionParams( new TextDocumentIdentifier(location.getUri()), diff --git a/tests/unit/src/test/scala/tests/ReferenceLspSuite.scala b/tests/unit/src/test/scala/tests/ReferenceLspSuite.scala index 5c300595c63..4e3b4063728 100644 --- a/tests/unit/src/test/scala/tests/ReferenceLspSuite.scala +++ b/tests/unit/src/test/scala/tests/ReferenceLspSuite.scala @@ -559,6 +559,83 @@ class ReferenceLspSuite extends BaseRangesSuite("reference") { } yield () } + check( + "override-references", + """|/a/src/main/scala/a/Main.scala + |package a + | + |trait Animal { + | def <>(): Unit + |} + | + |class Dog extends Animal { + | override def <>(): Unit = println("Bark") + |} + | + |object Main { + | def test(a: Animal, d: Dog): Unit = { + | a.<>() + | d.<>() + | } + |} + |""".stripMargin, + ) + + check( + "find-references-includes-implementations", + """|/a/src/main/scala/a/Main.scala + |package a + | + |trait Processor { + | def <>(input: String): String + |} + | + |class StringProcessor extends Processor { + | override def <>(input: String): String = input.toUpperCase + |} + | + |class NumberProcessor extends Processor { + | override def <>(input: String): String = input.filter(_.isDigit) + |} + | + |object Main { + | def use(p: Processor): String = p.<>("test") + |} + |""".stripMargin, + ) + + check( + "diamond-inheritance-references", + """|/a/src/main/scala/a/Main.scala + |package a + | + |trait A { + | def <>(): Unit + |} + | + |trait B extends A { + | override def <>(): Unit = println("B") + |} + | + |trait C extends A { + | override def <>(): Unit = println("C") + |} + | + |class D extends B with C { + | override def <>(): Unit = println("D") + |} + | + |object Main { + | def test(a: A, b: B, c: C, d: D): Unit = { + | a.<>() + | b.<>() + | c.<>() + | d.<>() + | } + |} + |""".stripMargin, + ) + override def assertCheck( filename: String, edit: String,