Skip to content

Commit 595660d

Browse files
zielinskytgodzik
authored andcommitted
Add InferredMethodProvider for automatic method signature generation (scala#23563)
Porting scalameta/metals#6877 Related discussion and feature request: scalameta/metals-feature-requests#298 [Cherry-picked 405e2bd]
1 parent 2c71cb8 commit 595660d

File tree

3 files changed

+897
-0
lines changed

3 files changed

+897
-0
lines changed
Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
package dotty.tools.pc
2+
3+
import java.nio.file.Paths
4+
5+
import scala.annotation.tailrec
6+
7+
import scala.meta.pc.OffsetParams
8+
import scala.meta.pc.PresentationCompilerConfig
9+
import scala.meta.pc.SymbolSearch
10+
import scala.meta.pc.reports.ReportContext
11+
12+
import dotty.tools.dotc.ast.tpd.*
13+
import dotty.tools.dotc.core.Contexts.*
14+
import dotty.tools.dotc.core.Names.Name
15+
import dotty.tools.dotc.core.Symbols.*
16+
import dotty.tools.dotc.core.Symbols.defn
17+
import dotty.tools.dotc.core.Types.*
18+
import dotty.tools.dotc.interactive.Interactive
19+
import dotty.tools.dotc.interactive.InteractiveDriver
20+
import dotty.tools.dotc.util.SourceFile
21+
import dotty.tools.dotc.util.SourcePosition
22+
import dotty.tools.pc.printer.ShortenedTypePrinter
23+
import dotty.tools.pc.printer.ShortenedTypePrinter.IncludeDefaultParam
24+
import dotty.tools.pc.utils.InteractiveEnrichments.*
25+
26+
import org.eclipse.lsp4j.TextEdit
27+
import org.eclipse.lsp4j as l
28+
29+
/**
30+
* Tries to calculate edits needed to create a method that will fix missing symbol
31+
* in all the places that it is possible such as:
32+
* - apply inside method invocation `method(.., nonExistent(param), ...)` and `method(.., nonExistent, ...)`
33+
* - method in val definition `val value: DefinedType = nonExistent(param)` and `val value: DefinedType = nonExistent`
34+
* - simple method call `nonExistent(param)` and `nonExistent`
35+
* - method call inside a container `container.nonExistent(param)` and `container.nonExistent`
36+
*
37+
* @param params position and actual source
38+
* @param driver Scala 3 interactive compiler driver
39+
* @param config presentation compiler configuration
40+
* @param symbolSearch symbol search
41+
*/
42+
final class InferredMethodProvider(
43+
params: OffsetParams,
44+
driver: InteractiveDriver,
45+
config: PresentationCompilerConfig,
46+
symbolSearch: SymbolSearch
47+
)(using ReportContext):
48+
49+
case class AdjustTypeOpts(
50+
text: String,
51+
adjustedEndPos: l.Position
52+
)
53+
54+
def inferredMethodEdits(
55+
adjustOpt: Option[AdjustTypeOpts] = None
56+
): List[TextEdit] =
57+
val uri = params.uri().nn
58+
val filePath = Paths.get(uri).nn
59+
60+
val sourceText = adjustOpt.map(_.text).getOrElse(params.text().nn)
61+
val source =
62+
SourceFile.virtual(filePath.toString(), sourceText)
63+
driver.run(uri, source)
64+
val unit = driver.currentCtx.run.nn.units.head
65+
val pos = driver.sourcePosition(params)
66+
val path =
67+
Interactive.pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx)
68+
69+
given locatedCtx: Context = driver.localContext(params)
70+
val indexedCtx = IndexedContext(pos)(using locatedCtx)
71+
72+
val autoImportsGen = AutoImports.generator(
73+
pos,
74+
sourceText,
75+
unit.tpdTree,
76+
unit.comments,
77+
indexedCtx,
78+
config
79+
)
80+
81+
val printer = ShortenedTypePrinter(
82+
symbolSearch,
83+
includeDefaultParam = IncludeDefaultParam.ResolveLater,
84+
isTextEdit = true
85+
)(using indexedCtx)
86+
87+
def imports: List[TextEdit] =
88+
printer.imports(autoImportsGen)
89+
90+
def printType(tpe: Type): String =
91+
printer.tpe(tpe)
92+
93+
def printName(name: Name): String =
94+
printer.nameString(name)
95+
96+
def printParams(params: List[Type], startIndex: Int = 0): String =
97+
params.zipWithIndex
98+
.map { case (p, index) =>
99+
s"arg${index + startIndex}: ${printType(p)}"
100+
}
101+
.mkString(", ")
102+
103+
def printSignature(
104+
methodName: Name,
105+
params: List[List[Type]],
106+
retTypeOpt: Option[Type]
107+
): String =
108+
val retTypeString = retTypeOpt match
109+
case Some(retType) =>
110+
val printRetType = printType(retType)
111+
if retType.isAny then ""
112+
else s": $printRetType"
113+
case _ => ""
114+
115+
val (paramsString, _) = params.foldLeft(("", 0)){
116+
case ((acc, startIdx), paramList) =>
117+
val printed = s"(${printParams(paramList, startIdx)})"
118+
(acc + printed, startIdx + paramList.size)
119+
}
120+
121+
s"def ${printName(methodName)}$paramsString$retTypeString = ???"
122+
123+
@tailrec
124+
def countIndent(text: String, index: Int, acc: Int): Int =
125+
if index > 0 && text(index) != '\n' then countIndent(text, index - 1, acc + 1)
126+
else acc
127+
128+
def indentation(text: String, pos: Int): String =
129+
if pos > 0 then
130+
val isSpace = text(pos) == ' '
131+
val isTab = text(pos) == '\t'
132+
val indent = countIndent(params.text(), pos, 0)
133+
134+
if isSpace then " " * indent else if isTab then "\t" * indent else ""
135+
else ""
136+
137+
def insertPosition() =
138+
val blockOrTemplateIndex =
139+
path.tail.indexWhere {
140+
case _: Block | _: Template => true
141+
case _ => false
142+
}
143+
path(blockOrTemplateIndex).sourcePos
144+
145+
/**
146+
* Returns the position to insert the method signature for a container.
147+
* If the container has an empty body, the position is the end of the container.
148+
* If the container has a non-empty body, the position is the end of the last element in the body.
149+
*
150+
* @param container the container to insert the method signature for
151+
* @return the position to insert the method signature for the container and a boolean indicating if the container has an empty body
152+
*/
153+
def insertPositionFor(container: Tree): Option[(SourcePosition, Boolean)] =
154+
val typeSymbol = container.tpe.widenDealias.typeSymbol
155+
if typeSymbol.exists then
156+
val trees = driver.openedTrees(params.uri().nn)
157+
val include = Interactive.Include.definitions | Interactive.Include.local
158+
Interactive.findTreesMatching(trees, include, typeSymbol).headOption match
159+
case Some(srcTree) =>
160+
srcTree.tree match
161+
case classDef: TypeDef if classDef.rhs.isInstanceOf[Template] =>
162+
val template = classDef.rhs.asInstanceOf[Template]
163+
val (pos, hasEmptyBody) = template.body.lastOption match
164+
case Some(last) => (last.sourcePos, false)
165+
case None => (classDef.sourcePos, true)
166+
Some((pos, hasEmptyBody))
167+
case _ => None
168+
case None => None
169+
else None
170+
171+
/**
172+
* Extracts type information for a specific parameter in a method signature.
173+
* If the parameter is a function type, extracts both the function's argument types
174+
* and return type. Otherwise, extracts just the parameter type.
175+
*
176+
* @param methodType the method type to analyze
177+
* @param argIndex the index of the parameter to extract information for
178+
* @return a tuple of (argument types, return type) where:
179+
* - argument types: Some(List[Type]) if parameter is a function, None otherwise
180+
* - return type: Some(Type) representing either the function's return type or the parameter type itself
181+
*/
182+
def extractParameterTypeInfo(methodType: Type, argIndex: Int): (Option[List[Type]], Option[Type]) =
183+
methodType match
184+
case m @ MethodType(param) =>
185+
val expectedFunctionType = m.paramInfos(argIndex)
186+
if defn.isFunctionType(expectedFunctionType) then
187+
expectedFunctionType match
188+
case defn.FunctionOf(argTypes, retType, _) =>
189+
(Some(argTypes), Some(retType))
190+
case _ =>
191+
(None, Some(expectedFunctionType))
192+
else
193+
(None, Some(m.paramInfos(argIndex)))
194+
case _ => (None, None)
195+
196+
def signatureEdits(signature: String): List[TextEdit] =
197+
val pos = insertPosition()
198+
val indent = indentation(params.text(), pos.start - 1)
199+
val lspPos = pos.toLsp
200+
lspPos.setEnd(lspPos.getStart())
201+
202+
List(
203+
TextEdit(
204+
lspPos,
205+
s"$signature\n$indent",
206+
)
207+
) ::: imports
208+
209+
def signatureEditsForContainer(signature: String, container: Tree): List[TextEdit] =
210+
insertPositionFor(container) match
211+
case Some((pos, hasEmptyBody)) =>
212+
val lspPos = pos.toLsp
213+
lspPos.setStart(lspPos.getEnd())
214+
val indent = indentation(params.text(), pos.start - 1)
215+
216+
if hasEmptyBody then
217+
List(
218+
TextEdit(
219+
lspPos,
220+
s":\n $indent$signature",
221+
)
222+
) ::: imports
223+
else
224+
List(
225+
TextEdit(
226+
lspPos,
227+
s"\n$indent$signature",
228+
)
229+
) ::: imports
230+
case None => Nil
231+
232+
path match
233+
/**
234+
* outerArgs
235+
* ---------------------------
236+
* method(..., errorMethod(args), ...)
237+
*
238+
*/
239+
case (id @ Ident(errorMethod)) ::
240+
(apply @ Apply(func, args)) ::
241+
Apply(method, outerArgs) ::
242+
_ if id.symbol == NoSymbol && func == id && method != apply =>
243+
244+
val argTypes = args.map(_.typeOpt.widenDealias)
245+
246+
val argIndex = outerArgs.indexOf(apply)
247+
val (allArgTypes, retTypeOpt) =
248+
extractParameterTypeInfo(method.tpe.widenDealias, argIndex) match
249+
case (Some(argTypes2), retTypeOpt) => (List(argTypes, argTypes2), retTypeOpt)
250+
case (None, retTypeOpt) => (List(argTypes), retTypeOpt)
251+
252+
val signature = printSignature(errorMethod, allArgTypes, retTypeOpt)
253+
254+
signatureEdits(signature)
255+
256+
/**
257+
* outerArgs
258+
* ---------------------
259+
* method(..., errorMethod, ...)
260+
*
261+
*/
262+
case (id @ Ident(errorMethod)) ::
263+
Apply(method, outerArgs) ::
264+
_ if id.symbol == NoSymbol && method != id =>
265+
266+
val argIndex = outerArgs.indexOf(id)
267+
268+
val (argTypes, retTypeOpt) = extractParameterTypeInfo(method.tpe.widenDealias, argIndex)
269+
270+
val allArgTypes = argTypes match
271+
case Some(argTypes) => List(argTypes)
272+
case None => Nil
273+
274+
val signature = printSignature(errorMethod, allArgTypes, retTypeOpt)
275+
276+
signatureEdits(signature)
277+
278+
/**
279+
* tpt body
280+
* ----------- ----------------
281+
* val value: DefinedType = errorMethod(args)
282+
*
283+
*/
284+
case (id @ Ident(errorMethod)) ::
285+
(apply @ Apply(func, args)) ::
286+
ValDef(_, tpt, body) ::
287+
_ if id.symbol == NoSymbol && func == id && apply == body =>
288+
289+
val retType = tpt.tpe.widenDealias
290+
val argTypes = args.map(_.typeOpt.widenDealias)
291+
292+
val signature = printSignature(errorMethod, List(argTypes), Some(retType))
293+
signatureEdits(signature)
294+
295+
/**
296+
* tpt body
297+
* ----------- -----------
298+
* val value: DefinedType = errorMethod
299+
*
300+
*/
301+
case (id @ Ident(errorMethod)) ::
302+
ValDef(_, tpt, body) ::
303+
_ if id.symbol == NoSymbol && id == body =>
304+
305+
val retType = tpt.tpe.widenDealias
306+
307+
val signature = printSignature(errorMethod, Nil, Some(retType))
308+
signatureEdits(signature)
309+
310+
/**
311+
*
312+
* errorMethod(args)
313+
*
314+
*/
315+
case (id @ Ident(errorMethod)) ::
316+
(apply @ Apply(func, args)) ::
317+
_ if id.symbol == NoSymbol && func == id =>
318+
319+
val argTypes = args.map(_.typeOpt.widenDealias)
320+
321+
val signature = printSignature(errorMethod, List(argTypes), None)
322+
signatureEdits(signature)
323+
324+
/**
325+
*
326+
* errorMethod
327+
*
328+
*/
329+
case (id @ Ident(errorMethod)) ::
330+
_ if id.symbol == NoSymbol =>
331+
332+
val signature = printSignature(errorMethod, Nil, None)
333+
signatureEdits(signature)
334+
335+
/**
336+
*
337+
* container.errorMethod(args)
338+
*
339+
*/
340+
case (select @ Select(container, errorMethod)) ::
341+
(apply @ Apply(func, args)) ::
342+
_ if select.symbol == NoSymbol && func == select =>
343+
344+
val argTypes = args.map(_.typeOpt.widenDealias)
345+
val signature = printSignature(errorMethod, List(argTypes), None)
346+
signatureEditsForContainer(signature, container)
347+
348+
/**
349+
*
350+
* container.errorMethod
351+
*
352+
*/
353+
case (select @ Select(container, errorMethod)) ::
354+
_ if select.symbol == NoSymbol =>
355+
356+
val signature = printSignature(errorMethod, Nil, None)
357+
signatureEditsForContainer(signature, container)
358+
359+
case _ => Nil
360+
361+
end inferredMethodEdits
362+
end InferredMethodProvider

presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ case class ScalaPresentationCompiler(
6464
CodeActionId.ExtractMethod,
6565
CodeActionId.InlineValue,
6666
CodeActionId.InsertInferredType,
67+
CodeActionId.InsertInferredMethod,
6768
PcConvertToNamedLambdaParameters.codeActionId
6869
).asJava
6970

@@ -92,6 +93,8 @@ case class ScalaPresentationCompiler(
9293
implementAbstractMembers(params)
9394
case (CodeActionId.InsertInferredType, _) =>
9495
insertInferredType(params)
96+
case (CodeActionId.InsertInferredMethod, _) =>
97+
insertInferredMethod(params)
9598
case (CodeActionId.InlineValue, _) =>
9699
inlineValue(params)
97100
case (CodeActionId.ExtractMethod, Some(extractionPos: OffsetParams)) =>
@@ -352,6 +355,19 @@ case class ScalaPresentationCompiler(
352355
.asJava
353356
}(params.toQueryContext)
354357

358+
def insertInferredMethod(
359+
params: OffsetParams
360+
): CompletableFuture[ju.List[l.TextEdit]] =
361+
val empty: ju.List[l.TextEdit] = new ju.ArrayList[l.TextEdit]()
362+
compilerAccess.withNonInterruptableCompiler(
363+
empty,
364+
params.token()
365+
) { pc =>
366+
new InferredMethodProvider(params, pc.compiler(), config, search)
367+
.inferredMethodEdits()
368+
.asJava
369+
}(params.toQueryContext)
370+
355371
override def inlineValue(
356372
params: OffsetParams
357373
): CompletableFuture[ju.List[l.TextEdit]] =

0 commit comments

Comments
 (0)