diff --git a/repl/src/main/scala/org/apache/livy/repl/Session.scala b/repl/src/main/scala/org/apache/livy/repl/Session.scala index 262c811c7..ba2be36e6 100644 --- a/repl/src/main/scala/org/apache/livy/repl/Session.scala +++ b/repl/src/main/scala/org/apache/livy/repl/Session.scala @@ -19,7 +19,7 @@ package org.apache.livy.repl import java.util.{LinkedHashMap => JLinkedHashMap} import java.util.Map.Entry -import java.util.concurrent.Executors +import java.util.concurrent.{ConcurrentHashMap, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ @@ -63,6 +63,8 @@ class Session( private val cancelExecutor = ExecutionContext.fromExecutorService( Executors.newSingleThreadExecutor()) + private val statementThreads = new ConcurrentHashMap[Int, Thread]() + private implicit val formats = DefaultFormats private var _state: SessionState = SessionState.NotStarted @@ -161,18 +163,29 @@ class Session( _statements.synchronized { _statements(statementId) = statement } Future { - setJobGroup(tpe, statementId) - statement.compareAndTransit(StatementState.Waiting, StatementState.Running) + val currentThread = Thread.currentThread() + statementThreads.put(statementId, currentThread) + try { + setJobGroup(tpe, statementId) + statement.compareAndTransit(StatementState.Waiting, StatementState.Running) - if (statement.state.get() == StatementState.Running) { - statement.started = System.currentTimeMillis() - statement.output = executeCode(interpreter(tpe), statementId, code) - } + if (statement.state.get() == StatementState.Running) { + statement.started = System.currentTimeMillis() + statement.output = executeCode(interpreter(tpe), statementId, code) + } - statement.compareAndTransit(StatementState.Running, StatementState.Available) - statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) - statement.updateProgress(1.0) - statement.completed = System.currentTimeMillis() + statement.compareAndTransit(StatementState.Running, StatementState.Available) + statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) + statement.updateProgress(1.0) + statement.completed = System.currentTimeMillis() + } finally { + statementThreads.remove(statementId, currentThread) + // Clear the interrupt flag, but log if the thread was interrupted. + if (Thread.interrupted()) { + warn(s"Thread was interrupted during execution of statement $statementId; " + + "interrupt flag cleared.") + } + } }(interpreterExecutor) statementId @@ -212,6 +225,7 @@ class Session( info(s"Failed to cancel statement $statementId.") statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) } else { + Option(statementThreads.get(statementId)).foreach(_.interrupt()) sc.cancelJobGroup(statementId.toString) if (statement.state.get() == StatementState.Cancelling) { Thread.sleep(livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TRIGGER_INTERVAL)) diff --git a/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala b/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala index 90e282839..49f03d537 100644 --- a/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala @@ -236,6 +236,33 @@ class SparkSessionSpec extends BaseSessionSpec(Spark) { } } + it should "cancel driver code without spark jobs" in withSession { session => + val stmtId = session.execute( + """ + |Thread.sleep(5000) + |val r = 1 + 1 + |r + """.stripMargin) + + eventually(timeout(30 seconds), interval(100 millis)) { + assert(session.statements(stmtId).state.get() == StatementState.Running) + } + + session.cancel(stmtId) + + eventually(timeout(30 seconds), interval(100 millis)) { + val statement = session.statements(stmtId) + assert(statement.state.get() == StatementState.Cancelled) + val resultJson = parse(statement.output) + (resultJson \ "status").extract[String] should equal ("error") + statement.output should not include ("r: Int = 2") + } + + val followUp = execute(session)("r") + val followUpResult = parse(followUp.output) + (followUpResult \ "status").extract[String] should equal ("error") + } + it should "correctly calculate progress" in withSession { session => val executeCode = """