Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public J visitTry(J.Try tryBlock, ExecutionContext ctx) {
return JavaTemplate.builder(template)
.contextSensitive()
.staticImports("org.assertj.core.api.Assertions.assertThatThrownBy")
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "assertj-core-3"))
.javaParser(JavaParser.fromJavaVersion().classpathFromResources(ctx, "junit-jupiter-api-5", "assertj-core-3"))
.build()
.<J.MethodInvocation>apply(getCursor(), try_.getCoordinates().replace(), lambdaStatements.toArray());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
private static class ExpectedExceptionToAssertThrowsVisitor extends JavaIsoVisitor<ExecutionContext> {

private static final String FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION = "firstExpectedExceptionMethodInvocation";
private static final String STATEMENTS_BEFORE_EXPECT_EXCEPTION = "statementsBeforeExpectException";
private static final String STATEMENTS_AFTER_EXPECT_EXCEPTION = "statementsAfterExpectException";
private static final String HAS_MATCHER = "hasMatcher";
private static final String EXCEPTION_CLASS = "exceptionClass";
Expand Down Expand Up @@ -100,13 +101,73 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, Ex
if (getCursor().pollMessage("hasExpectException") != null) {
List<NameTree> thrown = m.getThrows();
if (thrown != null && !thrown.isEmpty()) {
List<Statement> statementsBeforeExpect = getCursor().pollMessage(STATEMENTS_BEFORE_EXPECT_EXCEPTION);
if (statementsBeforeExpect != null && statementsBeforeExpectThrowCheckedException(statementsBeforeExpect)) {
return m;
}
assert m.getBody() != null;
return m.withBody(m.getBody().withPrefix(thrown.get(0).getPrefix())).withThrows(emptyList());
}
}
return m;
}

private boolean statementsBeforeExpectThrowCheckedException(List<Statement> statements) {
return statements.stream().anyMatch(this::statementThrowsCheckedException);
}

private boolean statementThrowsCheckedException(Statement statement) {
AtomicBoolean throwsChecked = new AtomicBoolean(false);
new JavaIsoVisitor<AtomicBoolean>() {
@Override
public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, AtomicBoolean found) {
if (found.get()) {
return method;
}
JavaType.Method methodType = method.getMethodType();
if (methodType == null) {
return super.visitMethodInvocation(method, found);
}
List<JavaType> thrownExceptions = methodType.getThrownExceptions();
for (JavaType thrownException : thrownExceptions) {
if (isCheckedException(thrownException)) {
found.set(true);
return method;
}
}
return super.visitMethodInvocation(method, found);
}

@Override
public J.NewClass visitNewClass(J.NewClass newClass, AtomicBoolean found) {
if (found.get()) {
return newClass;
}
JavaType.Method constructorType = newClass.getConstructorType();
if (constructorType == null) {
return super.visitNewClass(newClass, found);
}
List<JavaType> thrownExceptions = constructorType.getThrownExceptions();
for (JavaType thrownException : thrownExceptions) {
if (isCheckedException(thrownException)) {
found.set(true);
return newClass;
}
}
return super.visitNewClass(newClass, found);
}
}.visit(statement, throwsChecked);
return throwsChecked.get();
}

private boolean isCheckedException(JavaType exceptionType) {
if (exceptionType == null) {
return false;
}
return !TypeUtils.isAssignableTo("java.lang.RuntimeException", exceptionType) &&
!TypeUtils.isAssignableTo("java.lang.Error", exceptionType);
}

@Override
public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
J.Block b = super.visitBlock(block, ctx);
Expand Down Expand Up @@ -175,7 +236,13 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
return method;
}
getCursor().dropParentUntil(J.MethodDeclaration.class::isInstance).putMessage("hasExpectException", true);
getCursor().dropParentUntil(J.Block.class::isInstance).computeMessageIfAbsent(FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION, k -> method);
Cursor blockCursor = getCursor().dropParentUntil(J.Block.class::isInstance);
blockCursor.computeMessageIfAbsent(FIRST_EXPECTED_EXCEPTION_METHOD_INVOCATION, k -> method);

List<Statement> predecessorStatements = findPredecessorStatements(getCursor());
getCursor().dropParentUntil(J.MethodDeclaration.class::isInstance)
.computeMessageIfAbsent(STATEMENTS_BEFORE_EXPECT_EXCEPTION, k -> predecessorStatements);

List<Statement> successorStatements = findSuccessorStatements(getCursor());
getCursor().putMessageOnFirstEnclosing(J.Block.class, STATEMENTS_AFTER_EXPECT_EXCEPTION, successorStatements);
if (EXPECTED_EXCEPTION_CLASS_MATCHER.matches(method)) {
Expand All @@ -186,6 +253,25 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
return method;
}

/**
* From the current cursor point find all preceding statements in the method body.
*/
private List<Statement> findPredecessorStatements(Cursor cursor) {
J.MethodDeclaration methodDecl = cursor.firstEnclosing(J.MethodDeclaration.class);
if (methodDecl == null || methodDecl.getBody() == null) {
return emptyList();
}
List<Statement> predecessorStatements = new ArrayList<>();
Statement currentStatement = cursor.firstEnclosing(Statement.class);
for (Statement statement : methodDecl.getBody().getStatements()) {
if (statement == currentStatement) {
break;
}
predecessorStatements.add(statement);
}
return predecessorStatements;
}

/**
* From the current cursor point find all the next statements that can be executed in the current path.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,65 @@ public void expectExceptionUseCases() {
);
}

@Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/55")
@Test
void preserveThrowsWhenCodeBeforeExpectThrowsCheckedException() {
//language=java
rewriteRun(
java(
"""
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

class MyTest {

@Rule
ExpectedException thrown = ExpectedException.none();

@Test
public void testMethod() throws InterruptedException {
setup();
this.thrown.expect(IllegalArgumentException.class);
doSomething();
}

void setup() throws InterruptedException {
Thread.sleep(100);
}

void doSomething() {
throw new IllegalArgumentException();
}
}
""",
"""
import org.junit.Test;

import static org.junit.jupiter.api.Assertions.assertThrows;

class MyTest {

@Test
public void testMethod() throws InterruptedException {
setup();
assertThrows(IllegalArgumentException.class, () ->
doSomething());
}

void setup() throws InterruptedException {
Thread.sleep(100);
}

void doSomething() {
throw new IllegalArgumentException();
}
}
"""
)
);
}

@Issue("https://github.com/openrewrite/rewrite-testing-frameworks/issues/563")
@Test
void expectedCheckedExceptionThrowsRemoved() {
Expand Down