From 4f99fe5c4bfbe5d75dec1acea3a824c5f016680a Mon Sep 17 00:00:00 2001 From: arnavb Date: Sat, 13 Dec 2025 10:11:16 +0000 Subject: [PATCH 1/5] [LIVY-1026] Fix cancel issue and update dependencies --- pom.xml | 2 -- .../scala/org/apache/livy/repl/Session.scala | 34 +++++++++++++------ .../apache/livy/repl/SparkSessionSpec.scala | 27 +++++++++++++++ 3 files changed, 50 insertions(+), 13 deletions(-) diff --git a/pom.xml b/pom.xml index 4dc328e96..715de69e5 100644 --- a/pom.xml +++ b/pom.xml @@ -189,8 +189,6 @@ ${java.version} ${java.version} - - test,provided diff --git a/repl/src/main/scala/org/apache/livy/repl/Session.scala b/repl/src/main/scala/org/apache/livy/repl/Session.scala index 262c811c7..d315f4a09 100644 --- a/repl/src/main/scala/org/apache/livy/repl/Session.scala +++ b/repl/src/main/scala/org/apache/livy/repl/Session.scala @@ -19,7 +19,7 @@ package org.apache.livy.repl import java.util.{LinkedHashMap => JLinkedHashMap} import java.util.Map.Entry -import java.util.concurrent.Executors +import java.util.concurrent.{ConcurrentHashMap, Executors} import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ @@ -63,6 +63,8 @@ class Session( private val cancelExecutor = ExecutionContext.fromExecutorService( Executors.newSingleThreadExecutor()) + private val statementThreads = new ConcurrentHashMap[Int, Thread]() + private implicit val formats = DefaultFormats private var _state: SessionState = SessionState.NotStarted @@ -161,18 +163,27 @@ class Session( _statements.synchronized { _statements(statementId) = statement } Future { - setJobGroup(tpe, statementId) - statement.compareAndTransit(StatementState.Waiting, StatementState.Running) + val currentThread = Thread.currentThread() + statementThreads.put(statementId, currentThread) + try { + setJobGroup(tpe, statementId) + statement.compareAndTransit(StatementState.Waiting, StatementState.Running) - if (statement.state.get() == StatementState.Running) { - statement.started = System.currentTimeMillis() - statement.output = executeCode(interpreter(tpe), statementId, code) - } + if (statement.state.get() == StatementState.Running) { + statement.started = System.currentTimeMillis() + statement.output = executeCode(interpreter(tpe), statementId, code) + } - statement.compareAndTransit(StatementState.Running, StatementState.Available) - statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) - statement.updateProgress(1.0) - statement.completed = System.currentTimeMillis() + statement.compareAndTransit(StatementState.Running, StatementState.Available) + statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) + statement.updateProgress(1.0) + statement.completed = System.currentTimeMillis() + } finally { + statementThreads.remove(statementId, currentThread) + // Clear the interrupt flag, but log if the thread was interrupted. + if (Thread.interrupted()) { + logWarning(s"Thread was interrupted during execution of statement $statementId; interrupt flag cleared.") + } }(interpreterExecutor) statementId @@ -212,6 +223,7 @@ class Session( info(s"Failed to cancel statement $statementId.") statement.compareAndTransit(StatementState.Cancelling, StatementState.Cancelled) } else { + Option(statementThreads.get(statementId)).foreach(_.interrupt()) sc.cancelJobGroup(statementId.toString) if (statement.state.get() == StatementState.Cancelling) { Thread.sleep(livyConf.getTimeAsMs(RSCConf.Entry.JOB_CANCEL_TRIGGER_INTERVAL)) diff --git a/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala b/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala index 90e282839..49f03d537 100644 --- a/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala +++ b/repl/src/test/scala/org/apache/livy/repl/SparkSessionSpec.scala @@ -236,6 +236,33 @@ class SparkSessionSpec extends BaseSessionSpec(Spark) { } } + it should "cancel driver code without spark jobs" in withSession { session => + val stmtId = session.execute( + """ + |Thread.sleep(5000) + |val r = 1 + 1 + |r + """.stripMargin) + + eventually(timeout(30 seconds), interval(100 millis)) { + assert(session.statements(stmtId).state.get() == StatementState.Running) + } + + session.cancel(stmtId) + + eventually(timeout(30 seconds), interval(100 millis)) { + val statement = session.statements(stmtId) + assert(statement.state.get() == StatementState.Cancelled) + val resultJson = parse(statement.output) + (resultJson \ "status").extract[String] should equal ("error") + statement.output should not include ("r: Int = 2") + } + + val followUp = execute(session)("r") + val followUpResult = parse(followUp.output) + (followUpResult \ "status").extract[String] should equal ("error") + } + it should "correctly calculate progress" in withSession { session => val executeCode = """ From adfb6198687beff0e85f80858be12f1d3f419f62 Mon Sep 17 00:00:00 2001 From: arnavb Date: Sat, 13 Dec 2025 10:32:13 +0000 Subject: [PATCH 2/5] update --- repl/src/main/scala/org/apache/livy/repl/Session.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/repl/src/main/scala/org/apache/livy/repl/Session.scala b/repl/src/main/scala/org/apache/livy/repl/Session.scala index d315f4a09..c65f2da2c 100644 --- a/repl/src/main/scala/org/apache/livy/repl/Session.scala +++ b/repl/src/main/scala/org/apache/livy/repl/Session.scala @@ -184,6 +184,7 @@ class Session( if (Thread.interrupted()) { logWarning(s"Thread was interrupted during execution of statement $statementId; interrupt flag cleared.") } + } }(interpreterExecutor) statementId From 890a779fdc13f40d4cc2428cd03f761d8d8baadc Mon Sep 17 00:00:00 2001 From: arnavb Date: Sat, 13 Dec 2025 10:35:03 +0000 Subject: [PATCH 3/5] update --- repl/src/main/scala/org/apache/livy/repl/Session.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/repl/src/main/scala/org/apache/livy/repl/Session.scala b/repl/src/main/scala/org/apache/livy/repl/Session.scala index c65f2da2c..71dfeb293 100644 --- a/repl/src/main/scala/org/apache/livy/repl/Session.scala +++ b/repl/src/main/scala/org/apache/livy/repl/Session.scala @@ -182,7 +182,8 @@ class Session( statementThreads.remove(statementId, currentThread) // Clear the interrupt flag, but log if the thread was interrupted. if (Thread.interrupted()) { - logWarning(s"Thread was interrupted during execution of statement $statementId; interrupt flag cleared.") + logWarning(s"Thread was interrupted during execution of statement $statementId; " + + "interrupt flag cleared.") } } }(interpreterExecutor) From 368bd182ed73bcd78d43cf841589016d468ff559 Mon Sep 17 00:00:00 2001 From: arnavb Date: Sat, 13 Dec 2025 10:38:11 +0000 Subject: [PATCH 4/5] update --- repl/src/main/scala/org/apache/livy/repl/Session.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/repl/src/main/scala/org/apache/livy/repl/Session.scala b/repl/src/main/scala/org/apache/livy/repl/Session.scala index 71dfeb293..ba2be36e6 100644 --- a/repl/src/main/scala/org/apache/livy/repl/Session.scala +++ b/repl/src/main/scala/org/apache/livy/repl/Session.scala @@ -182,7 +182,7 @@ class Session( statementThreads.remove(statementId, currentThread) // Clear the interrupt flag, but log if the thread was interrupted. if (Thread.interrupted()) { - logWarning(s"Thread was interrupted during execution of statement $statementId; " + + warn(s"Thread was interrupted during execution of statement $statementId; " + "interrupt flag cleared.") } } From ec9a5266ac2e8a4784366d5f59ba56574b69d124 Mon Sep 17 00:00:00 2001 From: arnavb Date: Sat, 13 Dec 2025 10:40:21 +0000 Subject: [PATCH 5/5] update --- pom.xml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pom.xml b/pom.xml index 715de69e5..4dc328e96 100644 --- a/pom.xml +++ b/pom.xml @@ -189,6 +189,8 @@ ${java.version} ${java.version} + + test,provided