Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,26 @@ class SparkConnectConnection(val url: String, val info: Properties) extends Conn
throw new SQLFeatureNotSupportedException

override def createStatement(
resultSetType: Int, resultSetConcurrency: Int): Statement =
throw new SQLFeatureNotSupportedException
resultSetType: Int, resultSetConcurrency: Int): Statement = {
checkSupportedResultSet(resultSetType, resultSetConcurrency)
createStatement()
}

// Spark Connect results are forward-only and server-paginated, so only
// TYPE_FORWARD_ONLY result sets are supported.
private def checkSupportedResultSet(
resultSetType: Int, resultSetConcurrency: Int): Unit = {
if (resultSetType != ResultSet.TYPE_FORWARD_ONLY) {
throw new SQLFeatureNotSupportedException(
s"ResultSet type ${stringifyResultSetType(resultSetType)} is not supported; " +
"only TYPE_FORWARD_ONLY.")
}
if (resultSetConcurrency != ResultSet.CONCUR_READ_ONLY) {
throw new SQLFeatureNotSupportedException(
s"ResultSet concurrency ${stringifyResultSetConcurrency(resultSetConcurrency)} " +
"is not supported; only CONCUR_READ_ONLY.")
}
}

override def prepareStatement(
sql: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.client.jdbc
import java.sql.{Array => _, _}

import org.apache.spark.sql.connect.client.SparkResult
import org.apache.spark.sql.connect.client.jdbc.util.JdbcErrorUtils

class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {

Expand All @@ -28,6 +29,8 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {

private var maxRows: Int = 0

private var resultsExhausted: Boolean = false

@volatile private var closed: Boolean = false

override def isClosed: Boolean = closed
Expand Down Expand Up @@ -94,6 +97,7 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {
// reset before executing new query
operationId = null
resultSet = null
resultsExhausted = false

var df = conn.spark.sql(sql)
if (maxRows > 0) {
Expand Down Expand Up @@ -143,8 +147,13 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {
0
}

override def setQueryTimeout(seconds: Int): Unit =
throw new SQLFeatureNotSupportedException
// This driver does not apply a query timeout; validate and silently drop the value.
override def setQueryTimeout(seconds: Int): Unit = {
checkOpen()
if (seconds < 0) {
throw new SQLException("Query timeout must be zero or a positive integer.")
}
}

override def cancel(): Unit = {
checkOpen()
Expand All @@ -164,35 +173,60 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {
override def getUpdateCount: Int = {
checkOpen()

if (resultSet != null) {
if (resultsExhausted || resultSet != null) {
-1
} else {
0 // always return 0 because affected rows is not supported yet
}
}

override def getMoreResults: Boolean =
throw new SQLFeatureNotSupportedException
// a single result per execute(), so there is no next one: close the current
// ResultSet and mark exhausted, flipping getUpdateCount() to -1 so drain loops end
override def getMoreResults: Boolean = {
checkOpen()
if (resultSet != null) {
resultSet.close()
resultSet = null
}
resultsExhausted = true
false
}

override def setFetchDirection(direction: Int): Unit =
throw new SQLFeatureNotSupportedException
override def setFetchDirection(direction: Int): Unit = {
checkOpen()
if (direction != ResultSet.FETCH_FORWARD) {
throw new SQLException(
s"Fetch direction ${JdbcErrorUtils.stringifyFetchDirection(direction)} is not supported.")
}
}

override def getFetchDirection: Int =
throw new SQLFeatureNotSupportedException
override def getFetchDirection: Int = {
checkOpen()
ResultSet.FETCH_FORWARD
}

override def setFetchSize(rows: Int): Unit =
throw new SQLFeatureNotSupportedException
// This driver does not apply a fetch size hint; validate and silently drop the value.
override def setFetchSize(rows: Int): Unit = {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to get/setQueryTimeout

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Same treatment as get/setQueryTimeout: getFetchSize returns 0 and setFetchSize validates then silently drops the value. Also removed the now-unused DEFAULT_FETCH_SIZE.

checkOpen()
if (rows < 0) {
throw new SQLException("Fetch size must be zero or a positive integer.")
}
}

override def getFetchSize: Int =
throw new SQLFeatureNotSupportedException
override def getFetchSize: Int = {
checkOpen()
0
}

override def getResultSetConcurrency: Int = {
checkOpen()
ResultSet.CONCUR_READ_ONLY
}

override def getResultSetType: Int =
throw new SQLFeatureNotSupportedException
override def getResultSetType: Int = {
checkOpen()
ResultSet.TYPE_FORWARD_ONLY
}

override def addBatch(sql: String): Unit =
throw new SQLFeatureNotSupportedException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ private[jdbc] object JdbcErrorUtils {
throw new IllegalArgumentException(s"Invalid ResultSet type: $typ")
}

def stringifyResultSetConcurrency(concurrency: Int): String = concurrency match {
case ResultSet.CONCUR_READ_ONLY => "CONCUR_READ_ONLY"
case ResultSet.CONCUR_UPDATABLE => "CONCUR_UPDATABLE"
case _ =>
throw new IllegalArgumentException(s"Invalid ResultSet concurrency: $concurrency")
}

def stringifyFetchDirection(direction: Int): String = direction match {
case ResultSet.FETCH_FORWARD => "FETCH_FORWARD"
case ResultSet.FETCH_REVERSE => "FETCH_REVERSE"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,88 @@ class SparkConnectStatementSuite extends ConnectFunSuite with RemoteSparkSession
}
}
}

test("fetch size, fetch direction, result set type and query timeout accessors") {
withStatement { stmt =>
// fetch size: validated then silently dropped, always reads back as 0
assert(stmt.getFetchSize === 0)
stmt.setFetchSize(42)
assert(stmt.getFetchSize === 0)
val se1 = intercept[SQLException] {
stmt.setFetchSize(-1)
}
assert(se1.getMessage === "Fetch size must be zero or a positive integer.")

// fetch direction: only FETCH_FORWARD is supported
stmt.setFetchDirection(ResultSet.FETCH_FORWARD)
assert(stmt.getFetchDirection === ResultSet.FETCH_FORWARD)
intercept[SQLException] {
stmt.setFetchDirection(ResultSet.FETCH_REVERSE)
}

// result set type is forward-only
assert(stmt.getResultSetType === ResultSet.TYPE_FORWARD_ONLY)

// query timeout: validated then silently dropped, always reads back as 0
assert(stmt.getQueryTimeout === 0)
stmt.setQueryTimeout(30)
assert(stmt.getQueryTimeout === 0)
val se2 = intercept[SQLException] {
stmt.setQueryTimeout(-1)
}
assert(se2.getMessage === "Query timeout must be zero or a positive integer.")
}
}

test("getMoreResults terminates JDBC drain loops") {
// A typical JDBC result-draining loop. With getMoreResults throwing (or not
// flipping getUpdateCount to -1) this would spin forever; assert it returns.
def drain(stmt: Statement): Unit = {
while (stmt.getMoreResults || stmt.getUpdateCount != -1) {}
}

withTable("t_drain") {
withStatement { stmt =>
// result-bearing command
assert(stmt.execute("SELECT id FROM range(3)"))
assert(stmt.getUpdateCount === -1)
drain(stmt)
assert(stmt.getResultSet === null)
assert(stmt.getUpdateCount === -1)

// result-less command
assert(!stmt.execute("CREATE TABLE t_drain (id INT) USING Parquet"))
assert(stmt.getUpdateCount === 0)
drain(stmt)
assert(stmt.getUpdateCount === -1)
}
}
}

test("createStatement with result set type and concurrency") {
withConnection { conn =>
Using.resource(
conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)) { stmt =>
assert(stmt.getResultSetType === ResultSet.TYPE_FORWARD_ONLY)
}

// the holdability overload is not supported
intercept[SQLFeatureNotSupportedException] {
conn.createStatement(
ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY,
ResultSet.CLOSE_CURSORS_AT_COMMIT)
}

// only TYPE_FORWARD_ONLY and CONCUR_READ_ONLY are supported
intercept[SQLFeatureNotSupportedException] {
conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_UPDATABLE)
}
Seq(ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.TYPE_SCROLL_SENSITIVE).foreach { typ =>
intercept[SQLFeatureNotSupportedException] {
conn.createStatement(typ, ResultSet.CONCUR_READ_ONLY)
}
}
}
}
}