From 7ab416e7f0d1c84a5a6142bc93d027316b1698bb Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Thu, 9 Apr 2026 06:27:26 +0000 Subject: [PATCH 01/10] feat: make spark.read.csv accept DataFrame input --- python/pyspark/sql/connect/readwriter.py | 27 +++++++++++++- python/pyspark/sql/readwriter.py | 20 ++++++++-- .../tests/connect/test_connect_readwriter.py | 37 +++++++++++++++++++ python/pyspark/sql/tests/test_datasources.py | 37 +++++++++++++++++++ .../spark/sql/api/python/PythonSQLUtils.scala | 20 ++++++++++ 5 files changed, 136 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 027da28a31cb..725b2d491be6 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -371,7 +371,32 @@ def csv( ) if isinstance(path, str): path = [path] - return self.load(path=path, format="csv", schema=schema) + if isinstance(path, list): + return self.load(path=path, format="csv", schema=schema) + + from pyspark.sql.connect.dataframe import DataFrame + + if isinstance(path, DataFrame): + # Schema must be set explicitly here because the DataFrame path + # bypasses load(), which normally calls self.schema(schema). + if schema is not None: + self.schema(schema) + return self._df( + Parse( + child=path._plan, + format=proto.Parse.ParseFormat.PARSE_FORMAT_CSV, + schema=self._schema, + options=self._options, + ) + ) + raise PySparkTypeError( + errorClass="NOT_EXPECTED_TYPE", + messageParameters={ + "arg_name": "path", + "expected_type": "str, list, or DataFrame", + "arg_type": type(path).__name__, + }, + ) csv.__doc__ = PySparkDataFrameReader.csv.__doc__ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7ada41a71655..735ecfeb9d26 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -768,7 +768,7 @@ def text( def csv( self, - path: PathOrPaths, + path: Union[str, List[str], "RDD[str]", "DataFrame"], schema: Optional[Union[StructType, str]] = None, sep: Optional[str] = None, encoding: Optional[str] = None, @@ -814,11 +814,15 @@ def csv( .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.2.0 + Supports DataFrame input. + Parameters ---------- - path : str or list + path : str, list, :class:`RDD`, or :class:`DataFrame` string, or list of strings, for input path(s), - or RDD of Strings storing CSV rows. + or RDD of Strings storing CSV rows, + or a DataFrame with a single string column containing CSV rows. schema : :class:`pyspark.sql.types.StructType` or str, optional an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). @@ -915,12 +919,20 @@ def func(iterator): jrdd.rdd(), self._spark._jvm.Encoders.STRING() ) return self._df(self._jreader.csv(jdataset)) + + from pyspark.sql.dataframe import DataFrame + + if isinstance(path, DataFrame): + assert self._spark._jvm is not None + return self._df( + self._spark._jvm.PythonSQLUtils.csvFromDataFrame(self._jreader, path._jdf) + ) else: raise PySparkTypeError( errorClass="NOT_EXPECTED_TYPE", messageParameters={ "arg_name": "path", - "expected_type": "str or list[RDD]", + "expected_type": "str, list, RDD, or DataFrame", "arg_type": type(path).__name__, }, ) diff --git a/python/pyspark/sql/tests/connect/test_connect_readwriter.py b/python/pyspark/sql/tests/connect/test_connect_readwriter.py index 161f6305be8c..a526fa741b90 100644 --- a/python/pyspark/sql/tests/connect/test_connect_readwriter.py +++ b/python/pyspark/sql/tests/connect/test_connect_readwriter.py @@ -213,6 +213,43 @@ def test_json_with_dataframe_input_zero_columns(self): with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"): self.connect.read.json(empty_schema_df).collect() + def test_csv_with_dataframe_input(self): + csv_df = self.connect.createDataFrame( + [("Alice,25",), ("Bob,30",)], + schema="value STRING", + ) + result = self.connect.read.csv(csv_df) + expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")] + self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected) + + def test_csv_with_dataframe_input_and_schema(self): + csv_df = self.connect.createDataFrame( + [("Alice,25",), ("Bob,30",)], + schema="value STRING", + ) + result = self.connect.read.csv(csv_df, schema="name STRING, age INT") + expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_csv_with_dataframe_input_non_string_column(self): + int_df = self.connect.createDataFrame([(1,), (2,)], schema="value INT") + with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + self.connect.read.csv(int_df).collect() + + def test_csv_with_dataframe_input_multiple_columns(self): + multi_df = self.connect.createDataFrame( + [("Alice,25", "extra"), ("Bob,30", "extra")], + schema="value STRING, other STRING", + ) + result = self.connect.read.csv(multi_df) + expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")] + self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected) + + def test_csv_with_dataframe_input_zero_columns(self): + empty_schema_df = self.connect.range(1).select() + with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + self.connect.read.csv(empty_schema_df).collect() + def test_multi_paths(self): # SPARK-42041: DataFrameReader should support list of paths diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index e4eb5900cc85..8990587c308c 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -152,6 +152,43 @@ def test_ignorewhitespace_csv(self): self.assertEqual(readback.collect(), expected) shutil.rmtree(tmpPath) + def test_csv_with_dataframe_input(self): + csv_df = self.spark.createDataFrame( + [("Alice,25",), ("Bob,30",)], + schema="value STRING", + ) + result = self.spark.read.csv(csv_df) + expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")] + self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected) + + def test_csv_with_dataframe_input_and_schema(self): + csv_df = self.spark.createDataFrame( + [("Alice,25",), ("Bob,30",)], + schema="value STRING", + ) + result = self.spark.read.csv(csv_df, schema="name STRING, age INT") + expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_csv_with_dataframe_input_non_string_column(self): + int_df = self.spark.createDataFrame([(1,), (2,)], schema="value INT") + with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + self.spark.read.csv(int_df).collect() + + def test_csv_with_dataframe_input_multiple_columns(self): + multi_df = self.spark.createDataFrame( + [("Alice,25", "extra"), ("Bob,30", "extra")], + schema="value STRING, other STRING", + ) + result = self.spark.read.csv(multi_df) + expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")] + self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected) + + def test_csv_with_dataframe_input_zero_columns(self): + empty_schema_df = self.spark.range(1).select() + with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + self.spark.read.csv(empty_schema_df).collect() + def test_xml(self): tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index bc7678c291a2..c4b49628fe63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -214,6 +214,26 @@ private[sql] object PythonSQLUtils extends Logging { classicReader.json(df.as(Encoders.STRING)) } + /** + * Parses a [[DataFrame]] containing CSV strings into a structured [[DataFrame]]. + * The input DataFrame must have exactly one column of StringType. + * This is used by PySpark to avoid manual Dataset[String] conversion on the Python side. + */ + def csvFromDataFrame( + reader: DataFrameReader, + df: DataFrame): DataFrame = { + val classicReader = reader.asInstanceOf[ClassicDataFrameReader] + val fields = df.schema.fields + if (fields.isEmpty) { + throw QueryCompilationErrors.parseInputNotStringTypeError( + org.apache.spark.sql.types.NullType) + } + if (fields.head.dataType != org.apache.spark.sql.types.StringType) { + throw QueryCompilationErrors.parseInputNotStringTypeError(fields.head.dataType) + } + classicReader.csv(df.select(df.columns.head).as(Encoders.STRING)) + } + def cleanupPythonWorkerLogs(sessionUUID: String, sparkContext: SparkContext): Unit = { if (!sparkContext.isStopped) { try { From 5a39963c267e7e36f152714f4353ec748c25c342 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Thu, 9 Apr 2026 18:34:12 +0000 Subject: [PATCH 02/10] fix: align csv signature and mypy annotations with json --- python/pyspark/sql/connect/readwriter.py | 2 +- python/pyspark/sql/readwriter.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 725b2d491be6..ab53cfb04915 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -300,7 +300,7 @@ def text( def csv( self, - path: PathOrPaths, + path: Union[PathOrPaths, "DataFrame"], schema: Optional[Union[StructType, str]] = None, sep: Optional[str] = None, encoding: Optional[str] = None, diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 735ecfeb9d26..9d4da9f272b7 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -900,7 +900,7 @@ def csv( if not is_remote_only() and isinstance(path, RDD): - def func(iterator): + def func(iterator: Iterable) -> Iterable: for x in iterator: if not isinstance(x, str): x = str(x) @@ -909,7 +909,8 @@ def func(iterator): yield x keyed = path.mapPartitions(func) - keyed._bypass_serializer = True + keyed._bypass_serializer = True # type: ignore[attr-defined] + assert self._spark._jvm is not None jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) # see SPARK-22112 # There aren't any jvm api for creating a dataframe from rdd storing csv. From 08c85c91fa6347fd966aca058644690141525140 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Thu, 9 Apr 2026 21:49:39 +0000 Subject: [PATCH 03/10] fix: add csv to RDD argument skip list in connect compatibility test --- python/pyspark/sql/tests/test_connect_compatibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index 94ebd958843b..bd49b1f46548 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -100,7 +100,7 @@ def compare_method_signatures(self, classic_cls, connect_cls, cls_name): connect_signature = inspect.signature(connect_methods[method]) # Cannot support RDD arguments from Spark Connect - has_rdd_arguments = ("createDataFrame", "xml", "json", "toJSON") + has_rdd_arguments = ("createDataFrame", "xml", "json", "csv", "toJSON") if method not in has_rdd_arguments: self.assertEqual( classic_signature, From 25c91f781b7cbb684c73b2c1d7952aa65f56d9e3 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Fri, 10 Apr 2026 07:01:06 +0000 Subject: [PATCH 04/10] chore: retrigger CI From 46374413e283a42dd715c8a162ce0b917a0b3a8d Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Sun, 12 Apr 2026 00:13:02 +0000 Subject: [PATCH 05/10] chore: retrigger CI From c8868ea49ef3a22cd4348bf9541b79fecca2757b Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 13 Apr 2026 06:00:53 +0000 Subject: [PATCH 06/10] fix: address review feedback - reorder DataFrame check and reject multi-column input --- python/pyspark/sql/connect/readwriter.py | 11 +---------- .../sql/tests/connect/test_connect_readwriter.py | 9 ++++----- python/pyspark/sql/tests/test_datasources.py | 9 ++++----- .../apache/spark/sql/api/python/PythonSQLUtils.scala | 9 ++++----- 4 files changed, 13 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index ab53cfb04915..c112f2d1da30 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -371,8 +371,6 @@ def csv( ) if isinstance(path, str): path = [path] - if isinstance(path, list): - return self.load(path=path, format="csv", schema=schema) from pyspark.sql.connect.dataframe import DataFrame @@ -389,14 +387,7 @@ def csv( options=self._options, ) ) - raise PySparkTypeError( - errorClass="NOT_EXPECTED_TYPE", - messageParameters={ - "arg_name": "path", - "expected_type": "str, list, or DataFrame", - "arg_type": type(path).__name__, - }, - ) + return self.load(path=path, format="csv", schema=schema) csv.__doc__ = PySparkDataFrameReader.csv.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_connect_readwriter.py b/python/pyspark/sql/tests/connect/test_connect_readwriter.py index a526fa741b90..58a92fe9f017 100644 --- a/python/pyspark/sql/tests/connect/test_connect_readwriter.py +++ b/python/pyspark/sql/tests/connect/test_connect_readwriter.py @@ -233,7 +233,7 @@ def test_csv_with_dataframe_input_and_schema(self): def test_csv_with_dataframe_input_non_string_column(self): int_df = self.connect.createDataFrame([(1,), (2,)], schema="value INT") - with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_STRING_TYPE"): self.connect.read.csv(int_df).collect() def test_csv_with_dataframe_input_multiple_columns(self): @@ -241,13 +241,12 @@ def test_csv_with_dataframe_input_multiple_columns(self): [("Alice,25", "extra"), ("Bob,30", "extra")], schema="value STRING, other STRING", ) - result = self.connect.read.csv(multi_df) - expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")] - self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected) + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"): + self.connect.read.csv(multi_df).collect() def test_csv_with_dataframe_input_zero_columns(self): empty_schema_df = self.connect.range(1).select() - with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"): self.connect.read.csv(empty_schema_df).collect() def test_multi_paths(self): diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index 8990587c308c..720d02c6bac6 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -172,7 +172,7 @@ def test_csv_with_dataframe_input_and_schema(self): def test_csv_with_dataframe_input_non_string_column(self): int_df = self.spark.createDataFrame([(1,), (2,)], schema="value INT") - with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_STRING_TYPE"): self.spark.read.csv(int_df).collect() def test_csv_with_dataframe_input_multiple_columns(self): @@ -180,13 +180,12 @@ def test_csv_with_dataframe_input_multiple_columns(self): [("Alice,25", "extra"), ("Bob,30", "extra")], schema="value STRING, other STRING", ) - result = self.spark.read.csv(multi_df) - expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")] - self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected) + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"): + self.spark.read.csv(multi_df).collect() def test_csv_with_dataframe_input_zero_columns(self): empty_schema_df = self.spark.range(1).select() - with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"): self.spark.read.csv(empty_schema_df).collect() def test_xml(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index c4b49628fe63..7e9d3d38a4e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -224,14 +224,13 @@ private[sql] object PythonSQLUtils extends Logging { df: DataFrame): DataFrame = { val classicReader = reader.asInstanceOf[ClassicDataFrameReader] val fields = df.schema.fields - if (fields.isEmpty) { - throw QueryCompilationErrors.parseInputNotStringTypeError( - org.apache.spark.sql.types.NullType) + if (fields.length != 1) { + throw QueryCompilationErrors.dataframeInputNotSingleColumnError(fields.length) } if (fields.head.dataType != org.apache.spark.sql.types.StringType) { - throw QueryCompilationErrors.parseInputNotStringTypeError(fields.head.dataType) + throw QueryCompilationErrors.dataframeInputNotStringTypeError(fields.head.dataType) } - classicReader.csv(df.select(df.columns.head).as(Encoders.STRING)) + classicReader.csv(df.as(Encoders.STRING)) } def cleanupPythonWorkerLogs(sessionUUID: String, sparkContext: SparkContext): Unit = { From a9c7e749adf846e92b8a3c5fffcdfe4f67f95d7e Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 13 Apr 2026 19:07:49 +0000 Subject: [PATCH 07/10] chore: retrigger CI From 95930c0f369a04d6de6e62f745e42f222e979965 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 13 Apr 2026 20:27:58 +0000 Subject: [PATCH 08/10] refactor: extract toStringDataset util for json/csv DataFrame input validation --- .../spark/sql/api/python/PythonSQLUtils.scala | 35 ++++++------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 7e9d3d38a4e5..cb739761b8b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.CLASS_LOADER import org.apache.spark.security.SocketAuthServer -import org.apache.spark.sql.{internal, Column, DataFrame, DataFrameReader, Encoders, Row, SparkSession, TableArg} +import org.apache.spark.sql.{internal, Column, DataFrame, DataFrameReader, Dataset, Encoders, Row, SparkSession, TableArg} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -195,15 +195,7 @@ private[sql] object PythonSQLUtils extends Logging { @scala.annotation.varargs def internalFn(name: String, inputs: Column*): Column = Column.internalFn(name, inputs: _*) - /** - * Parses a [[DataFrame]] containing JSON strings into a structured [[DataFrame]]. - * The input DataFrame must have exactly one column of StringType. - * This is used by PySpark to avoid manual Dataset[String] conversion on the Python side. - */ - def jsonFromDataFrame( - reader: DataFrameReader, - df: DataFrame): DataFrame = { - val classicReader = reader.asInstanceOf[ClassicDataFrameReader] + private def toStringDataset(df: DataFrame): Dataset[String] = { val fields = df.schema.fields if (fields.length != 1) { throw QueryCompilationErrors.dataframeInputNotSingleColumnError(fields.length) @@ -211,26 +203,19 @@ private[sql] object PythonSQLUtils extends Logging { if (fields.head.dataType != org.apache.spark.sql.types.StringType) { throw QueryCompilationErrors.dataframeInputNotStringTypeError(fields.head.dataType) } - classicReader.json(df.as(Encoders.STRING)) + df.as(Encoders.STRING) + } + + def jsonFromDataFrame( + reader: DataFrameReader, + df: DataFrame): DataFrame = { + reader.asInstanceOf[ClassicDataFrameReader].json(toStringDataset(df)) } - /** - * Parses a [[DataFrame]] containing CSV strings into a structured [[DataFrame]]. - * The input DataFrame must have exactly one column of StringType. - * This is used by PySpark to avoid manual Dataset[String] conversion on the Python side. - */ def csvFromDataFrame( reader: DataFrameReader, df: DataFrame): DataFrame = { - val classicReader = reader.asInstanceOf[ClassicDataFrameReader] - val fields = df.schema.fields - if (fields.length != 1) { - throw QueryCompilationErrors.dataframeInputNotSingleColumnError(fields.length) - } - if (fields.head.dataType != org.apache.spark.sql.types.StringType) { - throw QueryCompilationErrors.dataframeInputNotStringTypeError(fields.head.dataType) - } - classicReader.csv(df.as(Encoders.STRING)) + reader.asInstanceOf[ClassicDataFrameReader].csv(toStringDataset(df)) } def cleanupPythonWorkerLogs(sessionUUID: String, sparkContext: SparkContext): Unit = { From 94c5ced85a857b8c79a4c75482833fe2ab949763 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 13 Apr 2026 20:58:26 +0000 Subject: [PATCH 09/10] docs: add Scaladoc to toStringDataset --- .../org/apache/spark/sql/api/python/PythonSQLUtils.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index cb739761b8b1..cfb450db649f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -195,6 +195,11 @@ private[sql] object PythonSQLUtils extends Logging { @scala.annotation.varargs def internalFn(name: String, inputs: Column*): Column = Column.internalFn(name, inputs: _*) + /** + * Validates that the input [[DataFrame]] has exactly one column of StringType + * and converts it to a Dataset[String]. + * This is used by PySpark to avoid manual Dataset[String] conversion on the Python side. + */ private def toStringDataset(df: DataFrame): Dataset[String] = { val fields = df.schema.fields if (fields.length != 1) { From 110a16338ca60a7e43618bad6c9dae91af6e4772 Mon Sep 17 00:00:00 2001 From: Yicong-Huang <17627829+Yicong-Huang@users.noreply.github.com> Date: Mon, 13 Apr 2026 20:58:42 +0000 Subject: [PATCH 10/10] docs: trim Scaladoc --- .../scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index cfb450db649f..93847f11f61a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -198,7 +198,6 @@ private[sql] object PythonSQLUtils extends Logging { /** * Validates that the input [[DataFrame]] has exactly one column of StringType * and converts it to a Dataset[String]. - * This is used by PySpark to avoid manual Dataset[String] conversion on the Python side. */ private def toStringDataset(df: DataFrame): Dataset[String] = { val fields = df.schema.fields