2020import org .openrewrite .Recipe ;
2121import org .openrewrite .TreeVisitor ;
2222import org .openrewrite .java .JavaIsoVisitor ;
23- import org .openrewrite .java .tree .J ;
24- import org .openrewrite .java .tree .Space ;
25- import org .openrewrite .java .tree .Statement ;
26- import org .openrewrite .marker .SearchResult ;
23+ import org .openrewrite .java .tree .*;
24+ import org .openrewrite .staticanalysis .java .JavaFileChecker ;
2725
2826import java .util .List ;
27+ import java .util .Optional ;
2928
3029public class LambdaBlockToExpression extends Recipe {
3130 @ Override
@@ -40,13 +39,7 @@ public String getDescription() {
4039
4140 @ Override
4241 public TreeVisitor <?, ExecutionContext > getVisitor () {
43- return Preconditions .check (
44- new JavaIsoVisitor <ExecutionContext >() {
45- @ Override
46- public J .CompilationUnit visitCompilationUnit (J .CompilationUnit cu , ExecutionContext executionContext ) {
47- return SearchResult .found (cu );
48- }
49- },
42+ return Preconditions .check (new JavaFileChecker <>(),
5043 new JavaIsoVisitor <ExecutionContext >() {
5144 @ Override
5245 public J .Lambda visitLambda (J .Lambda lambda , ExecutionContext executionContext ) {
@@ -64,7 +57,50 @@ public J.Lambda visitLambda(J.Lambda lambda, ExecutionContext executionContext)
6457 }
6558 return l ;
6659 }
60+
61+ @ Override
62+ public J .MethodInvocation visitMethodInvocation (J .MethodInvocation method , ExecutionContext executionContext ) {
63+ if (hasLambdaArgument (method ) && hasMethodOverloading (method )) {
64+ return method ;
65+ }
66+ return super .visitMethodInvocation (method , executionContext );
67+ }
6768 }
6869 );
6970 }
71+
72+ // Check whether a method has overloading methods in the declaring class
73+ private static boolean hasMethodOverloading (J .MethodInvocation method ) {
74+ String methodName = method .getSimpleName ();
75+ return Optional .ofNullable (method .getMethodType ())
76+ .map (JavaType .Method ::getDeclaringType )
77+ .filter (JavaType .Class .class ::isInstance )
78+ .map (JavaType .Class .class ::cast )
79+ .map (JavaType .Class ::getMethods )
80+ .map (methods -> {
81+ int overloadingCount = 0 ;
82+ for (JavaType .Method dm : methods ) {
83+ if (dm .getName ().equals (methodName )) {
84+ overloadingCount ++;
85+ if (overloadingCount > 1 ) {
86+
87+ return true ;
88+ }
89+ }
90+ }
91+ return false ;
92+ })
93+ .orElse (false );
94+ }
95+
96+ private static boolean hasLambdaArgument (J .MethodInvocation method ) {
97+ boolean hasLambdaArgument = false ;
98+ for (Expression arg : method .getArguments ()) {
99+ if (arg instanceof J .Lambda ) {
100+ hasLambdaArgument = true ;
101+ break ;
102+ }
103+ }
104+ return hasLambdaArgument ;
105+ }
70106}
0 commit comments