Skip to content

Commit 87fe644

Browse files
committed
feat: implement Convert to collect code action
1 parent d42536c commit 87fe644

File tree

5 files changed

+439
-0
lines changed

5 files changed

+439
-0
lines changed

docs/features/code-actions.md

+6
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ It converts a chain of `map`, `flatMap`, `filter` and `filterNot` methods into a
5555

5656
![To For Comprehension](./gifs/FlatMapToForComprehension.gif)
5757

58+
## filter then map to collect
59+
60+
It converts a chain of `filter` and `map` methods into a `collect` method.
61+
62+
![To Collect](./gifs/FilterMapToCollect.gif)
63+
5864
## Implement Abstract Members of the Parent Type
5965

6066
Upon inheriting from a type, you also have to implement its abstract members. But manually looking them all up and copying their signature is time consuming, isn't it? You can just use this code action instead.
69.5 KB
Loading

metals/src/main/scala/scala/meta/internal/metals/codeactions/CodeActionProvider.scala

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ final class CodeActionProvider(
4444
new InlineValueCodeAction(trees, compilers, languageClient),
4545
new ConvertToNamedArguments(trees, compilers, languageClient),
4646
new FlatMapToForComprehensionCodeAction(trees, buffers),
47+
new FilterMapToCollectCodeAction(trees),
4748
new MillifyDependencyCodeAction(buffers),
4849
new MillifyScalaCliDependencyCodeAction(buffers),
4950
new ConvertCommentCodeAction(buffers),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
package scala.meta.internal.metals.codeactions
2+
3+
import scala.concurrent.ExecutionContext
4+
import scala.concurrent.Future
5+
6+
import scala.meta._
7+
import scala.meta.internal.metals.MetalsEnrichments._
8+
import scala.meta.internal.metals.codeactions.CodeAction
9+
import scala.meta.internal.metals.codeactions.CodeActionBuilder
10+
import scala.meta.internal.parsing.Trees
11+
import scala.meta.pc.CancelToken
12+
13+
import org.eclipse.lsp4j.CodeActionParams
14+
import org.eclipse.{lsp4j => l}
15+
16+
class FilterMapToCollectCodeAction(trees: Trees) extends CodeAction {
17+
override def kind: String = l.CodeActionKind.RefactorRewrite
18+
19+
override def contribute(params: CodeActionParams, token: CancelToken)(implicit
20+
ec: ExecutionContext
21+
): Future[Seq[l.CodeAction]] = Future {
22+
val uri = params.getTextDocument().getUri()
23+
24+
val path = uri.toAbsolutePath
25+
val range = params.getRange()
26+
27+
trees
28+
.findLastEnclosingAt[Term.Apply](path, range.getStart())
29+
.flatMap(findFilterMapChain)
30+
.map(toTextEdit(_))
31+
.map(toCodeAction(uri, _))
32+
.toSeq
33+
}
34+
35+
private def toTextEdit(chain: FilterMapChain) = {
36+
val param = chain.filterFn.params.head
37+
val paramName = Term.Name(param.name.value)
38+
val paramPatWithType = param.decltpe match {
39+
case Some(tpe) => Pat.Typed(Pat.Var(paramName), tpe)
40+
case None => Pat.Var(paramName)
41+
}
42+
43+
val collectCall = Term.Apply(
44+
fun = Term.Select(chain.qual, Term.Name("collect")),
45+
argClause = Term.ArgClause(
46+
values = List(
47+
Term.PartialFunction(
48+
cases = List(
49+
Case(
50+
pat = paramPatWithType,
51+
cond = Some(chain.filterFn.renameParam(paramName)),
52+
body = chain.mapFn.renameParam(paramName),
53+
)
54+
)
55+
)
56+
)
57+
),
58+
)
59+
val indented = collectCall.syntax.linesIterator.zipWithIndex
60+
.map {
61+
case (line, 0) => line
62+
case (line, _) => " " + line
63+
}
64+
.mkString("\n")
65+
66+
new l.TextEdit(chain.tree.pos.toLsp, indented)
67+
}
68+
69+
private def toCodeAction(uri: String, textEdit: l.TextEdit): l.CodeAction =
70+
CodeActionBuilder.build(
71+
title = FilterMapToCollectCodeAction.title,
72+
kind = this.kind,
73+
changes = List(uri.toAbsolutePath -> List(textEdit)),
74+
)
75+
76+
private implicit class FunctionOps(fn: Term.Function) {
77+
def renameParam(to: Term.Name): Term = {
78+
val fnParamName = fn.params.head.name.value
79+
fn.body
80+
.transform { case Term.Name(name) if name == fnParamName => to }
81+
.asInstanceOf[Term]
82+
}
83+
}
84+
85+
private def findFilterMapChain(tree: Term.Apply): Option[FilterMapChain] = {
86+
val x = Term.Name("x")
87+
def extractFunction(arg: Tree): Option[Term.Function] = arg match {
88+
case fn: Term.Function => Some(fn)
89+
case Term.Block(List(fn: Term.Function)) => extractFunction(fn)
90+
case ref: Term.Name => {
91+
Some(
92+
Term.Function(
93+
UnaryParameterList(x),
94+
Term.Apply(ref, Term.ArgClause(List(x))),
95+
)
96+
)
97+
}
98+
case _ => None
99+
}
100+
101+
def findChain(tree: Term.Apply): Option[FilterMapChain] =
102+
tree match {
103+
case MapFunctionApply(FilterFunctionApply(base, filterArg), mapArg) =>
104+
for {
105+
filterFn <- extractFunction(filterArg)
106+
mapFn <- extractFunction(mapArg)
107+
} yield FilterMapChain(tree, base, filterFn, mapFn)
108+
case _ => None
109+
}
110+
111+
findChain(tree).orElse {
112+
// If we're inside the chain, look at our parent
113+
tree.parent.flatMap {
114+
// We're in a method call or function, look at parent apply
115+
case Term.Select(_, Term.Name("map" | "filter")) | Term.Function(_) =>
116+
tree.parent
117+
.flatMap(_.parent)
118+
.collectFirst { case parent: Term.Apply => parent }
119+
.flatMap(findChain)
120+
case _ => None
121+
}
122+
}
123+
}
124+
125+
private object UnaryParameterList {
126+
def unapply(tree: Tree): Option[Name] = tree match {
127+
case Term.Param(_, name, _, _) => Some(name)
128+
case _ => None
129+
}
130+
def apply(name: Name): List[Term.Param] = List(
131+
Term.Param(Nil, name, None, None)
132+
)
133+
}
134+
135+
private case class FunctionApply(val name: String) {
136+
def unapply(tree: Tree): Option[(Term, Term)] = tree match {
137+
case Term.Apply(Term.Select(base, Term.Name(`name`)), List(args)) =>
138+
Some((base, args))
139+
case _ => None
140+
}
141+
}
142+
private val FilterFunctionApply = new FunctionApply("filter")
143+
private val MapFunctionApply = new FunctionApply("map")
144+
145+
private case class FilterMapChain(
146+
tree: Term.Apply,
147+
qual: Term,
148+
filterFn: Term.Function,
149+
mapFn: Term.Function,
150+
)
151+
}
152+
153+
object FilterMapToCollectCodeAction {
154+
val title = "Convert to collect"
155+
}

0 commit comments

Comments
 (0)