Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1387,6 +1387,9 @@ object MetalsEnrichments
new l.TextDocumentIdentifier(location.getUri()),
location.getRange().getStart(),
)

def toReferenceParams(includeDeclaration: Boolean): l.ReferenceParams =
toTextDocumentPositionParams.toReferenceParams(includeDeclaration)
}

implicit class XtensionTextDocumentPositionParams(
Expand All @@ -1407,6 +1410,16 @@ object MetalsEnrichments
).toOption.flatten
.getOrElse(input.text)
}

def toReferenceParams(includeDeclaration: Boolean): l.ReferenceParams = {
val referenceParams = new l.ReferenceParams()
referenceParams.setPosition(params.getPosition())
referenceParams.setTextDocument(params.getTextDocument())
val context = new l.ReferenceContext()
context.setIncludeDeclaration(includeDeclaration)
referenceParams.setContext(context)
referenceParams
}
}

implicit class XtensionDebugSessionParams(params: b.DebugSessionParams) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ abstract class MetalsLspService(
buildTargets,
compilers,
scalaVersionSelector,
() => implementationProvider,
)

protected val packageProvider: PackageProvider =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import scala.util.control.NonFatal
import scala.util.matching.Regex

import scala.meta.Importee
import scala.meta.internal.implementation.ImplementationProvider
import scala.meta.internal.implementation.SuperMethodProvider
import scala.meta.internal.metals.MetalsEnrichments._
import scala.meta.internal.metals.ResolvedSymbolOccurrence
import scala.meta.internal.mtags.DefinitionAlternatives.GlobalSymbol
Expand All @@ -37,6 +39,7 @@ import com.google.common.hash.BloomFilter
import com.google.common.hash.Funnels
import org.eclipse.lsp4j.Location
import org.eclipse.lsp4j.ReferenceParams
import org.eclipse.lsp4j.TextDocumentPositionParams

final class ReferenceProvider(
workspace: AbsolutePath,
Expand All @@ -47,6 +50,7 @@ final class ReferenceProvider(
buildTargets: BuildTargets,
compilers: Compilers,
scalaVersionSelector: ScalaVersionSelector,
implementationProvider: () => ImplementationProvider,
)(implicit ec: ExecutionContext)
extends SemanticdbFeatureProvider
with CompletionItemPriority {
Expand Down Expand Up @@ -176,6 +180,7 @@ final class ReferenceProvider(
params: ReferenceParams,
findRealRange: AdjustRange = noAdjustRange,
includeSynthetics: Synthetic => Boolean = _ => true,
includeImplementations: Boolean = true,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever set it to false?

)(implicit report: ReportContext): Future[List[ReferencesResult]] = {
val source = params.getTextDocument.getUri.toAbsolutePath
val textDoc = semanticdbs().textDocument(source)
Expand Down Expand Up @@ -250,21 +255,53 @@ final class ReferenceProvider(
}
}
}
val includeDeclaration = params.getContext().isIncludeDeclaration()
val pcResult =
pcReferences(
source,
results.flatMap(_.occurrence).map(_.symbol),
params.getContext().isIncludeDeclaration(),
includeDeclaration,
findRealRange,
)

val implementationRefs: Future[List[ReferencesResult]] =
if (includeImplementations) {
val hasMethodSymbol = results.exists(
_.occurrence.exists(_.symbol.desc.isMethod)
)
if (hasMethodSymbol) {
val textParams = new TextDocumentPositionParams(
params.getTextDocument(),
params.getPosition(),
)
implementationProvider().implementations(textParams).flatMap {
implLocs =>
val implRefs = implLocs.map { implLoc =>
val implParams =
implLoc.toReferenceParams(includeDeclaration)
references(
implParams,
findRealRange,
includeSynthetics,
includeImplementations = false,
)
}
Future.sequence(implRefs).map(_.flatten)
}
} else {
Future.successful(Nil)
}
} else {
Future.successful(Nil)
}

Future
.sequence(List(semanticdbResult, pcResult))
.sequence(List(semanticdbResult, pcResult, implementationRefs))
.map(
_.flatten
.groupBy(_.symbol)
.collect { case (symbol, refs) =>
ReferencesResult(symbol, refs.flatMap(_.locations))
ReferencesResult(symbol, refs.flatMap(_.locations).distinct)
}
.toList
)
Expand Down Expand Up @@ -379,6 +416,12 @@ final class ReferenceProvider(
alternatives.isContructorParam(info)
}.toSet

val overriddenSymbols = definitionDoc.symbols
.find(_.symbol == symbol)
.map(SuperMethodProvider.getSuperMethodHierarchy)
.getOrElse(Nil)
.toSet

val nonSyntheticSymbols = for {
occ <- definitionDoc.occurrences
if isCandidate(occ.symbol) || occ.symbol == symbol
Expand All @@ -397,11 +440,11 @@ final class ReferenceProvider(
} yield info.symbol

if (defPath.isJava)
isCandidate
isCandidate ++ overriddenSymbols
else if (isSyntheticSymbol)
isCandidate ++ additionalAlternativesForSynthetic
isCandidate ++ additionalAlternativesForSynthetic ++ overriddenSymbols
else
isCandidate
isCandidate ++ overriddenSymbols
case None => Set.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ import org.eclipse.lsp4j.Location
import org.eclipse.lsp4j.MessageParams
import org.eclipse.lsp4j.MessageType
import org.eclipse.lsp4j.Position
import org.eclipse.lsp4j.ReferenceContext
import org.eclipse.lsp4j.ReferenceParams
import org.eclipse.lsp4j.RenameFile
import org.eclipse.lsp4j.RenameParams
import org.eclipse.lsp4j.ResourceOperation
Expand Down Expand Up @@ -224,10 +222,8 @@ final class RenameProvider(
* isJava - in Java we can include declarations safely and we
* also need to include contructors.
*/
toReferenceParams(
txtParams,
includeDeclaration = isJava,
),
txtParams
.toReferenceParams(includeDeclaration = isJava),
findRealRange = AdjustRange(findRealRange(newName)),
includeSynthetic,
)
Expand Down Expand Up @@ -456,7 +452,7 @@ final class RenameProvider(
} yield {
referenceProvider
.references(
toReferenceParams(loc, includeDeclaration = false),
loc.toReferenceParams(includeDeclaration = false),
findRealRange = AdjustRange(findRealRange(newName)),
)
.map(_.flatMap(_.locations :+ loc))
Expand Down Expand Up @@ -507,7 +503,7 @@ final class RenameProvider(
result <- {
val result = for {
implLoc <- implLocs
locParams = toReferenceParams(implLoc, includeDeclaration = true)
locParams = implLoc.toReferenceParams(includeDeclaration = true)
} yield {
referenceProvider
.references(
Expand Down Expand Up @@ -706,44 +702,6 @@ final class RenameProvider(
}
}

private def toReferenceParams(
textDoc: TextDocumentIdentifier,
pos: Position,
includeDeclaration: Boolean,
): ReferenceParams = {
val referenceParams = new ReferenceParams()
referenceParams.setPosition(pos)
referenceParams.setTextDocument(textDoc)
val context = new ReferenceContext()
context.setIncludeDeclaration(includeDeclaration)
referenceParams.setContext(context)
referenceParams
}

private def toReferenceParams(
location: Location,
includeDeclaration: Boolean,
): ReferenceParams = {
val textDoc = new TextDocumentIdentifier()
textDoc.setUri(location.getUri())
toReferenceParams(
textDoc,
location.getRange().getStart(),
includeDeclaration,
)
}

private def toReferenceParams(
params: TextDocumentPositionParams,
includeDeclaration: Boolean,
): ReferenceParams = {
toReferenceParams(
params.getTextDocument(),
params.getPosition(),
includeDeclaration,
)
}

private def toTextParams(location: Location): TextDocumentPositionParams = {
new TextDocumentPositionParams(
new TextDocumentIdentifier(location.getUri()),
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/src/test/scala/tests/ReferenceLspSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,83 @@ class ReferenceLspSuite extends BaseRangesSuite("reference") {
} yield ()
}

check(
"override-references",
"""|/a/src/main/scala/a/Main.scala
|package a
|
|trait Animal {
| def <<speak>>(): Unit
|}
|
|class Dog extends Animal {
| override def <<sp@@eak>>(): Unit = println("Bark")
|}
|
|object Main {
| def test(a: Animal, d: Dog): Unit = {
| a.<<speak>>()
| d.<<speak>>()
| }
|}
|""".stripMargin,
)

check(
"find-references-includes-implementations",
"""|/a/src/main/scala/a/Main.scala
|package a
|
|trait Processor {
| def <<pr@@ocess>>(input: String): String
|}
|
|class StringProcessor extends Processor {
| override def <<process>>(input: String): String = input.toUpperCase
|}
|
|class NumberProcessor extends Processor {
| override def <<process>>(input: String): String = input.filter(_.isDigit)
|}
|
|object Main {
| def use(p: Processor): String = p.<<process>>("test")
|}
|""".stripMargin,
)

check(
"diamond-inheritance-references",
"""|/a/src/main/scala/a/Main.scala
|package a
|
|trait A {
| def <<fo@@o>>(): Unit
|}
|
|trait B extends A {
| override def <<foo>>(): Unit = println("B")
|}
|
|trait C extends A {
| override def <<foo>>(): Unit = println("C")
|}
|
|class D extends B with C {
| override def <<foo>>(): Unit = println("D")
|}
|
|object Main {
| def test(a: A, b: B, c: C, d: D): Unit = {
| a.<<foo>>()
| b.<<foo>>()
| c.<<foo>>()
| d.<<foo>>()
| }
|}
|""".stripMargin,
)

override def assertCheck(
filename: String,
edit: String,
Expand Down
Loading