diff --git a/presentation-compiler/src/main/dotty/tools/pc/InferredMethodProvider.scala b/presentation-compiler/src/main/dotty/tools/pc/InferredMethodProvider.scala new file mode 100644 index 000000000000..e6f27781bc64 --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/InferredMethodProvider.scala @@ -0,0 +1,362 @@ +package dotty.tools.pc + +import java.nio.file.Paths + +import scala.annotation.tailrec + +import scala.meta.pc.OffsetParams +import scala.meta.pc.PresentationCompilerConfig +import scala.meta.pc.SymbolSearch +import scala.meta.pc.reports.ReportContext + +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Contexts.* +import dotty.tools.dotc.core.Names.Name +import dotty.tools.dotc.core.Symbols.* +import dotty.tools.dotc.core.Symbols.defn +import dotty.tools.dotc.core.Types.* +import dotty.tools.dotc.interactive.Interactive +import dotty.tools.dotc.interactive.InteractiveDriver +import dotty.tools.dotc.util.SourceFile +import dotty.tools.dotc.util.SourcePosition +import dotty.tools.pc.printer.ShortenedTypePrinter +import dotty.tools.pc.printer.ShortenedTypePrinter.IncludeDefaultParam +import dotty.tools.pc.utils.InteractiveEnrichments.* + +import org.eclipse.lsp4j.TextEdit +import org.eclipse.lsp4j as l + +/** + * Tries to calculate edits needed to create a method that will fix missing symbol + * in all the places that it is possible such as: + * - apply inside method invocation `method(.., nonExistent(param), ...)` and `method(.., nonExistent, ...)` + * - method in val definition `val value: DefinedType = nonExistent(param)` and `val value: DefinedType = nonExistent` + * - simple method call `nonExistent(param)` and `nonExistent` + * - method call inside a container `container.nonExistent(param)` and `container.nonExistent` + * + * @param params position and actual source + * @param driver Scala 3 interactive compiler driver + * @param config presentation compiler configuration + * @param symbolSearch symbol search + */ +final class InferredMethodProvider( + params: OffsetParams, + driver: InteractiveDriver, + config: PresentationCompilerConfig, + symbolSearch: SymbolSearch +)(using ReportContext): + + case class AdjustTypeOpts( + text: String, + adjustedEndPos: l.Position + ) + + def inferredMethodEdits( + adjustOpt: Option[AdjustTypeOpts] = None + ): List[TextEdit] = + val uri = params.uri().nn + val filePath = Paths.get(uri).nn + + val sourceText = adjustOpt.map(_.text).getOrElse(params.text().nn) + val source = + SourceFile.virtual(filePath.toString(), sourceText) + driver.run(uri, source) + val unit = driver.currentCtx.run.nn.units.head + val pos = driver.sourcePosition(params) + val path = + Interactive.pathTo(driver.openedTrees(uri), pos)(using driver.currentCtx) + + given locatedCtx: Context = driver.localContext(params) + val indexedCtx = IndexedContext(pos)(using locatedCtx) + + val autoImportsGen = AutoImports.generator( + pos, + sourceText, + unit.tpdTree, + unit.comments, + indexedCtx, + config + ) + + val printer = ShortenedTypePrinter( + symbolSearch, + includeDefaultParam = IncludeDefaultParam.ResolveLater, + isTextEdit = true + )(using indexedCtx) + + def imports: List[TextEdit] = + printer.imports(autoImportsGen) + + def printType(tpe: Type): String = + printer.tpe(tpe) + + def printName(name: Name): String = + printer.nameString(name) + + def printParams(params: List[Type], startIndex: Int = 0): String = + params.zipWithIndex + .map { case (p, index) => + s"arg${index + startIndex}: ${printType(p)}" + } + .mkString(", ") + + def printSignature( + methodName: Name, + params: List[List[Type]], + retTypeOpt: Option[Type] + ): String = + val retTypeString = retTypeOpt match + case Some(retType) => + val printRetType = printType(retType) + if retType.isAny then "" + else s": $printRetType" + case _ => "" + + val (paramsString, _) = params.foldLeft(("", 0)){ + case ((acc, startIdx), paramList) => + val printed = s"(${printParams(paramList, startIdx)})" + (acc + printed, startIdx + paramList.size) + } + + s"def ${printName(methodName)}$paramsString$retTypeString = ???" + + @tailrec + def countIndent(text: String, index: Int, acc: Int): Int = + if index > 0 && text(index) != '\n' then countIndent(text, index - 1, acc + 1) + else acc + + def indentation(text: String, pos: Int): String = + if pos > 0 then + val isSpace = text(pos) == ' ' + val isTab = text(pos) == '\t' + val indent = countIndent(params.text(), pos, 0) + + if isSpace then " " * indent else if isTab then "\t" * indent else "" + else "" + + def insertPosition() = + val blockOrTemplateIndex = + path.tail.indexWhere { + case _: Block | _: Template => true + case _ => false + } + path(blockOrTemplateIndex).sourcePos + + /** + * Returns the position to insert the method signature for a container. + * If the container has an empty body, the position is the end of the container. + * If the container has a non-empty body, the position is the end of the last element in the body. + * + * @param container the container to insert the method signature for + * @return the position to insert the method signature for the container and a boolean indicating if the container has an empty body + */ + def insertPositionFor(container: Tree): Option[(SourcePosition, Boolean)] = + val typeSymbol = container.tpe.widenDealias.typeSymbol + if typeSymbol.exists then + val trees = driver.openedTrees(params.uri().nn) + val include = Interactive.Include.definitions | Interactive.Include.local + Interactive.findTreesMatching(trees, include, typeSymbol).headOption match + case Some(srcTree) => + srcTree.tree match + case classDef: TypeDef if classDef.rhs.isInstanceOf[Template] => + val template = classDef.rhs.asInstanceOf[Template] + val (pos, hasEmptyBody) = template.body.lastOption match + case Some(last) => (last.sourcePos, false) + case None => (classDef.sourcePos, true) + Some((pos, hasEmptyBody)) + case _ => None + case None => None + else None + + /** + * Extracts type information for a specific parameter in a method signature. + * If the parameter is a function type, extracts both the function's argument types + * and return type. Otherwise, extracts just the parameter type. + * + * @param methodType the method type to analyze + * @param argIndex the index of the parameter to extract information for + * @return a tuple of (argument types, return type) where: + * - argument types: Some(List[Type]) if parameter is a function, None otherwise + * - return type: Some(Type) representing either the function's return type or the parameter type itself + */ + def extractParameterTypeInfo(methodType: Type, argIndex: Int): (Option[List[Type]], Option[Type]) = + methodType match + case m @ MethodType(param) => + val expectedFunctionType = m.paramInfos(argIndex) + if defn.isFunctionType(expectedFunctionType) then + expectedFunctionType match + case defn.FunctionOf(argTypes, retType, _) => + (Some(argTypes), Some(retType)) + case _ => + (None, Some(expectedFunctionType)) + else + (None, Some(m.paramInfos(argIndex))) + case _ => (None, None) + + def signatureEdits(signature: String): List[TextEdit] = + val pos = insertPosition() + val indent = indentation(params.text(), pos.start - 1) + val lspPos = pos.toLsp + lspPos.setEnd(lspPos.getStart()) + + List( + TextEdit( + lspPos, + s"$signature\n$indent", + ) + ) ::: imports + + def signatureEditsForContainer(signature: String, container: Tree): List[TextEdit] = + insertPositionFor(container) match + case Some((pos, hasEmptyBody)) => + val lspPos = pos.toLsp + lspPos.setStart(lspPos.getEnd()) + val indent = indentation(params.text(), pos.start - 1) + + if hasEmptyBody then + List( + TextEdit( + lspPos, + s":\n $indent$signature", + ) + ) ::: imports + else + List( + TextEdit( + lspPos, + s"\n$indent$signature", + ) + ) ::: imports + case None => Nil + + path match + /** + * outerArgs + * --------------------------- + * method(..., errorMethod(args), ...) + * + */ + case (id @ Ident(errorMethod)) :: + (apply @ Apply(func, args)) :: + Apply(method, outerArgs) :: + _ if id.symbol == NoSymbol && func == id && method != apply => + + val argTypes = args.map(_.typeOpt.widenDealias) + + val argIndex = outerArgs.indexOf(apply) + val (allArgTypes, retTypeOpt) = + extractParameterTypeInfo(method.tpe.widenDealias, argIndex) match + case (Some(argTypes2), retTypeOpt) => (List(argTypes, argTypes2), retTypeOpt) + case (None, retTypeOpt) => (List(argTypes), retTypeOpt) + + val signature = printSignature(errorMethod, allArgTypes, retTypeOpt) + + signatureEdits(signature) + + /** + * outerArgs + * --------------------- + * method(..., errorMethod, ...) + * + */ + case (id @ Ident(errorMethod)) :: + Apply(method, outerArgs) :: + _ if id.symbol == NoSymbol && method != id => + + val argIndex = outerArgs.indexOf(id) + + val (argTypes, retTypeOpt) = extractParameterTypeInfo(method.tpe.widenDealias, argIndex) + + val allArgTypes = argTypes match + case Some(argTypes) => List(argTypes) + case None => Nil + + val signature = printSignature(errorMethod, allArgTypes, retTypeOpt) + + signatureEdits(signature) + + /** + * tpt body + * ----------- ---------------- + * val value: DefinedType = errorMethod(args) + * + */ + case (id @ Ident(errorMethod)) :: + (apply @ Apply(func, args)) :: + ValDef(_, tpt, body) :: + _ if id.symbol == NoSymbol && func == id && apply == body => + + val retType = tpt.tpe.widenDealias + val argTypes = args.map(_.typeOpt.widenDealias) + + val signature = printSignature(errorMethod, List(argTypes), Some(retType)) + signatureEdits(signature) + + /** + * tpt body + * ----------- ----------- + * val value: DefinedType = errorMethod + * + */ + case (id @ Ident(errorMethod)) :: + ValDef(_, tpt, body) :: + _ if id.symbol == NoSymbol && id == body => + + val retType = tpt.tpe.widenDealias + + val signature = printSignature(errorMethod, Nil, Some(retType)) + signatureEdits(signature) + + /** + * + * errorMethod(args) + * + */ + case (id @ Ident(errorMethod)) :: + (apply @ Apply(func, args)) :: + _ if id.symbol == NoSymbol && func == id => + + val argTypes = args.map(_.typeOpt.widenDealias) + + val signature = printSignature(errorMethod, List(argTypes), None) + signatureEdits(signature) + + /** + * + * errorMethod + * + */ + case (id @ Ident(errorMethod)) :: + _ if id.symbol == NoSymbol => + + val signature = printSignature(errorMethod, Nil, None) + signatureEdits(signature) + + /** + * + * container.errorMethod(args) + * + */ + case (select @ Select(container, errorMethod)) :: + (apply @ Apply(func, args)) :: + _ if select.symbol == NoSymbol && func == select => + + val argTypes = args.map(_.typeOpt.widenDealias) + val signature = printSignature(errorMethod, List(argTypes), None) + signatureEditsForContainer(signature, container) + + /** + * + * container.errorMethod + * + */ + case (select @ Select(container, errorMethod)) :: + _ if select.symbol == NoSymbol => + + val signature = printSignature(errorMethod, Nil, None) + signatureEditsForContainer(signature, container) + + case _ => Nil + + end inferredMethodEdits +end InferredMethodProvider diff --git a/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala b/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala index 2f218687296f..18311d1b7853 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala @@ -64,6 +64,7 @@ case class ScalaPresentationCompiler( CodeActionId.ExtractMethod, CodeActionId.InlineValue, CodeActionId.InsertInferredType, + CodeActionId.InsertInferredMethod, PcConvertToNamedLambdaParameters.codeActionId ).asJava @@ -92,6 +93,8 @@ case class ScalaPresentationCompiler( implementAbstractMembers(params) case (CodeActionId.InsertInferredType, _) => insertInferredType(params) + case (CodeActionId.InsertInferredMethod, _) => + insertInferredMethod(params) case (CodeActionId.InlineValue, _) => inlineValue(params) case (CodeActionId.ExtractMethod, Some(extractionPos: OffsetParams)) => @@ -352,6 +355,19 @@ case class ScalaPresentationCompiler( .asJava }(params.toQueryContext) + def insertInferredMethod( + params: OffsetParams + ): CompletableFuture[ju.List[l.TextEdit]] = + val empty: ju.List[l.TextEdit] = new ju.ArrayList[l.TextEdit]() + compilerAccess.withNonInterruptableCompiler( + empty, + params.token() + ) { pc => + new InferredMethodProvider(params, pc.compiler(), config, search) + .inferredMethodEdits() + .asJava + }(params.toQueryContext) + override def inlineValue( params: OffsetParams ): CompletableFuture[ju.List[l.TextEdit]] = diff --git a/presentation-compiler/test/dotty/tools/pc/tests/edit/InsertInferredMethodSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/edit/InsertInferredMethodSuite.scala new file mode 100644 index 000000000000..2b8e2ef32ef5 --- /dev/null +++ b/presentation-compiler/test/dotty/tools/pc/tests/edit/InsertInferredMethodSuite.scala @@ -0,0 +1,519 @@ +package dotty.tools.pc.tests.edit + +import java.net.URI +import java.util.Optional + +import scala.meta.internal.jdk.CollectionConverters.* +import scala.meta.internal.metals.CompilerOffsetParams +import scala.meta.pc.CodeActionId +import scala.language.unsafeNulls + +import dotty.tools.pc.base.BaseCodeActionSuite +import dotty.tools.pc.utils.TextEdits + +import org.eclipse.lsp4j as l +import org.junit.Test + +class InsertInferredMethodSuite extends BaseCodeActionSuite: + + @Test def `simple` = + checkEdit( + """| + |trait Main { + | def method1(s : String) = 123 + | + | method1(<>(1)) + |} + | + |""".stripMargin, + """|trait Main { + | def method1(s : String) = 123 + | + | def otherMethod(arg0: Int): String = ??? + | method1(otherMethod(1)) + |} + |""".stripMargin + ) + + @Test def `simple-2` = + checkEdit( + """| + |trait Main { + | def method1(s : String) = 123 + | + | <>(1) + |} + | + |""".stripMargin, + """|trait Main { + | def method1(s : String) = 123 + | + | def otherMethod(arg0: Int) = ??? + | otherMethod(1) + |} + |""".stripMargin + ) + + @Test def `simple-3` = + checkEdit( + """| + |trait Main { + | def method1(s : String) = 123 + | + | <>((1 + 123).toDouble) + |} + | + |""".stripMargin, + """|trait Main { + | def method1(s : String) = 123 + | + | def otherMethod(arg0: Double) = ??? + | otherMethod((1 + 123).toDouble) + |} + |""".stripMargin + ) + + @Test def `simple-4` = + checkEdit( + """| + |trait Main { + | def method1(s : String) = 123 + | + | method1(<>()) + |} + | + |""".stripMargin, + """|trait Main { + | def method1(s : String) = 123 + | + | def otherMethod(): String = ??? + | method1(otherMethod()) + |} + |""".stripMargin + ) + + @Test def `backtick-method-name` = + checkEdit( + """| + |trait Main { + | <<`met ? hod`>>(10) + |} + |""".stripMargin, + """|trait Main { + | def `met ? hod`(arg0: Int) = ??? + | `met ? hod`(10) + |} + |""".stripMargin + ) + + @Test def `custom-type` = + checkEdit( + """| + |trait Main { + | def method1(b: Double, s : String) = 123 + | + | case class User(i : Int) + | val user = User(1) + | + | method1(0.0, <>(user, 1)) + |} + |""".stripMargin, + """| + |trait Main { + | def method1(b: Double, s : String) = 123 + | + | case class User(i : Int) + | val user = User(1) + | + | def otherMethod(arg0: User, arg1: Int): String = ??? + | method1(0.0, otherMethod(user, 1)) + |} + |""".stripMargin + ) + + @Test def `custom-type-2` = + checkEdit( + """| + |trait Main { + | def method1(b: Double, s : String) = 123 + | + | case class User(i : Int) + | val user = User(1) + | <>(user, 1) + |} + |""".stripMargin, + """| + |trait Main { + | def method1(b: Double, s : String) = 123 + | + | case class User(i : Int) + | val user = User(1) + | def otherMethod(arg0: User, arg1: Int) = ??? + | otherMethod(user, 1) + |} + |""".stripMargin + ) + + @Test def `custom-type-advanced` = + checkEdit( + """| + |trait Main { + | def method1(b: Double, s : String) = 123 + | + | case class User(i : Int) + | + | <>(User(1), 1) + |} + | + |""".stripMargin, + """|trait Main { + | def method1(b: Double, s : String) = 123 + | + | case class User(i : Int) + | + | def otherMethod(arg0: User, arg1: Int) = ??? + | otherMethod(User(1), 1) + |} + |""".stripMargin + ) + + @Test def `custom-type-advanced-2` = + checkEdit( + """| + |trait Main { + | def method1(b: Double, s : String) = 123 + | + | case class User(i : Int) + | + | <>(List(Set(User(1))), Map("1" -> 1)) + |} + | + |""".stripMargin, + """|trait Main { + | def method1(b: Double, s : String) = 123 + | + | case class User(i : Int) + | + | def otherMethod(arg0: List[Set[User]], arg1: Map[String, Int]) = ??? + | otherMethod(List(Set(User(1))), Map("1" -> 1)) + |} + |""".stripMargin + ) + + @Test def `with-imports` = + checkEdit( + """|import java.nio.file.Files + | + |trait Main { + | def main() = { + | def method1(s : String) = 123 + | val path = Files.createTempDirectory("") + | method1(<>(path)) + | } + |} + | + |""".stripMargin, + """|import java.nio.file.Files + |import java.nio.file.Path + | + |trait Main { + | def main() = { + | def method1(s : String) = 123 + | val path = Files.createTempDirectory("") + | def otherMethod(arg0: Path): String = ??? + | method1(otherMethod(path)) + | } + |} + |""".stripMargin + ) + + @Test def `val-definition` = + checkEdit( + """| + |trait Main { + | val result: String = <>(42, "hello") + |} + | + |""".stripMargin, + """|trait Main { + | def nonExistent(arg0: Int, arg1: String): String = ??? + | val result: String = nonExistent(42, "hello") + |} + |""".stripMargin + ) + + @Test def `val-definition-no-args` = + checkEdit( + """| + |trait Main { + | val result: Int = <> + |} + | + |""".stripMargin, + """|trait Main { + | def getValue: Int = ??? + | val result: Int = getValue + |} + |""".stripMargin + ) + + @Test def `lambda-expression` = + checkEdit( + """| + |trait Main { + | val list = List(1, 2, 3) + | list.map(<>) + |} + | + |""".stripMargin, + """|trait Main { + | val list = List(1, 2, 3) + | def transform(arg0: Int) = ??? + | list.map(transform) + |} + |""".stripMargin + ) + + @Test def `lambda-expression-2` = + checkEdit( + """| + |trait Main { + | val list = List(1, 2, 3) + | list.map(<>(10, "test")) + |} + | + |""".stripMargin, + """|trait Main { + | val list = List(1, 2, 3) + | def transform(arg0: Int, arg1: String)(arg2: Int) = ??? + | list.map(transform(10, "test")) + |} + |""".stripMargin + ) + + @Test def `lambda-expression-3` = + checkEdit( + """| + |trait Main { + | val list = List("a", "b", "c") + | list.map(<>) + |} + | + |""".stripMargin, + """|trait Main { + | val list = List("a", "b", "c") + | def process(arg0: String) = ??? + | list.map(process) + |} + |""".stripMargin + ) + + @Test def `lambda-expression-4` = + checkEdit( + """| + |trait Main { + | List((1, 2, 3)).filter(_ => true).map(<>) + |} + | + |""".stripMargin, + """|trait Main { + | def otherMethod(arg0: (Int, Int, Int)) = ??? + | List((1, 2, 3)).filter(_ => true).map(otherMethod) + |} + |""".stripMargin + ) + + @Test def `lambda-expression-5` = + checkEdit( + """| + |trait Main { + | val list = List(1, 2, 3) + | list.filter(<>) + |} + | + |""".stripMargin, + """|trait Main { + | val list = List(1, 2, 3) + | def otherMethod(arg0: Int): Boolean = ??? + | list.filter(otherMethod) + |} + |""".stripMargin + ) + + @Test def `simple-method-no-args` = + checkEdit( + """| + |trait Main { + | <> + |} + | + |""".stripMargin, + """|trait Main { + | def missingMethod = ??? + | missingMethod + |} + |""".stripMargin + ) + + @Test def `simple-method-no-args-2` = + checkEdit( + """| + |trait Main { + | def method1(s : String) = 123 + | method1(<>) + |} + | + |""".stripMargin, + """|trait Main { + | def method1(s : String) = 123 + | def missingMethod: String = ??? + | method1(missingMethod) + |} + |""".stripMargin + ) + + @Test def `nested-val-definition` = + checkEdit( + """| + |trait Main { + | def someMethod(): Unit = { + | val data: List[String] = <>(10) + | } + |} + | + |""".stripMargin, + """|trait Main { + | def someMethod(): Unit = { + | def generateData(arg0: Int): List[String] = ??? + | val data: List[String] = generateData(10) + | } + |} + |""".stripMargin + ) + + @Test def `simple-class-definition` = + checkEdit( + """| + |class User: + | val name: String = "John" + | + |object Main: + | val user = User() + | user.<> + | + |""".stripMargin, + """| + |class User: + | val name: String = "John" + | def otherMethod = ??? + | + |object Main: + | val user = User() + | user.otherMethod + |""".stripMargin, + ) + + @Test def `simple-class-definition-2` = + checkEdit( + """| + |class User: + | val name: String = "John" + | + |object Main: + | val user = User() + | user.<>(10) + | + |""".stripMargin, + """| + |class User: + | val name: String = "John" + | def otherMethod(arg0: Int) = ??? + | + |object Main: + | val user = User() + | user.otherMethod(10) + |""".stripMargin, + ) + + @Test def `simple-object-definition` = + checkEdit( + """| + |object User: + | val name: String = "John" + | + |object Main: + | User.<> + | + |""".stripMargin, + """| + |object User: + | val name: String = "John" + | def otherMethod = ??? + | + |object Main: + | User.otherMethod + |""".stripMargin, + ) + + @Test def `simple-object-definition-2` = + checkEdit( + """| + |object User: + | val name: String = "John" + | + |object Main: + | User.<>(10) + | + |""".stripMargin, + """| + |object User: + | val name: String = "John" + | def otherMethod(arg0: Int) = ??? + | + |object Main: + | User.otherMethod(10) + |""".stripMargin, + ) + + @Test def `class-definition-without-body` = + checkEdit( + """| + |class User + | + |object Main: + | val user = User() + | user.<> + | + |""".stripMargin, + """| + |class User: + | def otherMethod = ??? + | + |object Main: + | val user = User() + | user.otherMethod + |""".stripMargin, + ) + + def checkEdit( + original: String, + expected: String + ): Unit = + val edits = getAutoImplement(original) + val (code, _, _) = params(original) + val obtained = TextEdits.applyEdits(code, edits) + assertNoDiff(expected, obtained) + + def getAutoImplement( + original: String, + filename: String = "file:/A.scala" + ): List[l.TextEdit] = + val (code, _, offset) = params(original) + val result = presentationCompiler + .codeAction( + CompilerOffsetParams(URI.create(filename), code, offset, cancelToken), + CodeActionId.InsertInferredMethod, + Optional.empty() + ) + .get() + result.asScala.toList