diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala index 21b9471bb6069..fb6816302a296 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala @@ -128,8 +128,26 @@ class SparkConnectConnection(val url: String, val info: Properties) extends Conn throw new SQLFeatureNotSupportedException override def createStatement( - resultSetType: Int, resultSetConcurrency: Int): Statement = - throw new SQLFeatureNotSupportedException + resultSetType: Int, resultSetConcurrency: Int): Statement = { + checkSupportedResultSet(resultSetType, resultSetConcurrency) + createStatement() + } + + // Spark Connect results are forward-only and server-paginated, so only + // TYPE_FORWARD_ONLY result sets are supported. + private def checkSupportedResultSet( + resultSetType: Int, resultSetConcurrency: Int): Unit = { + if (resultSetType != ResultSet.TYPE_FORWARD_ONLY) { + throw new SQLFeatureNotSupportedException( + s"ResultSet type ${stringifyResultSetType(resultSetType)} is not supported; " + + "only TYPE_FORWARD_ONLY.") + } + if (resultSetConcurrency != ResultSet.CONCUR_READ_ONLY) { + throw new SQLFeatureNotSupportedException( + s"ResultSet concurrency ${stringifyResultSetConcurrency(resultSetConcurrency)} " + + "is not supported; only CONCUR_READ_ONLY.") + } + } override def prepareStatement( sql: String, diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala index d1947ae93a40c..245832087268e 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.client.jdbc import java.sql.{Array => _, _} import org.apache.spark.sql.connect.client.SparkResult +import org.apache.spark.sql.connect.client.jdbc.util.JdbcErrorUtils class SparkConnectStatement(conn: SparkConnectConnection) extends Statement { @@ -28,6 +29,8 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement { private var maxRows: Int = 0 + private var resultsExhausted: Boolean = false + @volatile private var closed: Boolean = false override def isClosed: Boolean = closed @@ -94,6 +97,7 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement { // reset before executing new query operationId = null resultSet = null + resultsExhausted = false var df = conn.spark.sql(sql) if (maxRows > 0) { @@ -143,8 +147,13 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement { 0 } - override def setQueryTimeout(seconds: Int): Unit = - throw new SQLFeatureNotSupportedException + // This driver does not apply a query timeout; validate and silently drop the value. + override def setQueryTimeout(seconds: Int): Unit = { + checkOpen() + if (seconds < 0) { + throw new SQLException("Query timeout must be zero or a positive integer.") + } + } override def cancel(): Unit = { checkOpen() @@ -164,35 +173,60 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement { override def getUpdateCount: Int = { checkOpen() - if (resultSet != null) { + if (resultsExhausted || resultSet != null) { -1 } else { 0 // always return 0 because affected rows is not supported yet } } - override def getMoreResults: Boolean = - throw new SQLFeatureNotSupportedException + // a single result per execute(), so there is no next one: close the current + // ResultSet and mark exhausted, flipping getUpdateCount() to -1 so drain loops end + override def getMoreResults: Boolean = { + checkOpen() + if (resultSet != null) { + resultSet.close() + resultSet = null + } + resultsExhausted = true + false + } - override def setFetchDirection(direction: Int): Unit = - throw new SQLFeatureNotSupportedException + override def setFetchDirection(direction: Int): Unit = { + checkOpen() + if (direction != ResultSet.FETCH_FORWARD) { + throw new SQLException( + s"Fetch direction ${JdbcErrorUtils.stringifyFetchDirection(direction)} is not supported.") + } + } - override def getFetchDirection: Int = - throw new SQLFeatureNotSupportedException + override def getFetchDirection: Int = { + checkOpen() + ResultSet.FETCH_FORWARD + } - override def setFetchSize(rows: Int): Unit = - throw new SQLFeatureNotSupportedException + // This driver does not apply a fetch size hint; validate and silently drop the value. + override def setFetchSize(rows: Int): Unit = { + checkOpen() + if (rows < 0) { + throw new SQLException("Fetch size must be zero or a positive integer.") + } + } - override def getFetchSize: Int = - throw new SQLFeatureNotSupportedException + override def getFetchSize: Int = { + checkOpen() + 0 + } override def getResultSetConcurrency: Int = { checkOpen() ResultSet.CONCUR_READ_ONLY } - override def getResultSetType: Int = - throw new SQLFeatureNotSupportedException + override def getResultSetType: Int = { + checkOpen() + ResultSet.TYPE_FORWARD_ONLY + } override def addBatch(sql: String): Unit = throw new SQLFeatureNotSupportedException diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala index 6480c5d768f3f..a732dae0d6af5 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala @@ -46,6 +46,13 @@ private[jdbc] object JdbcErrorUtils { throw new IllegalArgumentException(s"Invalid ResultSet type: $typ") } + def stringifyResultSetConcurrency(concurrency: Int): String = concurrency match { + case ResultSet.CONCUR_READ_ONLY => "CONCUR_READ_ONLY" + case ResultSet.CONCUR_UPDATABLE => "CONCUR_UPDATABLE" + case _ => + throw new IllegalArgumentException(s"Invalid ResultSet concurrency: $concurrency") + } + def stringifyFetchDirection(direction: Int): String = direction match { case ResultSet.FETCH_FORWARD => "FETCH_FORWARD" case ResultSet.FETCH_REVERSE => "FETCH_REVERSE" diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatementSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatementSuite.scala index fa9df3f1247f7..1e768a5888e3b 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatementSuite.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatementSuite.scala @@ -116,4 +116,88 @@ class SparkConnectStatementSuite extends ConnectFunSuite with RemoteSparkSession } } } + + test("fetch size, fetch direction, result set type and query timeout accessors") { + withStatement { stmt => + // fetch size: validated then silently dropped, always reads back as 0 + assert(stmt.getFetchSize === 0) + stmt.setFetchSize(42) + assert(stmt.getFetchSize === 0) + val se1 = intercept[SQLException] { + stmt.setFetchSize(-1) + } + assert(se1.getMessage === "Fetch size must be zero or a positive integer.") + + // fetch direction: only FETCH_FORWARD is supported + stmt.setFetchDirection(ResultSet.FETCH_FORWARD) + assert(stmt.getFetchDirection === ResultSet.FETCH_FORWARD) + intercept[SQLException] { + stmt.setFetchDirection(ResultSet.FETCH_REVERSE) + } + + // result set type is forward-only + assert(stmt.getResultSetType === ResultSet.TYPE_FORWARD_ONLY) + + // query timeout: validated then silently dropped, always reads back as 0 + assert(stmt.getQueryTimeout === 0) + stmt.setQueryTimeout(30) + assert(stmt.getQueryTimeout === 0) + val se2 = intercept[SQLException] { + stmt.setQueryTimeout(-1) + } + assert(se2.getMessage === "Query timeout must be zero or a positive integer.") + } + } + + test("getMoreResults terminates JDBC drain loops") { + // A typical JDBC result-draining loop. With getMoreResults throwing (or not + // flipping getUpdateCount to -1) this would spin forever; assert it returns. + def drain(stmt: Statement): Unit = { + while (stmt.getMoreResults || stmt.getUpdateCount != -1) {} + } + + withTable("t_drain") { + withStatement { stmt => + // result-bearing command + assert(stmt.execute("SELECT id FROM range(3)")) + assert(stmt.getUpdateCount === -1) + drain(stmt) + assert(stmt.getResultSet === null) + assert(stmt.getUpdateCount === -1) + + // result-less command + assert(!stmt.execute("CREATE TABLE t_drain (id INT) USING Parquet")) + assert(stmt.getUpdateCount === 0) + drain(stmt) + assert(stmt.getUpdateCount === -1) + } + } + } + + test("createStatement with result set type and concurrency") { + withConnection { conn => + Using.resource( + conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)) { stmt => + assert(stmt.getResultSetType === ResultSet.TYPE_FORWARD_ONLY) + } + + // the holdability overload is not supported + intercept[SQLFeatureNotSupportedException] { + conn.createStatement( + ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + ResultSet.CLOSE_CURSORS_AT_COMMIT) + } + + // only TYPE_FORWARD_ONLY and CONCUR_READ_ONLY are supported + intercept[SQLFeatureNotSupportedException] { + conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_UPDATABLE) + } + Seq(ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.TYPE_SCROLL_SENSITIVE).foreach { typ => + intercept[SQLFeatureNotSupportedException] { + conn.createStatement(typ, ResultSet.CONCUR_READ_ONLY) + } + } + } + } }