1919import lombok .Value ;
2020import org .openrewrite .Tree ;
2121import org .openrewrite .internal .lang .Nullable ;
22- import org .openrewrite .java .JavaIsoVisitor ;
2322import org .openrewrite .java .JavaVisitor ;
2423import org .openrewrite .java .MethodMatcher ;
2524import org .openrewrite .java .style .EqualsAvoidsNullStyle ;
26- import org .openrewrite .java .tree .Expression ;
27- import org .openrewrite .java .tree .J ;
28- import org .openrewrite .java .tree .JavaType ;
29- import org .openrewrite .java .tree .Space ;
25+ import org .openrewrite .java .tree .*;
26+ import org .openrewrite .marker .Markers ;
3027
3128import static java .util .Collections .singletonList ;
3229
3330@ Value
3431@ EqualsAndHashCode (callSuper = false )
35- public class EqualsAvoidsNullVisitor <P > extends JavaIsoVisitor <P > {
32+ public class EqualsAvoidsNullVisitor <P > extends JavaVisitor <P > {
3633 private static final MethodMatcher STRING_EQUALS = new MethodMatcher ("String equals(java.lang.Object)" );
3734 private static final MethodMatcher STRING_EQUALS_IGNORE_CASE = new MethodMatcher ("String equalsIgnoreCase(java.lang.String)" );
3835
3936 EqualsAvoidsNullStyle style ;
4037
4138 @ Override
42- public J .MethodInvocation visitMethodInvocation (J .MethodInvocation method , P p ) {
43- J .MethodInvocation m = super .visitMethodInvocation (method , p );
44-
39+ public J visitMethodInvocation (J .MethodInvocation method , P p ) {
40+ J j = super .visitMethodInvocation (method , p );
41+ if (!(j instanceof J .MethodInvocation )) {
42+ return j ;
43+ }
44+ J .MethodInvocation m = (J .MethodInvocation ) j ;
4545 if (m .getSelect () == null ) {
4646 return m ;
4747 }
4848
4949 if ((STRING_EQUALS .matches (m ) || (!Boolean .TRUE .equals (style .getIgnoreEqualsIgnoreCase ()) && STRING_EQUALS_IGNORE_CASE .matches (m ))) &&
5050 m .getArguments ().get (0 ) instanceof J .Literal &&
51- m .getArguments ().get (0 ).getType () != JavaType .Primitive .Null &&
5251 !(m .getSelect () instanceof J .Literal )) {
5352 Tree parent = getCursor ().getParentTreeCursor ().getValue ();
5453 if (parent instanceof J .Binary ) {
@@ -62,8 +61,16 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, P p)
6261 }
6362 }
6463
65- m = m .withSelect (((J .Literal ) m .getArguments ().get (0 )).withPrefix (m .getSelect ().getPrefix ()))
66- .withArguments (singletonList (m .getSelect ().withPrefix (Space .EMPTY )));
64+ if (m .getArguments ().get (0 ).getType () == JavaType .Primitive .Null ) {
65+ return new J .Binary (Tree .randomId (), m .getPrefix (), Markers .EMPTY ,
66+ m .getSelect (),
67+ JLeftPadded .build (J .Binary .Type .Equal ).withBefore (Space .SINGLE_SPACE ),
68+ m .getArguments ().get (0 ).withPrefix (Space .SINGLE_SPACE ),
69+ JavaType .Primitive .Boolean );
70+ } else {
71+ m = m .withSelect (((J .Literal ) m .getArguments ().get (0 )).withPrefix (m .getSelect ().getPrefix ()))
72+ .withArguments (singletonList (m .getSelect ().withPrefix (Space .EMPTY )));
73+ }
6774 }
6875
6976 return m ;
0 commit comments