diff --git a/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java b/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java index 7ab2d3a79..c05d6a732 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertion.java @@ -29,6 +29,7 @@ import org.openrewrite.java.MethodMatcher; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JavaType; import org.openrewrite.java.tree.TypeUtils; import java.util.ArrayList; @@ -112,6 +113,24 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation methodInvocat if (!TypeUtils.isAssignableTo(requiredType, actual.getType())) { return mi; } + + // Skip transformation when actual type has wildcard type parameters and the + // dedicated assertion takes arguments (e.g. containsEntry on Map, contains on Optional). + // When both arguments are empty the dedicated assertion has no parameters, so wildcards are not an issue. + boolean assertThatArgumentIsEmpty = assertThatArg.getArguments().get(0) instanceof J.Empty; + boolean methodToReplaceArgumentIsEmpty = mi.getArguments().get(0) instanceof J.Empty; + if (!(assertThatArgumentIsEmpty && methodToReplaceArgumentIsEmpty)) { + JavaType.Parameterized parameterized = TypeUtils.asParameterized(actual.getType()); + if (parameterized != null) { + for (JavaType typeParam : parameterized.getTypeParameters()) { + if (typeParam instanceof JavaType.GenericTypeVariable && + "?".equals(((JavaType.GenericTypeVariable) typeParam).getName())) { + return mi; + } + } + } + } + List arguments = new ArrayList<>(); arguments.add(actual); diff --git a/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java b/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java index 2117c3366..2ab17e05f 100644 --- a/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java +++ b/src/test/java/org/openrewrite/java/testing/assertj/SimplifyChainedAssertJAssertionTest.java @@ -408,6 +408,127 @@ Map getMap() { ); } + @Test + void mapGetIsEqualToWithBoundedWildcardTypeIsNotConverted() { + rewriteRun( + spec -> spec.recipe(new SimplifyChainedAssertJAssertion("get", "isEqualTo", "containsEntry", "java.util.Map")), + //language=java + java( + """ + import java.util.Map; + + import static org.assertj.core.api.Assertions.assertThat; + + class MyTest { + void testExtendsWildcard(Map map) { + assertThat(map.get("key")).isEqualTo(42); + } + void testSuperWildcard(Map map) { + assertThat(map.get("key")).isEqualTo(42); + } + } + """ + ) + ); + } + + @Test + void mapGetIsEqualToWithNestedWildcardTypeIsConverted() { + rewriteRun( + spec -> spec.recipe(new SimplifyChainedAssertJAssertion("get", "isEqualTo", "containsEntry", "java.util.Map")), + //language=java + java( + """ + import java.util.List; + import java.util.Map; + + import static org.assertj.core.api.Assertions.assertThat; + + class MyTest { + void testMethod(Map> map, List value) { + assertThat(map.get("key")).isEqualTo(value); + } + } + """, + """ + import java.util.List; + import java.util.Map; + + import static org.assertj.core.api.Assertions.assertThat; + + class MyTest { + void testMethod(Map> map, List value) { + assertThat(map).containsEntry("key", value); + } + } + """ + ) + ); + } + + @Test + void mapGetIsEqualToWithWildcardTypeIsNotConverted() { + rewriteRun( + spec -> spec.recipe(new SimplifyChainedAssertJAssertion("get", "isEqualTo", "containsEntry", "java.util.Map")), + //language=java + java( + """ + import java.util.Map; + + import static org.assertj.core.api.Assertions.assertThat; + + class MyTest { + void testMethod(Map map) { + assertThat(map.get("key")).isEqualTo("value"); + } + } + """ + ) + ); + } + + @Test + void optionalGetIsEqualToWithWildcardTypeIsNotConverted() { + rewriteRun( + spec -> spec.recipe(new SimplifyChainedAssertJAssertion("get", "isEqualTo", "contains", "java.util.Optional")), + //language=java + java( + """ + import java.util.Optional; + + import static org.assertj.core.api.Assertions.assertThat; + + class MyTest { + void testMethod(Optional opt) { + assertThat(opt.get()).isEqualTo("value"); + } + } + """ + ) + ); + } + + @Test + void collectionContainsWithWildcardTypeIsNotConverted() { + rewriteRun( + spec -> spec.recipe(new SimplifyChainedAssertJAssertion("contains", "isTrue", "contains", "java.util.Collection")), + //language=java + java( + """ + import java.util.Collection; + + import static org.assertj.core.api.Assertions.assertThat; + + class MyTest { + void testMethod(Collection coll) { + assertThat(coll.contains("element")).isTrue(); + } + } + """ + ) + ); + } + @Test void keySetContainsWithMultipleArguments() { rewriteRun(