diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java new file mode 100644 index 000000000..7836ca0e9 --- /dev/null +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -0,0 +1,115 @@ +/* + * Copyright 2024 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.testing.mockito; + +import org.openrewrite.ExecutionContext; +import org.openrewrite.Preconditions; +import org.openrewrite.Recipe; +import org.openrewrite.TreeVisitor; +import org.openrewrite.internal.ListUtils; +import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.JavaTemplate; +import org.openrewrite.java.MethodMatcher; +import org.openrewrite.java.VariableNameUtils; +import org.openrewrite.java.search.UsesMethod; +import org.openrewrite.java.tree.Flag; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.Statement; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +public class MockitoWhenOnStaticToMockStatic extends Recipe { + + private static final MethodMatcher MOCKITO_WHEN = new MethodMatcher("org.mockito.Mockito when(..)"); + + @Override + public String getDisplayName() { + return "Replace `Mockito.when` on static (non mock) with try-with-resource with MockedStatic"; + } + + @Override + public String getDescription() { + return "Replace `Mockito.when` on static (non mock) with try-with-resource with MockedStatic as Mockito4 no longer allows this."; + } + + @Override + public TreeVisitor getVisitor() { + return Preconditions.check(new UsesMethod<>(MOCKITO_WHEN), new JavaIsoVisitor() { + @Override + public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) { + J.MethodDeclaration m = super.visitMethodDeclaration(method, ctx); + if (m.getBody() == null) { + return m; + } + + List newStatements = maybeWrapStatementsInTryWithResourcesMockedStatic(m, m.getBody().getStatements()); + return maybeAutoFormat(m, m.withBody(m.getBody().withStatements(newStatements)), ctx); + } + + private List maybeWrapStatementsInTryWithResourcesMockedStatic(J.MethodDeclaration m, List remainingStatements) { + AtomicBoolean restInTry = new AtomicBoolean(false); + return ListUtils.flatMap(remainingStatements, (index, statement) -> { + if (restInTry.get()) { + // Rest of the statements have ended up in the try block + return Collections.emptyList(); + } + + if (statement instanceof J.MethodInvocation && + MOCKITO_WHEN.matches(((J.MethodInvocation) statement).getSelect())) { + J.MethodInvocation when = (J.MethodInvocation) ((J.MethodInvocation) statement).getSelect(); + if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) { + J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0); + if (whenArg.getMethodType() != null && whenArg.getMethodType().hasFlags(Flag.Static)) { + J.Identifier clazz = (J.Identifier) whenArg.getSelect(); + if (clazz != null && clazz.getType() != null) { + String mockName = VariableNameUtils.generateVariableName("mock" + clazz.getSimpleName(), updateCursor(m), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER); + maybeAddImport("org.mockito.MockedStatic", false); + maybeAddImport("org.mockito.Mockito", "mockStatic"); + String template = String.format( + "try(MockedStatic<%1$s> %2$s = mockStatic(%1$s.class)) {\n" + + " %2$s.when(#{any()}).thenReturn(#{any()});\n" + + "}", clazz.getSimpleName(), mockName); + J.Try try_ = (J.Try) ((J.MethodDeclaration) JavaTemplate.builder(template) + .contextSensitive() + .imports("org.mockito.MockedStatic") + .staticImports("org.mockito.Mockito.mockStatic") + .build() + .apply(getCursor(), m.getCoordinates().replaceBody(), + whenArg, ((J.MethodInvocation) statement).getArguments().get(0))) + .getBody().getStatements().get(0); + + restInTry.set(true); + + List precedingStatements = remainingStatements.subList(0, index); + return try_.withBody(try_.getBody().withStatements(ListUtils.concatAll( + try_.getBody().getStatements(), + maybeWrapStatementsInTryWithResourcesMockedStatic( + m.withBody(m.getBody().withStatements(ListUtils.concat(precedingStatements, try_))), + remainingStatements.subList(index + 1, remainingStatements.size()) + )))) + .withPrefix(statement.getPrefix()); + } + } + } + } + return statement; + }); + } + }); + } +} diff --git a/src/main/resources/META-INF/rewrite/mockito.yml b/src/main/resources/META-INF/rewrite/mockito.yml index fd6028729..7e60ad788 100644 --- a/src/main/resources/META-INF/rewrite/mockito.yml +++ b/src/main/resources/META-INF/rewrite/mockito.yml @@ -59,6 +59,7 @@ tags: - mockito recipeList: - org.openrewrite.java.testing.mockito.Mockito1to3Migration + - org.openrewrite.java.testing.mockito.MockitoWhenOnStaticToMockStatic - org.openrewrite.java.dependencies.UpgradeDependencyVersion: groupId: org.mockito artifactId: "*" diff --git a/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java new file mode 100644 index 000000000..a7eb4b724 --- /dev/null +++ b/src/test/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStaticTest.java @@ -0,0 +1,150 @@ +/* + * Copyright 2024 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.java.testing.mockito; + +import org.junit.jupiter.api.Test; +import org.openrewrite.DocumentExample; +import org.openrewrite.InMemoryExecutionContext; +import org.openrewrite.java.JavaParser; +import org.openrewrite.test.RecipeSpec; +import org.openrewrite.test.RewriteTest; +import org.openrewrite.test.SourceSpec; +import org.openrewrite.test.TypeValidation; + +import static org.openrewrite.java.Assertions.java; + +class MockitoWhenOnStaticToMockStaticTest implements RewriteTest { + + @Override + public void defaults(RecipeSpec spec) { + spec.recipe(new MockitoWhenOnStaticToMockStatic()) + .parser(JavaParser.fromJavaVersion() + .classpathFromResources(new InMemoryExecutionContext(), + "junit-4.13", + "mockito-core-3.12", + "mockito-junit-jupiter-3.12" + )); + } + + @DocumentExample + @Test + void shouldRefactorMockito_When() { + //language=java + rewriteRun( + spec -> spec.afterTypeValidationOptions(TypeValidation.builder().identifiers(false).build()), + java( + """ + package com.foo; + public class A { + public static Integer getNumber() { + return 42; + } + } + """, + SourceSpec::skip + ), + java( + """ + import com.foo.A; + + import static org.junit.Assert.assertEquals; + import static org.mockito.Mockito.*; + + class Test { + void test() { + when(A.getNumber()).thenReturn(-1); + assertEquals(A.getNumber(), -1); + } + } + """, + """ + import com.foo.A; + import org.mockito.MockedStatic; + + import static org.junit.Assert.assertEquals; + import static org.mockito.Mockito.*; + + class Test { + void test() { + try (MockedStatic mockA = mockStatic(A.class)) { + mockA.when(A.getNumber()).thenReturn(-1); + assertEquals(A.getNumber(), -1); + } + } + } + """ + ) + ); + } + + @Test + void shouldHandleMultipleStaticMocks() { + //language=java + rewriteRun( + spec -> spec.afterTypeValidationOptions(TypeValidation.builder().identifiers(false).build()), + java( + """ + package com.foo; + public class A { + public static Integer getNumber() { + return 42; + } + } + """, + SourceSpec::skip + ), + java( + """ + import com.foo.A; + + import static org.junit.Assert.assertEquals; + import static org.mockito.Mockito.*; + + class Test { + void test() { + when(A.getNumber()).thenReturn(-1); + assertEquals(A.getNumber(), -1); + + when(A.getNumber()).thenReturn(-2); + assertEquals(A.getNumber(), -2); + } + } + """, + """ + import com.foo.A; + import org.mockito.MockedStatic; + + import static org.junit.Assert.assertEquals; + import static org.mockito.Mockito.*; + + class Test { + void test() { + try (MockedStatic mockA = mockStatic(A.class)) { + mockA.when(A.getNumber()).thenReturn(-1); + assertEquals(A.getNumber(), -1); + + try (MockedStatic mockA1 = mockStatic(A.class)) { + mockA1.when(A.getNumber()).thenReturn(-2); + assertEquals(A.getNumber(), -2); + } + } + } + } + """ + ) + ); + } +}