Skip to content

Commit 3c3fded

Browse files
committed
Fix a false positive of LambdaBlockToExpression recipe to not rewrite when there are ambiguous method overloading
1 parent e1cfd80 commit 3c3fded

File tree

2 files changed

+59
-32
lines changed

2 files changed

+59
-32
lines changed

src/main/java/org/openrewrite/staticanalysis/LambdaBlockToExpression.java

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020
import org.openrewrite.Recipe;
2121
import org.openrewrite.TreeVisitor;
2222
import org.openrewrite.java.JavaIsoVisitor;
23-
import org.openrewrite.java.tree.J;
24-
import org.openrewrite.java.tree.Space;
25-
import org.openrewrite.java.tree.Statement;
26-
import org.openrewrite.marker.SearchResult;
23+
import org.openrewrite.java.tree.*;
24+
import org.openrewrite.staticanalysis.java.JavaFileChecker;
2725

2826
import java.util.List;
27+
import java.util.Optional;
2928

3029
public class LambdaBlockToExpression extends Recipe {
3130
@Override
@@ -40,13 +39,7 @@ public String getDescription() {
4039

4140
@Override
4241
public TreeVisitor<?, ExecutionContext> getVisitor() {
43-
return Preconditions.check(
44-
new JavaIsoVisitor<ExecutionContext>() {
45-
@Override
46-
public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext executionContext) {
47-
return SearchResult.found(cu);
48-
}
49-
},
42+
return Preconditions.check(new JavaFileChecker<>(),
5043
new JavaIsoVisitor<ExecutionContext>() {
5144
@Override
5245
public J.Lambda visitLambda(J.Lambda lambda, ExecutionContext executionContext) {
@@ -64,7 +57,50 @@ public J.Lambda visitLambda(J.Lambda lambda, ExecutionContext executionContext)
6457
}
6558
return l;
6659
}
60+
61+
@Override
62+
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, ExecutionContext executionContext) {
63+
if (hasLambdaArgument(method) && hasMethodOverloading(method)) {
64+
return method;
65+
}
66+
return super.visitMethodInvocation(method, executionContext);
67+
}
6768
}
6869
);
6970
}
71+
72+
// Check whether a method has overloading methods in the declaring class
73+
private static boolean hasMethodOverloading(J.MethodInvocation method) {
74+
String methodName = method.getSimpleName();
75+
return Optional.ofNullable(method.getMethodType())
76+
.map(JavaType.Method::getDeclaringType)
77+
.filter(JavaType.Class.class::isInstance)
78+
.map(JavaType.Class.class::cast)
79+
.map(JavaType.Class::getMethods)
80+
.map(methods -> {
81+
int overloadingCount = 0;
82+
for (JavaType.Method dm : methods) {
83+
if (dm.getName().equals(methodName)) {
84+
overloadingCount++;
85+
if (overloadingCount > 1) {
86+
87+
return true;
88+
}
89+
}
90+
}
91+
return false;
92+
})
93+
.orElse(false);
94+
}
95+
96+
private static boolean hasLambdaArgument(J.MethodInvocation method) {
97+
boolean hasLambdaArgument = false;
98+
for (Expression arg : method.getArguments()) {
99+
if (arg instanceof J.Lambda) {
100+
hasLambdaArgument = true;
101+
break;
102+
}
103+
}
104+
return hasLambdaArgument;
105+
}
70106
}

src/test/java/org/openrewrite/staticanalysis/LambdaBlockToExpressionTest.java

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,21 @@ class Test {
8484
);
8585
}
8686

87-
@Test
8887
@Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/162")
89-
void simplifyLambdaBlockWithAmbiguousMethod() {
88+
@Test
89+
void noChangeIfLambdaBlockWithAmbiguousMethod() {
9090
//language=java
9191
rewriteRun(
92-
java("""
93-
import java.util.function.Function;
94-
import java.util.function.Consumer;
95-
class A {
96-
void aMethod(Consumer<Integer> consumer){
97-
}
98-
99-
void aMethod(Function<Integer,String> function){
100-
}
101-
}
102-
"""),
92+
java(
93+
"""
94+
import java.util.function.Function;
95+
import java.util.function.Consumer;
96+
class A {
97+
void aMethod(Consumer<Integer> consumer) {}
98+
void aMethod(Function<Integer,String> function) {}
99+
}
100+
"""
101+
),
103102
java(
104103
"""
105104
class Test {
@@ -110,14 +109,6 @@ void doTest() {
110109
});
111110
}
112111
}
113-
""",
114-
"""
115-
class Test {
116-
void doTest() {
117-
A a = new A();
118-
a.aMethod(value -> value.toString());
119-
}
120-
}
121112
"""
122113
)
123114
);

0 commit comments

Comments
 (0)