From c5d641e7bb381708f9d78ce928b367854c4a5c97 Mon Sep 17 00:00:00 2001 From: "Craig P. Motlin" Date: Wed, 28 Jan 2026 10:12:49 -0500 Subject: [PATCH] Preserve throws clause when checked exceptions exist outside assertThrows. --- .../JUnitTryFailToAssertThatThrownBy.java | 2 +- .../ExpectedExceptionToAssertThrows.java | 88 ++++++++++++++++++- .../ExpectedExceptionToAssertThrowsTest.java | 59 +++++++++++++ 3 files changed, 147 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/openrewrite/java/testing/assertj/JUnitTryFailToAssertThatThrownBy.java b/src/main/java/org/openrewrite/java/testing/assertj/JUnitTryFailToAssertThatThrownBy.java index f55b39970..1d32214cc 100644 --- a/src/main/java/org/openrewrite/java/testing/assertj/JUnitTryFailToAssertThatThrownBy.java +++ b/src/main/java/org/openrewrite/java/testing/assertj/JUnitTryFailToAssertThatThrownBy.java @@ -111,7 +111,7 @@ public J visitTry(J.Try tryBlock, ExecutionContext ctx) { return JavaTemplate.builder(template) .contextSensitive() .staticImports("org.assertj.core.api.Assertions.assertThatThrownBy") - .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3")) + .javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "junit-jupiter-api-5", "assertj-core-3")) .build() .apply(getCursor(), try_.getCoordinates().replace(), lambdaStatements.toArray()); } diff --git a/src/main/java/org/openrewrite/java/testing/junit5/ExpectedExceptionToAssertThrows.java b/src/main/java/org/openrewrite/java/testing/junit5/ExpectedExceptionToAssertThrows.java index 209cfc1b7..23b9cb913 100644 --- a/src/main/java/org/openrewrite/java/testing/junit5/ExpectedExceptionToAssertThrows.java +++ b/src/main/java/org/openrewrite/java/testing/junit5/ExpectedExceptionToAssertThrows.java @@ -63,6 +63,7 @@ public TreeVisitor getVisitor() { private static class ExpectedExceptionToAssertThrowsVisitor extends JavaIsoVisitor { private static final String FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION = "firstExpectedExceptionMethodInvocation"; + private static final String STATEMENTS_BEFORE_EXPECT_EXCEPTION = "statementsBeforeExpectException"; private static final String STATEMENTS_AFTER_EXPECT_EXCEPTION = "statementsAfterExpectException"; private static final String HAS_MATCHER = "hasMatcher"; private static final String EXCEPTION_CLASS = "exceptionClass"; @@ -100,6 +101,10 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex if (getCursor().pollMessage("hasExpectException") != null) { List thrown = m.getThrows(); if (thrown != null && !thrown.isEmpty()) { + List statementsBeforeExpect = getCursor().pollMessage(STATEMENTS_BEFORE_EXPECT_EXCEPTION); + if (statementsBeforeExpect != null && statementsBeforeExpectThrowCheckedException(statementsBeforeExpect)) { + return m; + } assert m.getBody() != null; return m.withBody(m.getBody().withPrefix(thrown.get(0).getPrefix())).withThrows(emptyList()); } @@ -107,6 +112,62 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex return m; } + private boolean statementsBeforeExpectThrowCheckedException(List statements) { + return statements.stream().anyMatch(this::statementThrowsCheckedException); + } + + private boolean statementThrowsCheckedException(Statement statement) { + AtomicBoolean throwsChecked = new AtomicBoolean(false); + new JavaIsoVisitor() { + @Override + public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, AtomicBoolean found) { + if (found.get()) { + return method; + } + JavaType.Method methodType = method.getMethodType(); + if (methodType == null) { + return super.visitMethodInvocation(method, found); + } + List thrownExceptions = methodType.getThrownExceptions(); + for (JavaType thrownException : thrownExceptions) { + if (isCheckedException(thrownException)) { + found.set(true); + return method; + } + } + return super.visitMethodInvocation(method, found); + } + + @Override + public J.NewClass visitNewClass(J.NewClass newClass, AtomicBoolean found) { + if (found.get()) { + return newClass; + } + JavaType.Method constructorType = newClass.getConstructorType(); + if (constructorType == null) { + return super.visitNewClass(newClass, found); + } + List thrownExceptions = constructorType.getThrownExceptions(); + for (JavaType thrownException : thrownExceptions) { + if (isCheckedException(thrownException)) { + found.set(true); + return newClass; + } + } + return super.visitNewClass(newClass, found); + } + }.visit(statement, throwsChecked); + return throwsChecked.get(); + } + + private boolean isCheckedException(JavaType exceptionType) { + if (exceptionType == null) { + return false; + } + return !TypeUtils.isAssignableTo("java.lang.RuntimeException", exceptionType) && + !TypeUtils.isAssignableTo("java.lang.Error", exceptionType); + } + @Override public J.Block visitBlock(J.Block block, ExecutionContext ctx) { J.Block b = super.visitBlock(block, ctx); @@ -175,7 +236,13 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu return method; } getCursor().dropParentUntil(J.MethodDeclaration.class::isInstance).putMessage("hasExpectException", true); - getCursor().dropParentUntil(J.Block.class::isInstance).computeMessageIfAbsent(FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION, k -> method); + Cursor blockCursor = getCursor().dropParentUntil(J.Block.class::isInstance); + blockCursor.computeMessageIfAbsent(FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION, k -> method); + + List predecessorStatements = findPredecessorStatements(getCursor()); + getCursor().dropParentUntil(J.MethodDeclaration.class::isInstance) + .computeMessageIfAbsent(STATEMENTS_BEFORE_EXPECT_EXCEPTION, k -> predecessorStatements); + List successorStatements = findSuccessorStatements(getCursor()); getCursor().putMessageOnFirstEnclosing(J.Block.class, STATEMENTS_AFTER_EXPECT_EXCEPTION, successorStatements); if (EXPECTED_EXCEPTION_CLASS_MATCHER.matches(method)) { @@ -186,6 +253,25 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu return method; } + /** + * From the current cursor point find all preceding statements in the method body. + */ + private List findPredecessorStatements(Cursor cursor) { + J.MethodDeclaration methodDecl = cursor.firstEnclosing(J.MethodDeclaration.class); + if (methodDecl == null || methodDecl.getBody() == null) { + return emptyList(); + } + List predecessorStatements = new ArrayList<>(); + Statement currentStatement = cursor.firstEnclosing(Statement.class); + for (Statement statement : methodDecl.getBody().getStatements()) { + if (statement == currentStatement) { + break; + } + predecessorStatements.add(statement); + } + return predecessorStatements; + } + /** * From the current cursor point find all the next statements that can be executed in the current path. */ diff --git a/src/test/java/org/openrewrite/java/testing/junit5/ExpectedExceptionToAssertThrowsTest.java b/src/test/java/org/openrewrite/java/testing/junit5/ExpectedExceptionToAssertThrowsTest.java index f07923935..14bbc147e 100644 --- a/src/test/java/org/openrewrite/java/testing/junit5/ExpectedExceptionToAssertThrowsTest.java +++ b/src/test/java/org/openrewrite/java/testing/junit5/ExpectedExceptionToAssertThrowsTest.java @@ -424,6 +424,65 @@ public void expectExceptionUseCases() { ); } + @Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/55") + @Test + void preserveThrowsWhenCodeBeforeExpectThrowsCheckedException() { + //language=java + rewriteRun( + java( + """ + import org.junit.Rule; + import org.junit.Test; + import org.junit.rules.ExpectedException; + + class MyTest { + + @Rule + ExpectedException thrown = ExpectedException.none(); + + @Test + public void testMethod() throws InterruptedException { + setup(); + this.thrown.expect(IllegalArgumentException.class); + doSomething(); + } + + void setup() throws InterruptedException { + Thread.sleep(100); + } + + void doSomething() { + throw new IllegalArgumentException(); + } + } + """, + """ + import org.junit.Test; + + import static org.junit.jupiter.api.Assertions.assertThrows; + + class MyTest { + + @Test + public void testMethod() throws InterruptedException { + setup(); + assertThrows(IllegalArgumentException.class, () -> + doSomething()); + } + + void setup() throws InterruptedException { + Thread.sleep(100); + } + + void doSomething() { + throw new IllegalArgumentException(); + } + } + """ + ) + ); + } + @Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/563") @Test void expectedCheckedExceptionThrowsRemoved() {