From d274c980b447d7d21a1755ca61ded911d4372ca6 Mon Sep 17 00:00:00 2001 From: Laurens Westerlaken Date: Thu, 12 Sep 2024 16:54:15 +0200 Subject: [PATCH 01/15] WIP --- .../MockitoWhenOnStaticToMockStatic.java | 92 ++++++++++++++++ .../MockitoWhenOnStaticToMockStaticTest.java | 100 ++++++++++++++++++ 2 files changed, 192 insertions(+) create mode 100644 src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java create mode 100644 src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java new file mode 100644 index 000000000..00336424b --- /dev/null +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -0,0 +1,92 @@ +/* + * Copyright 2024 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.testing.mockito; + +import org.openrewrite.ExecutionContext; +import org.openrewrite.Preconditions; +import org.openrewrite.Recipe; +import org.openrewrite.TreeVisitor; +import org.openrewrite.internal.ListUtils; +import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.JavaParser; +import org.openrewrite.java.JavaTemplate; +import org.openrewrite.java.MethodMatcher; +import org.openrewrite.java.search.UsesType; +import org.openrewrite.java.tree.*; + +import java.util.ArrayList; +import java.util.List; + +public class MockitoWhenOnStaticToMockStatic extends Recipe { + + private static final MethodMatcher MOCKITO_WHEN = new MethodMatcher("org.mockito.Mockito when(..)"); + + @Override + public String getDisplayName() { + return ""; + } + + @Override + public String getDescription() { + return "."; + } + + @Override + public TreeVisitor getVisitor() { + return Preconditions.check(new UsesType<>("org.mockito.Mockito", true), + new JavaIsoVisitor() { + @Override + public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { + J.MethodDeclaration m = super.visitMethodDeclaration(method, ctx); + boolean rewrittenWhen = false; + List statementsBeforeWhen = new ArrayList<>(); + List statementsAfterWhen = new ArrayList<>(); + for (Statement stmt : m.getBody().getStatements()) { + if (stmt instanceof J.MethodInvocation && + MOCKITO_WHEN.matches(((J.MethodInvocation)stmt).getSelect())) { + J.MethodInvocation when = (J.MethodInvocation)((J.MethodInvocation) stmt).getSelect(); + if (when.getArguments().get(0) instanceof J.MethodInvocation && ((J.MethodInvocation)when.getArguments().get(0)).getMethodType().getFlags().contains(Flag.Static)){ + JavaType.FullyQualified arg_fq = TypeUtils.asFullyQualified(when.getArguments().get(0).getType()); + String template = String.format("try(MockedStatic<%s> mock%s = mockStatic(%s.class)){\n" + + "mock%s.when(%s::%s).thenReturn(%s);\n" + + "}", arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), ((J.MethodInvocation)when.getArguments().get(0)).getSimpleName(), ((J.MethodInvocation) stmt).getArguments().get(0)); + m = JavaTemplate.builder(template) + .contextSensitive() + .javaParser(JavaParser.fromJavaVersion()) + .staticImports("org.mockito.Mockito.mockStatic") + .build() + .apply(getCursor(), stmt.getCoordinates().replace()); + rewrittenWhen = true; + maybeAddImport("org.mockito.Mockito", "mockStatic"); + continue; + } + } + if (rewrittenWhen) { + statementsAfterWhen.add(stmt); + } else { + statementsBeforeWhen.add(stmt); + } + } + if (rewrittenWhen) { + J.Try try_catch = (J.Try) m.getBody().getStatements().get(statementsBeforeWhen.size()); + return maybeAutoFormat(method, m.withBody(m.getBody().withStatements(ListUtils.concat(statementsBeforeWhen, try_catch.withBody(try_catch.getBody().withStatements(ListUtils.concatAll(try_catch.getBody().getStatements(), statementsAfterWhen)))))), ctx); + } else { + return m; + } + } + }); + } +} diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java new file mode 100644 index 000000000..3cd7cc424 --- /dev/null +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -0,0 +1,100 @@ +/* + * Copyright 2024 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.testing.mockito; + +import org.junit.jupiter.api.Test; +import org.openrewrite.InMemoryExecutionContext; +import org.openrewrite.java.JavaParser; +import org.openrewrite.test.RecipeSpec; +import org.openrewrite.test.RewriteTest; +import org.openrewrite.test.SourceSpec; + +import static org.openrewrite.java.Assertions.java; + +public class MockitoWhenOnStaticToMockStaticTest implements RewriteTest { + + @Override + public void defaults(RecipeSpec spec) { + spec + .parser(JavaParser.fromJavaVersion() + .classpathFromResources(new InMemoryExecutionContext(), + "junit-4.13", + "junit-jupiter-api-5.9", + "mockito-core-3.12", + "mockito-junit-jupiter-3.12" + )) + .recipe(new MockitoWhenOnStaticToMockStatic()); + } + + @Test + void shouldRefactorMockito_When() { + //language=java + rewriteRun( + java( + """ + package a.b; + + public class A { + + public A() { + } + + public static A getA() { + return new A(); + } + } + """, + SourceSpec::skip + ), + java( + """ + import a.b.A; + + import static org.junit.Assert.assertEquals; + import static org.mockito.Mockito.*; + + class Test { + + private A aMock = mock(A.class); + + void test() { + when(A.getA()).thenReturn(aMock); + assertEquals(A.getA(), aMock); + } + } + """, + """ + import a.b.A; + + import static org.junit.Assert.assertEquals; + import static org.mockito.Mockito.*; + + class Test { + + private A aMock = mock(A.class); + + void test() { + try (MockedStatic mockA = mockStatic(A.class)) { + mockA.when(A::getA).thenReturn(aMock); + assertEquals(A.getA(), aMock); + } + } + } + """ + ) + ); + } +} From 497252dd2141e9f05338edd8d9fe933eab661041 Mon Sep 17 00:00:00 2001 From: Laurens Westerlaken Date: Thu, 12 Sep 2024 18:00:15 +0200 Subject: [PATCH 02/15] Format and suggestion --- .../testing/mockito/MockitoWhenOnStaticToMockStatic.java | 8 ++++---- .../mockito/MockitoWhenOnStaticToMockStaticTest.java | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index 00336424b..353336398 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -56,13 +56,13 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex List statementsAfterWhen = new ArrayList<>(); for (Statement stmt : m.getBody().getStatements()) { if (stmt instanceof J.MethodInvocation && - MOCKITO_WHEN.matches(((J.MethodInvocation)stmt).getSelect())) { - J.MethodInvocation when = (J.MethodInvocation)((J.MethodInvocation) stmt).getSelect(); - if (when.getArguments().get(0) instanceof J.MethodInvocation && ((J.MethodInvocation)when.getArguments().get(0)).getMethodType().getFlags().contains(Flag.Static)){ + MOCKITO_WHEN.matches(((J.MethodInvocation) stmt).getSelect())) { + J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) stmt).getSelect(); + if (when.getArguments().get(0) instanceof J.MethodInvocation && ((J.MethodInvocation) when.getArguments().get(0)).getMethodType().getFlags().contains(Flag.Static)) { JavaType.FullyQualified arg_fq = TypeUtils.asFullyQualified(when.getArguments().get(0).getType()); String template = String.format("try(MockedStatic<%s> mock%s = mockStatic(%s.class)){\n" + "mock%s.when(%s::%s).thenReturn(%s);\n" + - "}", arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), ((J.MethodInvocation)when.getArguments().get(0)).getSimpleName(), ((J.MethodInvocation) stmt).getArguments().get(0)); + "}", arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), ((J.MethodInvocation) when.getArguments().get(0)).getSimpleName(), ((J.MethodInvocation) stmt).getArguments().get(0)); m = JavaTemplate.builder(template) .contextSensitive() .javaParser(JavaParser.fromJavaVersion()) diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java index 3cd7cc424..00cdbc240 100644 --- a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -24,7 +24,7 @@ import static org.openrewrite.java.Assertions.java; -public class MockitoWhenOnStaticToMockStaticTest implements RewriteTest { +class MockitoWhenOnStaticToMockStaticTest implements RewriteTest { @Override public void defaults(RecipeSpec spec) { From 3d4b4d8eec1da4df6da37a31bead73ead184bb87 Mon Sep 17 00:00:00 2001 From: Laurens Westerlaken Date: Fri, 13 Sep 2024 15:17:30 +0200 Subject: [PATCH 03/15] Add recipe to mockito recipes Add missing imports Fill out name and desc --- .../mockito/MockitoWhenOnStaticToMockStatic.java | 15 +++++++++------ src/main/resources/META-INF/rewrite/mockito.yml | 1 + .../MockitoWhenOnStaticToMockStaticTest.java | 5 +++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index 353336398..2f0849184 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -36,12 +36,12 @@ public class MockitoWhenOnStaticToMockStatic extends Recipe { @Override public String getDisplayName() { - return ""; + return "Replace `Mockito.when` on static (non mock) with try-with-resource with MockedStatic"; } @Override public String getDescription() { - return "."; + return "Replace `Mockito.when` on static (non mock) with try-with-resource with MockedStatic as Mockito4 no longer allows this."; } @Override @@ -60,16 +60,19 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) stmt).getSelect(); if (when.getArguments().get(0) instanceof J.MethodInvocation && ((J.MethodInvocation) when.getArguments().get(0)).getMethodType().getFlags().contains(Flag.Static)) { JavaType.FullyQualified arg_fq = TypeUtils.asFullyQualified(when.getArguments().get(0).getType()); - String template = String.format("try(MockedStatic<%s> mock%s = mockStatic(%s.class)){\n" + - "mock%s.when(%s::%s).thenReturn(%s);\n" + - "}", arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), arg_fq.getClassName(), ((J.MethodInvocation) when.getArguments().get(0)).getSimpleName(), ((J.MethodInvocation) stmt).getArguments().get(0)); + J.Identifier ident = (J.Identifier) ((J.MethodInvocation)when.getArguments().get(0)).getSelect(); + String template = String.format("try(MockedStatic<#{}> mock%s = mockStatic(#{}.class)){\n" + + " mock%s.when(#{any()}).thenReturn(#{any()});\n" + + "}", arg_fq.getClassName(), arg_fq.getClassName()); m = JavaTemplate.builder(template) .contextSensitive() .javaParser(JavaParser.fromJavaVersion()) + .imports("org.mockito.MockedStatic") .staticImports("org.mockito.Mockito.mockStatic") .build() - .apply(getCursor(), stmt.getCoordinates().replace()); + .apply(getCursor(), stmt.getCoordinates().replace(), ident.getType(), ident.getType(), when.getArguments().get(0), ((J.MethodInvocation) stmt).getArguments().get(0)); rewrittenWhen = true; + maybeAddImport("org.mockito.MockedStatic", false); maybeAddImport("org.mockito.Mockito", "mockStatic"); continue; } diff --git a/src/main/resources/META-INF/rewrite/mockito.yml b/src/main/resources/META-INF/rewrite/mockito.yml index 49a5b193a..7a059c678 100644 --- a/src/main/resources/META-INF/rewrite/mockito.yml +++ b/src/main/resources/META-INF/rewrite/mockito.yml @@ -67,6 +67,7 @@ recipeList: groupId: net.bytebuddy artifactId: byte-buddy* newVersion: 1.12.19 + - org.openrewrite.java.testing.mockito.MockitoWhenOnStaticToMockStatic --- type: specs.openrewrite.org/v1beta/recipe name: org.openrewrite.java.testing.mockito.Mockito1to3Migration diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java index 00cdbc240..803d2e249 100644 --- a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -78,6 +78,7 @@ void test() { """, """ import a.b.A; + import org.mockito.MockedStatic; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; @@ -87,8 +88,8 @@ class Test { private A aMock = mock(A.class); void test() { - try (MockedStatic mockA = mockStatic(A.class)) { - mockA.when(A::getA).thenReturn(aMock); + try (MockedStatic mockA = mockStatic(a.b.A.class)) { + mockA.when(A.getA()).thenReturn(aMock); assertEquals(A.getA(), aMock); } } From 8c9d1d48467144250ecc20da3d58becc365bd5fb Mon Sep 17 00:00:00 2001 From: Laurens Westerlaken Date: Mon, 16 Sep 2024 16:34:09 +0200 Subject: [PATCH 04/15] Update test with TypeValidation --- .../testing/mockito/MockitoWhenOnStaticToMockStaticTest.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java index 803d2e249..2037b131c 100644 --- a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -21,6 +21,7 @@ import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; import org.openrewrite.test.SourceSpec; +import org.openrewrite.test.TypeValidation; import static org.openrewrite.java.Assertions.java; @@ -43,6 +44,7 @@ public void defaults(RecipeSpec spec) { void shouldRefactorMockito_When() { //language=java rewriteRun( + spec -> spec.typeValidationOptions(TypeValidation.builder().methodInvocations(false).identifiers(false).build()), java( """ package a.b; From c25c40dcfd5a8aeac3b55370ff22e2b5054c2a67 Mon Sep 17 00:00:00 2001 From: Laurens Westerlaken Date: Mon, 16 Sep 2024 16:42:26 +0200 Subject: [PATCH 05/15] Format and refactor to remove errors --- .../MockitoWhenOnStaticToMockStatic.java | 69 +++++++++++-------- .../MockitoWhenOnStaticToMockStaticTest.java | 38 +++++----- 2 files changed, 58 insertions(+), 49 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index 2f0849184..f40795a38 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -54,41 +54,50 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex boolean rewrittenWhen = false; List statementsBeforeWhen = new ArrayList<>(); List statementsAfterWhen = new ArrayList<>(); - for (Statement stmt : m.getBody().getStatements()) { - if (stmt instanceof J.MethodInvocation && - MOCKITO_WHEN.matches(((J.MethodInvocation) stmt).getSelect())) { - J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) stmt).getSelect(); - if (when.getArguments().get(0) instanceof J.MethodInvocation && ((J.MethodInvocation) when.getArguments().get(0)).getMethodType().getFlags().contains(Flag.Static)) { - JavaType.FullyQualified arg_fq = TypeUtils.asFullyQualified(when.getArguments().get(0).getType()); - J.Identifier ident = (J.Identifier) ((J.MethodInvocation)when.getArguments().get(0)).getSelect(); - String template = String.format("try(MockedStatic<#{}> mock%s = mockStatic(#{}.class)){\n" + - " mock%s.when(#{any()}).thenReturn(#{any()});\n" + - "}", arg_fq.getClassName(), arg_fq.getClassName()); - m = JavaTemplate.builder(template) - .contextSensitive() - .javaParser(JavaParser.fromJavaVersion()) - .imports("org.mockito.MockedStatic") - .staticImports("org.mockito.Mockito.mockStatic") - .build() - .apply(getCursor(), stmt.getCoordinates().replace(), ident.getType(), ident.getType(), when.getArguments().get(0), ((J.MethodInvocation) stmt).getArguments().get(0)); - rewrittenWhen = true; - maybeAddImport("org.mockito.MockedStatic", false); - maybeAddImport("org.mockito.Mockito", "mockStatic"); - continue; + if (m.getBody() != null) { + for (Statement stmt : m.getBody().getStatements()) { + if (stmt instanceof J.MethodInvocation && + MOCKITO_WHEN.matches(((J.MethodInvocation) stmt).getSelect())) { + J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) stmt).getSelect(); + if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) { + J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0); + if (whenArg.getMethodType() != null && whenArg.getMethodType().getFlags().contains(Flag.Static)) { + JavaType.FullyQualified arg_fq = TypeUtils.asFullyQualified(whenArg.getType()); + J.Identifier ident = (J.Identifier) whenArg.getSelect(); + if (arg_fq != null && ident != null && ident.getType() != null) { + String template = String.format("try(MockedStatic<#{}> mock%s = mockStatic(#{}.class)){\n" + + " mock%s.when(#{any()}).thenReturn(#{any()});\n" + + "}", arg_fq.getClassName(), arg_fq.getClassName()); + m = JavaTemplate.builder(template) + .contextSensitive() + .javaParser(JavaParser.fromJavaVersion()) + .imports("org.mockito.MockedStatic") + .staticImports("org.mockito.Mockito.mockStatic") + .build() + .apply(getCursor(), stmt.getCoordinates().replace(), ident.getType(), ident.getType(), whenArg, ((J.MethodInvocation) stmt).getArguments().get(0)); + rewrittenWhen = true; + maybeAddImport("org.mockito.MockedStatic", false); + maybeAddImport("org.mockito.Mockito", "mockStatic"); + continue; + } + } + + } + } + if (rewrittenWhen) { + statementsAfterWhen.add(stmt); + } else { + statementsBeforeWhen.add(stmt); } - } - if (rewrittenWhen) { - statementsAfterWhen.add(stmt); - } else { - statementsBeforeWhen.add(stmt); } } if (rewrittenWhen) { - J.Try try_catch = (J.Try) m.getBody().getStatements().get(statementsBeforeWhen.size()); - return maybeAutoFormat(method, m.withBody(m.getBody().withStatements(ListUtils.concat(statementsBeforeWhen, try_catch.withBody(try_catch.getBody().withStatements(ListUtils.concatAll(try_catch.getBody().getStatements(), statementsAfterWhen)))))), ctx); - } else { - return m; + if (m.getBody() != null) { + J.Try try_catch = (J.Try) m.getBody().getStatements().get(statementsBeforeWhen.size()); + return maybeAutoFormat(method, m.withBody(m.getBody().withStatements(ListUtils.concat(statementsBeforeWhen, try_catch.withBody(try_catch.getBody().withStatements(ListUtils.concatAll(try_catch.getBody().getStatements(), statementsAfterWhen)))))), ctx); + } } + return m; } }); } diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java index 2037b131c..6c35fea8d 100644 --- a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -47,31 +47,31 @@ void shouldRefactorMockito_When() { spec -> spec.typeValidationOptions(TypeValidation.builder().methodInvocations(false).identifiers(false).build()), java( """ - package a.b; - - public class A { - - public A() { - } - - public static A getA() { - return new A(); - } - } - """, + package a.b; + + public class A { + + public A() { + } + + public static A getA() { + return new A(); + } + } + """, SourceSpec::skip ), java( """ import a.b.A; - + import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; - + class Test { - + private A aMock = mock(A.class); - + void test() { when(A.getA()).thenReturn(aMock); assertEquals(A.getA(), aMock); @@ -84,11 +84,11 @@ void test() { import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; - + class Test { - + private A aMock = mock(A.class); - + void test() { try (MockedStatic mockA = mockStatic(a.b.A.class)) { mockA.when(A.getA()).thenReturn(aMock); From 903899edc549d0489825f16b25d31e94adf22e58 Mon Sep 17 00:00:00 2001 From: Tim te Beek Date: Mon, 16 Sep 2024 21:40:56 +0200 Subject: [PATCH 06/15] Reduce and flatten ahead of further changes --- .../MockitoWhenOnStaticToMockStatic.java | 70 ++++++++-------- .../resources/META-INF/rewrite/mockito.yml | 2 +- .../MockitoWhenOnStaticToMockStaticTest.java | 83 +++++++++---------- 3 files changed, 77 insertions(+), 78 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index f40795a38..e9a571320 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -24,6 +24,7 @@ import org.openrewrite.java.JavaParser; import org.openrewrite.java.JavaTemplate; import org.openrewrite.java.MethodMatcher; +import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.search.UsesType; import org.openrewrite.java.tree.*; @@ -46,49 +47,52 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check(new UsesType<>("org.mockito.Mockito", true), + return Preconditions.check( + new UsesMethod<>(MOCKITO_WHEN), new JavaIsoVisitor() { @Override public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { J.MethodDeclaration m = super.visitMethodDeclaration(method, ctx); + if (m.getBody() == null) { + return m; + } + boolean rewrittenWhen = false; List statementsBeforeWhen = new ArrayList<>(); List statementsAfterWhen = new ArrayList<>(); - if (m.getBody() != null) { - for (Statement stmt : m.getBody().getStatements()) { - if (stmt instanceof J.MethodInvocation && - MOCKITO_WHEN.matches(((J.MethodInvocation) stmt).getSelect())) { - J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) stmt).getSelect(); - if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) { - J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0); - if (whenArg.getMethodType() != null && whenArg.getMethodType().getFlags().contains(Flag.Static)) { - JavaType.FullyQualified arg_fq = TypeUtils.asFullyQualified(whenArg.getType()); - J.Identifier ident = (J.Identifier) whenArg.getSelect(); - if (arg_fq != null && ident != null && ident.getType() != null) { - String template = String.format("try(MockedStatic<#{}> mock%s = mockStatic(#{}.class)){\n" + - " mock%s.when(#{any()}).thenReturn(#{any()});\n" + - "}", arg_fq.getClassName(), arg_fq.getClassName()); - m = JavaTemplate.builder(template) - .contextSensitive() - .javaParser(JavaParser.fromJavaVersion()) - .imports("org.mockito.MockedStatic") - .staticImports("org.mockito.Mockito.mockStatic") - .build() - .apply(getCursor(), stmt.getCoordinates().replace(), ident.getType(), ident.getType(), whenArg, ((J.MethodInvocation) stmt).getArguments().get(0)); - rewrittenWhen = true; - maybeAddImport("org.mockito.MockedStatic", false); - maybeAddImport("org.mockito.Mockito", "mockStatic"); - continue; - } + for (Statement stmt : m.getBody().getStatements()) { + if (stmt instanceof J.MethodInvocation && + MOCKITO_WHEN.matches(((J.MethodInvocation) stmt).getSelect())) { + J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) stmt).getSelect(); + if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) { + J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0); + if (whenArg.getMethodType() != null && whenArg.getMethodType().getFlags().contains(Flag.Static)) { + JavaType.FullyQualified arg_fq = TypeUtils.asFullyQualified(whenArg.getType()); + J.Identifier ident = (J.Identifier) whenArg.getSelect(); + if (arg_fq != null && ident != null && ident.getType() != null) { + String template = String.format("try(MockedStatic<#{}> mock%s = mockStatic(#{}.class)){\n" + + " mock%s.when(#{any()}).thenReturn(#{any()});\n" + + "}", arg_fq.getClassName(), arg_fq.getClassName()); + m = JavaTemplate.builder(template) + .contextSensitive() + .javaParser(JavaParser.fromJavaVersion()) + .imports("org.mockito.MockedStatic") + .staticImports("org.mockito.Mockito.mockStatic") + .build() + .apply(getCursor(), stmt.getCoordinates().replace(), ident.getType(), ident.getType(), whenArg, ((J.MethodInvocation) stmt).getArguments().get(0)); + rewrittenWhen = true; + maybeAddImport("org.mockito.MockedStatic", false); + maybeAddImport("org.mockito.Mockito", "mockStatic"); + continue; } - } + } - if (rewrittenWhen) { - statementsAfterWhen.add(stmt); - } else { - statementsBeforeWhen.add(stmt); - } + } + if (rewrittenWhen) { + statementsAfterWhen.add(stmt); + } else { + statementsBeforeWhen.add(stmt); } } if (rewrittenWhen) { diff --git a/src/main/resources/META-INF/rewrite/mockito.yml b/src/main/resources/META-INF/rewrite/mockito.yml index d584b8ea8..7e60ad788 100644 --- a/src/main/resources/META-INF/rewrite/mockito.yml +++ b/src/main/resources/META-INF/rewrite/mockito.yml @@ -59,6 +59,7 @@ tags: - mockito recipeList: - org.openrewrite.java.testing.mockito.Mockito1to3Migration + - org.openrewrite.java.testing.mockito.MockitoWhenOnStaticToMockStatic - org.openrewrite.java.dependencies.UpgradeDependencyVersion: groupId: org.mockito artifactId: "*" @@ -67,7 +68,6 @@ recipeList: groupId: net.bytebuddy artifactId: byte-buddy* newVersion: 1.12.19 - - org.openrewrite.java.testing.mockito.MockitoWhenOnStaticToMockStatic --- type: specs.openrewrite.org/v1beta/recipe name: org.openrewrite.java.testing.mockito.Mockito1to3Migration diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java index 6c35fea8d..0ec2906a6 100644 --- a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -16,6 +16,7 @@ package org.openrewrite.java.testing.mockito; import org.junit.jupiter.api.Test; +import org.openrewrite.DocumentExample; import org.openrewrite.InMemoryExecutionContext; import org.openrewrite.java.JavaParser; import org.openrewrite.test.RecipeSpec; @@ -29,31 +30,25 @@ class MockitoWhenOnStaticToMockStaticTest implements RewriteTest { @Override public void defaults(RecipeSpec spec) { - spec + spec.recipe(new MockitoWhenOnStaticToMockStatic()) .parser(JavaParser.fromJavaVersion() .classpathFromResources(new InMemoryExecutionContext(), "junit-4.13", - "junit-jupiter-api-5.9", "mockito-core-3.12", "mockito-junit-jupiter-3.12" - )) - .recipe(new MockitoWhenOnStaticToMockStatic()); + )); } + @DocumentExample @Test void shouldRefactorMockito_When() { //language=java rewriteRun( - spec -> spec.typeValidationOptions(TypeValidation.builder().methodInvocations(false).identifiers(false).build()), + spec -> spec.afterTypeValidationOptions(TypeValidation.builder().methodInvocations(false).identifiers(false).build()), java( """ - package a.b; - + package com.foo; public class A { - - public A() { - } - public static A getA() { return new A(); } @@ -63,40 +58,40 @@ public static A getA() { ), java( """ - import a.b.A; - - import static org.junit.Assert.assertEquals; - import static org.mockito.Mockito.*; - - class Test { - - private A aMock = mock(A.class); - - void test() { - when(A.getA()).thenReturn(aMock); - assertEquals(A.getA(), aMock); - } - } - """, - """ - import a.b.A; - import org.mockito.MockedStatic; - - import static org.junit.Assert.assertEquals; - import static org.mockito.Mockito.*; - - class Test { - - private A aMock = mock(A.class); - - void test() { - try (MockedStatic mockA = mockStatic(a.b.A.class)) { - mockA.when(A.getA()).thenReturn(aMock); - assertEquals(A.getA(), aMock); - } - } - } + import com.foo.A; + + import static org.junit.Assert.assertEquals; + import static org.mockito.Mockito.*; + + class Test { + + private A aMock = mock(A.class); + + void test() { + when(A.getA()).thenReturn(aMock); + assertEquals(A.getA(), aMock); + } + } + """, """ + import com.foo.A; + import org.mockito.MockedStatic; + + import static org.junit.Assert.assertEquals; + import static org.mockito.Mockito.*; + + class Test { + + private A aMock = mock(A.class); + + void test() { + try (MockedStatic mockA = mockStatic(com.foo.A.class)) { + mockA.when(A.getA()).thenReturn(aMock); + assertEquals(A.getA(), aMock); + } + } + } + """ ) ); } From 92da6ad7dbe048adcf33bd6e8387d399f111b5eb Mon Sep 17 00:00:00 2001 From: Tim te Beek Date: Mon, 16 Sep 2024 21:43:39 +0200 Subject: [PATCH 07/15] Show problematic case with new unit test --- .../MockitoWhenOnStaticToMockStaticTest.java | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java index 0ec2906a6..c6ba952f2 100644 --- a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -95,4 +95,47 @@ void test() { ) ); } + + @Test + void shouldHandleMultipleStaticMocks() { + //language=java + rewriteRun( + spec -> spec.afterTypeValidationOptions(TypeValidation.builder().methodInvocations(false).identifiers(false).build()), + java( + """ + package com.foo; + public class A { + public static A getA() { + return new A(); + } + } + """, + SourceSpec::skip + ), + java( + """ + import com.foo.A; + + import static org.junit.Assert.assertEquals; + import static org.mockito.Mockito.*; + + class Test { + + private A aMock = mock(A.class); + + void test() { + when(A.getA()).thenReturn(aMock); + assertEquals(A.getA(), aMock); + + when(A.getA()).thenReturn(aMock); + assertEquals(A.getA(), aMock); + } + } + """, + """ + class TODO {} + """ + ) + ); + } } From 59b721a986180d8f79410293d313faf6d004f91f Mon Sep 17 00:00:00 2001 From: Tim te Beek Date: Mon, 16 Sep 2024 22:14:16 +0200 Subject: [PATCH 08/15] Generate a non-conflicting variable name of the right type --- .../MockitoWhenOnStaticToMockStatic.java | 25 ++++++++++--------- .../MockitoWhenOnStaticToMockStaticTest.java | 12 ++++----- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index e9a571320..cbe0f6d5b 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -20,12 +20,8 @@ import org.openrewrite.Recipe; import org.openrewrite.TreeVisitor; import org.openrewrite.internal.ListUtils; -import org.openrewrite.java.JavaIsoVisitor; -import org.openrewrite.java.JavaParser; -import org.openrewrite.java.JavaTemplate; -import org.openrewrite.java.MethodMatcher; +import org.openrewrite.java.*; import org.openrewrite.java.search.UsesMethod; -import org.openrewrite.java.search.UsesType; import org.openrewrite.java.tree.*; import java.util.ArrayList; @@ -67,19 +63,24 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) { J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0); if (whenArg.getMethodType() != null && whenArg.getMethodType().getFlags().contains(Flag.Static)) { - JavaType.FullyQualified arg_fq = TypeUtils.asFullyQualified(whenArg.getType()); + JavaType.FullyQualified argFq = TypeUtils.asFullyQualified(whenArg.getType()); J.Identifier ident = (J.Identifier) whenArg.getSelect(); - if (arg_fq != null && ident != null && ident.getType() != null) { - String template = String.format("try(MockedStatic<#{}> mock%s = mockStatic(#{}.class)){\n" + - " mock%s.when(#{any()}).thenReturn(#{any()});\n" + - "}", arg_fq.getClassName(), arg_fq.getClassName()); - m = JavaTemplate.builder(template) + if (argFq != null && ident != null && ident.getType() != null) { + String mockName = VariableNameUtils.generateVariableName("mock" + ident.getSimpleName(), getCursor(), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER); + m = JavaTemplate.builder(String.format( + "try(MockedStatic<#{}> %1$s = mockStatic(#{}.class)) {\n" + + " %1$s.when(#{any()}).thenReturn(#{any()});\n" + + "}", mockName)) .contextSensitive() .javaParser(JavaParser.fromJavaVersion()) .imports("org.mockito.MockedStatic") .staticImports("org.mockito.Mockito.mockStatic") .build() - .apply(getCursor(), stmt.getCoordinates().replace(), ident.getType(), ident.getType(), whenArg, ((J.MethodInvocation) stmt).getArguments().get(0)); + .apply(getCursor(), stmt.getCoordinates().replace(), + ident.getType(), + ident.getType(), + whenArg, + ((J.MethodInvocation) stmt).getArguments().get(0)); rewrittenWhen = true; maybeAddImport("org.mockito.MockedStatic", false); maybeAddImport("org.mockito.Mockito", "mockStatic"); diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java index c6ba952f2..cf08a19b5 100644 --- a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -49,8 +49,8 @@ void shouldRefactorMockito_When() { """ package com.foo; public class A { - public static A getA() { - return new A(); + public static Integer getNumber() { + return 42; } } """, @@ -68,8 +68,8 @@ class Test { private A aMock = mock(A.class); void test() { - when(A.getA()).thenReturn(aMock); - assertEquals(A.getA(), aMock); + when(A.getNumber()).thenReturn(-1); + assertEquals(A.getNumber(), -1); } } """, @@ -86,8 +86,8 @@ class Test { void test() { try (MockedStatic mockA = mockStatic(com.foo.A.class)) { - mockA.when(A.getA()).thenReturn(aMock); - assertEquals(A.getA(), aMock); + mockA.when(A.getNumber()).thenReturn(-1); + assertEquals(A.getNumber(), -1); } } } From fa1ad4bd19b357e389ef78954dfcb2da8e1fed6a Mon Sep 17 00:00:00 2001 From: Tim te Beek Date: Mon, 16 Sep 2024 22:57:51 +0200 Subject: [PATCH 09/15] Switch to `ListUtils.flatMap` without `mock` field in test --- .../MockitoWhenOnStaticToMockStatic.java | 85 +++++++++---------- .../MockitoWhenOnStaticToMockStaticTest.java | 41 +++++---- 2 files changed, 65 insertions(+), 61 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index cbe0f6d5b..289ef0e6c 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -15,17 +15,17 @@ */ package org.openrewrite.java.testing.mockito; -import org.openrewrite.ExecutionContext; -import org.openrewrite.Preconditions; -import org.openrewrite.Recipe; -import org.openrewrite.TreeVisitor; +import org.openrewrite.*; import org.openrewrite.internal.ListUtils; import org.openrewrite.java.*; import org.openrewrite.java.search.UsesMethod; -import org.openrewrite.java.tree.*; +import org.openrewrite.java.tree.Flag; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.Statement; -import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; public class MockitoWhenOnStaticToMockStatic extends Recipe { @@ -53,56 +53,51 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex return m; } - boolean rewrittenWhen = false; - List statementsBeforeWhen = new ArrayList<>(); - List statementsAfterWhen = new ArrayList<>(); - for (Statement stmt : m.getBody().getStatements()) { - if (stmt instanceof J.MethodInvocation && - MOCKITO_WHEN.matches(((J.MethodInvocation) stmt).getSelect())) { - J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) stmt).getSelect(); + AtomicBoolean restInTry = new AtomicBoolean(false); + List originalStatements = m.getBody().getStatements(); + List newStatements = ListUtils.flatMap(originalStatements, (index, statement) -> { + if (restInTry.get()) { + // Rest of the statements have ended up in the try block + return Collections.emptyList(); + } + + if (statement instanceof J.MethodInvocation && + MOCKITO_WHEN.matches(((J.MethodInvocation) statement).getSelect())) { + J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) statement).getSelect(); if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) { J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0); if (whenArg.getMethodType() != null && whenArg.getMethodType().getFlags().contains(Flag.Static)) { - JavaType.FullyQualified argFq = TypeUtils.asFullyQualified(whenArg.getType()); - J.Identifier ident = (J.Identifier) whenArg.getSelect(); - if (argFq != null && ident != null && ident.getType() != null) { - String mockName = VariableNameUtils.generateVariableName("mock" + ident.getSimpleName(), getCursor(), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER); - m = JavaTemplate.builder(String.format( - "try(MockedStatic<#{}> %1$s = mockStatic(#{}.class)) {\n" + - " %1$s.when(#{any()}).thenReturn(#{any()});\n" + - "}", mockName)) + J.Identifier clazz = (J.Identifier) whenArg.getSelect(); + if (clazz != null && clazz.getType() != null) { + String mockName = VariableNameUtils.generateVariableName("mock" + clazz.getSimpleName(), getCursor(), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER); + maybeAddImport("org.mockito.MockedStatic", false); + maybeAddImport("org.mockito.Mockito", "mockStatic"); + String template = String.format( + "try(MockedStatic<#{}> %1$s = mockStatic(#{}.class)) {\n" + + " %1$s.when(#{any()}).thenReturn(#{any()});\n" + + "}", mockName); + J.Try try_ = (J.Try) ((J.MethodDeclaration) JavaTemplate.builder(template) .contextSensitive() - .javaParser(JavaParser.fromJavaVersion()) .imports("org.mockito.MockedStatic") .staticImports("org.mockito.Mockito.mockStatic") .build() - .apply(getCursor(), stmt.getCoordinates().replace(), - ident.getType(), - ident.getType(), - whenArg, - ((J.MethodInvocation) stmt).getArguments().get(0)); - rewrittenWhen = true; - maybeAddImport("org.mockito.MockedStatic", false); - maybeAddImport("org.mockito.Mockito", "mockStatic"); - continue; + .apply(getCursor(), m.getCoordinates().replaceBody(), + clazz.getType(), clazz.getType(), + whenArg, ((J.MethodInvocation) statement).getArguments().get(0))) + .getBody().getStatements().get(0); + + restInTry.set(true); + return try_.withBody(try_.getBody().withStatements(ListUtils.concatAll( + try_.getBody().getStatements(), + originalStatements.subList(index + 1, originalStatements.size())))); } } - } } - if (rewrittenWhen) { - statementsAfterWhen.add(stmt); - } else { - statementsBeforeWhen.add(stmt); - } - } - if (rewrittenWhen) { - if (m.getBody() != null) { - J.Try try_catch = (J.Try) m.getBody().getStatements().get(statementsBeforeWhen.size()); - return maybeAutoFormat(method, m.withBody(m.getBody().withStatements(ListUtils.concat(statementsBeforeWhen, try_catch.withBody(try_catch.getBody().withStatements(ListUtils.concatAll(try_catch.getBody().getStatements(), statementsAfterWhen)))))), ctx); - } - } - return m; + return statement; + }); + + return maybeAutoFormat(m, m.withBody(m.getBody().withStatements(newStatements)), ctx); } }); } diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java index cf08a19b5..562ed414e 100644 --- a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -64,9 +64,6 @@ public static Integer getNumber() { import static org.mockito.Mockito.*; class Test { - - private A aMock = mock(A.class); - void test() { when(A.getNumber()).thenReturn(-1); assertEquals(A.getNumber(), -1); @@ -81,9 +78,6 @@ void test() { import static org.mockito.Mockito.*; class Test { - - private A aMock = mock(A.class); - void test() { try (MockedStatic mockA = mockStatic(com.foo.A.class)) { mockA.when(A.getNumber()).thenReturn(-1); @@ -105,8 +99,8 @@ void shouldHandleMultipleStaticMocks() { """ package com.foo; public class A { - public static A getA() { - return new A(); + public static Integer getNumber() { + return 42; } } """, @@ -120,20 +114,35 @@ public static A getA() { import static org.mockito.Mockito.*; class Test { - - private A aMock = mock(A.class); - void test() { - when(A.getA()).thenReturn(aMock); - assertEquals(A.getA(), aMock); + when(A.getNumber()).thenReturn(-1); + assertEquals(A.getNumber(), -1); - when(A.getA()).thenReturn(aMock); - assertEquals(A.getA(), aMock); + when(A.getNumber()).thenReturn(-2); + assertEquals(A.getNumber(), -2); } } """, """ - class TODO {} + import com.foo.A; + import org.mockito.MockedStatic; + + import static org.junit.Assert.assertEquals; + import static org.mockito.Mockito.*; + + class Test { + void test() { + try (MockedStatic mockA = mockStatic(com.foo.A.class)) { + mockA.when(A.getNumber()).thenReturn(-1); + assertEquals(A.getNumber(), -1); + + try (MockedStatic mockA2 = mockStatic(com.foo.A.class)) { + mockA2.when(A.getNumber()).thenReturn(-2); + assertEquals(A.getNumber(), -2); + } + } + } + } """ ) ); From 011362cf64d9807f06d5923a22b378b2399b046c Mon Sep 17 00:00:00 2001 From: Tim te Beek Date: Mon, 16 Sep 2024 23:02:47 +0200 Subject: [PATCH 10/15] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../testing/mockito/MockitoWhenOnStaticToMockStatic.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index 289ef0e6c..606c4dd46 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -15,7 +15,10 @@ */ package org.openrewrite.java.testing.mockito; -import org.openrewrite.*; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Preconditions; +import org.openrewrite.Recipe; +import org.openrewrite.TreeVisitor; import org.openrewrite.internal.ListUtils; import org.openrewrite.java.*; import org.openrewrite.java.search.UsesMethod; From 04ca397e7d5f13c0e23e97439549420c3fa97801 Mon Sep 17 00:00:00 2001 From: Laurens Westerlaken Date: Tue, 17 Sep 2024 11:18:02 +0200 Subject: [PATCH 11/15] Make recursive in case multiple occurrences exist --- .../MockitoWhenOnStaticToMockStatic.java | 22 +++++++++++------ .../MockitoWhenOnStaticToMockStaticTest.java | 24 +++++++++---------- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index 606c4dd46..eac784593 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -20,7 +20,10 @@ import org.openrewrite.Recipe; import org.openrewrite.TreeVisitor; import org.openrewrite.internal.ListUtils; -import org.openrewrite.java.*; +import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.JavaTemplate; +import org.openrewrite.java.MethodMatcher; +import org.openrewrite.java.VariableNameUtils; import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.Flag; import org.openrewrite.java.tree.J; @@ -56,8 +59,12 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex return m; } + List newStatements = getStatements(m.getBody().getStatements(), m); + return maybeAutoFormat(m, m.withBody(m.getBody().withStatements(newStatements)), ctx); + } + + private List getStatements(List originalStatements, J.MethodDeclaration m) { AtomicBoolean restInTry = new AtomicBoolean(false); - List originalStatements = m.getBody().getStatements(); List newStatements = ListUtils.flatMap(originalStatements, (index, statement) -> { if (restInTry.get()) { // Rest of the statements have ended up in the try block @@ -72,14 +79,14 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex if (whenArg.getMethodType() != null && whenArg.getMethodType().getFlags().contains(Flag.Static)) { J.Identifier clazz = (J.Identifier) whenArg.getSelect(); if (clazz != null && clazz.getType() != null) { - String mockName = VariableNameUtils.generateVariableName("mock" + clazz.getSimpleName(), getCursor(), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER); + String mockName = VariableNameUtils.generateVariableName("mock" + clazz.getSimpleName(), updateCursor(m), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER); maybeAddImport("org.mockito.MockedStatic", false); maybeAddImport("org.mockito.Mockito", "mockStatic"); String template = String.format( "try(MockedStatic<#{}> %1$s = mockStatic(#{}.class)) {\n" + " %1$s.when(#{any()}).thenReturn(#{any()});\n" + "}", mockName); - J.Try try_ = (J.Try) ((J.MethodDeclaration) JavaTemplate.builder(template) + J.Try try_ = (J.Try) ((J.MethodDeclaration) JavaTemplate.builder(template) .contextSensitive() .imports("org.mockito.MockedStatic") .staticImports("org.mockito.Mockito.mockStatic") @@ -90,17 +97,18 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex .getBody().getStatements().get(0); restInTry.set(true); + return try_.withBody(try_.getBody().withStatements(ListUtils.concatAll( try_.getBody().getStatements(), - originalStatements.subList(index + 1, originalStatements.size())))); + getStatements(originalStatements.subList(index + 1, originalStatements.size()), m.withBody(m.getBody().withStatements(ListUtils.concat(m.getBody().getStatements(), try_))))))) + .withPrefix(statement.getPrefix()); } } } } return statement; }); - - return maybeAutoFormat(m, m.withBody(m.getBody().withStatements(newStatements)), ctx); + return newStatements; } }); } diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java index 562ed414e..19f1683d0 100644 --- a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -59,10 +59,10 @@ public static Integer getNumber() { java( """ import com.foo.A; - + import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; - + class Test { void test() { when(A.getNumber()).thenReturn(-1); @@ -73,10 +73,10 @@ void test() { """ import com.foo.A; import org.mockito.MockedStatic; - + import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; - + class Test { void test() { try (MockedStatic mockA = mockStatic(com.foo.A.class)) { @@ -109,15 +109,15 @@ public static Integer getNumber() { java( """ import com.foo.A; - + import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; - + class Test { void test() { when(A.getNumber()).thenReturn(-1); assertEquals(A.getNumber(), -1); - + when(A.getNumber()).thenReturn(-2); assertEquals(A.getNumber(), -2); } @@ -126,18 +126,18 @@ void test() { """ import com.foo.A; import org.mockito.MockedStatic; - + import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; - + class Test { void test() { try (MockedStatic mockA = mockStatic(com.foo.A.class)) { mockA.when(A.getNumber()).thenReturn(-1); assertEquals(A.getNumber(), -1); - - try (MockedStatic mockA2 = mockStatic(com.foo.A.class)) { - mockA2.when(A.getNumber()).thenReturn(-2); + + try (MockedStatic mockA1 = mockStatic(com.foo.A.class)) { + mockA1.when(A.getNumber()).thenReturn(-2); assertEquals(A.getNumber(), -2); } } From 69d25223fa28baac5af98d509401a8fd176475e3 Mon Sep 17 00:00:00 2001 From: Laurens Westerlaken Date: Tue, 17 Sep 2024 11:45:30 +0200 Subject: [PATCH 12/15] Return immediately --- .../java/testing/mockito/MockitoWhenOnStaticToMockStatic.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index eac784593..702ba0bed 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -65,7 +65,7 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex private List getStatements(List originalStatements, J.MethodDeclaration m) { AtomicBoolean restInTry = new AtomicBoolean(false); - List newStatements = ListUtils.flatMap(originalStatements, (index, statement) -> { + return ListUtils.flatMap(originalStatements, (index, statement) -> { if (restInTry.get()) { // Rest of the statements have ended up in the try block return Collections.emptyList(); @@ -108,7 +108,6 @@ private List getStatements(List originalStatements, J.Meth } return statement; }); - return newStatements; } }); } From 50c9e0caec160cc102ce194426e3573a62233980 Mon Sep 17 00:00:00 2001 From: Tim te Beek Date: Sat, 21 Sep 2024 12:52:42 +0200 Subject: [PATCH 13/15] Reduce nesting and clarify intent through method rename --- .../MockitoWhenOnStaticToMockStatic.java | 103 +++++++++--------- 1 file changed, 52 insertions(+), 51 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index 702ba0bed..66f881f66 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -49,66 +49,67 @@ public String getDescription() { @Override public TreeVisitor getVisitor() { - return Preconditions.check( - new UsesMethod<>(MOCKITO_WHEN), - new JavaIsoVisitor() { - @Override - public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { - J.MethodDeclaration m = super.visitMethodDeclaration(method, ctx); - if (m.getBody() == null) { - return m; - } + return Preconditions.check(new UsesMethod<>(MOCKITO_WHEN), new JavaIsoVisitor() { + @Override + public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { + J.MethodDeclaration m = super.visitMethodDeclaration(method, ctx); + if (m.getBody() == null) { + return m; + } - List newStatements = getStatements(m.getBody().getStatements(), m); - return maybeAutoFormat(m, m.withBody(m.getBody().withStatements(newStatements)), ctx); - } + List newStatements = maybeWrapStatementsInTryWithResourcesMockedStatic(m, m.getBody().getStatements()); + return maybeAutoFormat(m, m.withBody(m.getBody().withStatements(newStatements)), ctx); + } - private List getStatements(List originalStatements, J.MethodDeclaration m) { - AtomicBoolean restInTry = new AtomicBoolean(false); - return ListUtils.flatMap(originalStatements, (index, statement) -> { - if (restInTry.get()) { - // Rest of the statements have ended up in the try block - return Collections.emptyList(); - } + private List maybeWrapStatementsInTryWithResourcesMockedStatic(J.MethodDeclaration m, List originalStatements) { + AtomicBoolean restInTry = new AtomicBoolean(false); + return ListUtils.flatMap(originalStatements, (index, statement) -> { + if (restInTry.get()) { + // Rest of the statements have ended up in the try block + return Collections.emptyList(); + } - if (statement instanceof J.MethodInvocation && - MOCKITO_WHEN.matches(((J.MethodInvocation) statement).getSelect())) { - J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) statement).getSelect(); - if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) { - J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0); - if (whenArg.getMethodType() != null && whenArg.getMethodType().getFlags().contains(Flag.Static)) { - J.Identifier clazz = (J.Identifier) whenArg.getSelect(); - if (clazz != null && clazz.getType() != null) { - String mockName = VariableNameUtils.generateVariableName("mock" + clazz.getSimpleName(), updateCursor(m), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER); - maybeAddImport("org.mockito.MockedStatic", false); - maybeAddImport("org.mockito.Mockito", "mockStatic"); - String template = String.format( - "try(MockedStatic<#{}> %1$s = mockStatic(#{}.class)) {\n" + - " %1$s.when(#{any()}).thenReturn(#{any()});\n" + - "}", mockName); - J.Try try_ = (J.Try) ((J.MethodDeclaration) JavaTemplate.builder(template) - .contextSensitive() - .imports("org.mockito.MockedStatic") - .staticImports("org.mockito.Mockito.mockStatic") - .build() - .apply(getCursor(), m.getCoordinates().replaceBody(), - clazz.getType(), clazz.getType(), - whenArg, ((J.MethodInvocation) statement).getArguments().get(0))) - .getBody().getStatements().get(0); + if (statement instanceof J.MethodInvocation && + MOCKITO_WHEN.matches(((J.MethodInvocation) statement).getSelect())) { + J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) statement).getSelect(); + if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) { + J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0); + if (whenArg.getMethodType() != null && whenArg.getMethodType().hasFlags(Flag.Static)) { + J.Identifier clazz = (J.Identifier) whenArg.getSelect(); + if (clazz != null && clazz.getType() != null) { + String mockName = VariableNameUtils.generateVariableName("mock" + clazz.getSimpleName(), updateCursor(m), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER); + maybeAddImport("org.mockito.MockedStatic", false); + maybeAddImport("org.mockito.Mockito", "mockStatic"); + String template = String.format( + "try(MockedStatic<#{}> %1$s = mockStatic(#{}.class)) {\n" + + " %1$s.when(#{any()}).thenReturn(#{any()});\n" + + "}", mockName); + J.Try try_ = (J.Try) ((J.MethodDeclaration) JavaTemplate.builder(template) + .contextSensitive() + .imports("org.mockito.MockedStatic") + .staticImports("org.mockito.Mockito.mockStatic") + .build() + .apply(getCursor(), m.getCoordinates().replaceBody(), + clazz.getType(), clazz.getType(), + whenArg, ((J.MethodInvocation) statement).getArguments().get(0))) + .getBody().getStatements().get(0); - restInTry.set(true); + restInTry.set(true); - return try_.withBody(try_.getBody().withStatements(ListUtils.concatAll( + return try_.withBody(try_.getBody().withStatements(ListUtils.concatAll( try_.getBody().getStatements(), - getStatements(originalStatements.subList(index + 1, originalStatements.size()), m.withBody(m.getBody().withStatements(ListUtils.concat(m.getBody().getStatements(), try_))))))) - .withPrefix(statement.getPrefix()); - } - } + maybeWrapStatementsInTryWithResourcesMockedStatic( + m.withBody(m.getBody().withStatements(ListUtils.concat(m.getBody().getStatements(), try_))), + originalStatements.subList(index + 1, originalStatements.size()) + )))) + .withPrefix(statement.getPrefix()); } } - return statement; - }); + } } + return statement; }); + } + }); } } From 1976e5120bbbad81f22478aae1c7bbccbb9fd72d Mon Sep 17 00:00:00 2001 From: Tim te Beek Date: Sat, 21 Sep 2024 13:15:45 +0200 Subject: [PATCH 14/15] Only add preceding statements before try, not all --- .../testing/mockito/MockitoWhenOnStaticToMockStatic.java | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index 66f881f66..f393c7f94 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -61,9 +61,9 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex return maybeAutoFormat(m, m.withBody(m.getBody().withStatements(newStatements)), ctx); } - private List maybeWrapStatementsInTryWithResourcesMockedStatic(J.MethodDeclaration m, List originalStatements) { + private List maybeWrapStatementsInTryWithResourcesMockedStatic(J.MethodDeclaration m, List remainingStatements) { AtomicBoolean restInTry = new AtomicBoolean(false); - return ListUtils.flatMap(originalStatements, (index, statement) -> { + return ListUtils.flatMap(remainingStatements, (index, statement) -> { if (restInTry.get()) { // Rest of the statements have ended up in the try block return Collections.emptyList(); @@ -96,11 +96,12 @@ private List maybeWrapStatementsInTryWithResourcesMockedStatic(J.Meth restInTry.set(true); + List precedingStatements = remainingStatements.subList(0, index); return try_.withBody(try_.getBody().withStatements(ListUtils.concatAll( try_.getBody().getStatements(), maybeWrapStatementsInTryWithResourcesMockedStatic( - m.withBody(m.getBody().withStatements(ListUtils.concat(m.getBody().getStatements(), try_))), - originalStatements.subList(index + 1, originalStatements.size()) + m.withBody(m.getBody().withStatements(ListUtils.concat(precedingStatements, try_))), + remainingStatements.subList(index + 1, remainingStatements.size()) )))) .withPrefix(statement.getPrefix()); } From fdd26acf11be876e346fcc7dbbd83d32a24ea42a Mon Sep 17 00:00:00 2001 From: Tim te Beek Date: Sat, 21 Sep 2024 13:30:22 +0200 Subject: [PATCH 15/15] Shorten qualified class refs & enable method type validation --- .../mockito/MockitoWhenOnStaticToMockStatic.java | 7 +++---- .../mockito/MockitoWhenOnStaticToMockStaticTest.java | 10 +++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index f393c7f94..7836ca0e9 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -81,16 +81,15 @@ private List maybeWrapStatementsInTryWithResourcesMockedStatic(J.Meth maybeAddImport("org.mockito.MockedStatic", false); maybeAddImport("org.mockito.Mockito", "mockStatic"); String template = String.format( - "try(MockedStatic<#{}> %1$s = mockStatic(#{}.class)) {\n" + - " %1$s.when(#{any()}).thenReturn(#{any()});\n" + - "}", mockName); + "try(MockedStatic<%1$s> %2$s = mockStatic(%1$s.class)) {\n" + + " %2$s.when(#{any()}).thenReturn(#{any()});\n" + + "}", clazz.getSimpleName(), mockName); J.Try try_ = (J.Try) ((J.MethodDeclaration) JavaTemplate.builder(template) .contextSensitive() .imports("org.mockito.MockedStatic") .staticImports("org.mockito.Mockito.mockStatic") .build() .apply(getCursor(), m.getCoordinates().replaceBody(), - clazz.getType(), clazz.getType(), whenArg, ((J.MethodInvocation) statement).getArguments().get(0))) .getBody().getStatements().get(0); diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java index 19f1683d0..a7eb4b724 100644 --- a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -44,7 +44,7 @@ public void defaults(RecipeSpec spec) { void shouldRefactorMockito_When() { //language=java rewriteRun( - spec -> spec.afterTypeValidationOptions(TypeValidation.builder().methodInvocations(false).identifiers(false).build()), + spec -> spec.afterTypeValidationOptions(TypeValidation.builder().identifiers(false).build()), java( """ package com.foo; @@ -79,7 +79,7 @@ void test() { class Test { void test() { - try (MockedStatic mockA = mockStatic(com.foo.A.class)) { + try (MockedStatic mockA = mockStatic(A.class)) { mockA.when(A.getNumber()).thenReturn(-1); assertEquals(A.getNumber(), -1); } @@ -94,7 +94,7 @@ void test() { void shouldHandleMultipleStaticMocks() { //language=java rewriteRun( - spec -> spec.afterTypeValidationOptions(TypeValidation.builder().methodInvocations(false).identifiers(false).build()), + spec -> spec.afterTypeValidationOptions(TypeValidation.builder().identifiers(false).build()), java( """ package com.foo; @@ -132,11 +132,11 @@ void test() { class Test { void test() { - try (MockedStatic mockA = mockStatic(com.foo.A.class)) { + try (MockedStatic mockA = mockStatic(A.class)) { mockA.when(A.getNumber()).thenReturn(-1); assertEquals(A.getNumber(), -1); - try (MockedStatic mockA1 = mockStatic(com.foo.A.class)) { + try (MockedStatic mockA1 = mockStatic(A.class)) { mockA1.when(A.getNumber()).thenReturn(-2); assertEquals(A.getNumber(), -2); }