diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 027da28a31cb..c112f2d1da30 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -300,7 +300,7 @@ def text( def csv( self, - path: PathOrPaths, + path: Union[PathOrPaths, "DataFrame"], schema: Optional[Union[StructType, str]] = None, sep: Optional[str] = None, encoding: Optional[str] = None, @@ -371,6 +371,22 @@ def csv( ) if isinstance(path, str): path = [path] + + from pyspark.sql.connect.dataframe import DataFrame + + if isinstance(path, DataFrame): + # Schema must be set explicitly here because the DataFrame path + # bypasses load(), which normally calls self.schema(schema). + if schema is not None: + self.schema(schema) + return self._df( + Parse( + child=path._plan, + format=proto.Parse.ParseFormat.PARSE_FORMAT_CSV, + schema=self._schema, + options=self._options, + ) + ) return self.load(path=path, format="csv", schema=schema) csv.__doc__ = PySparkDataFrameReader.csv.__doc__ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 7ada41a71655..9d4da9f272b7 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -768,7 +768,7 @@ def text( def csv( self, - path: PathOrPaths, + path: Union[str, List[str], "RDD[str]", "DataFrame"], schema: Optional[Union[StructType, str]] = None, sep: Optional[str] = None, encoding: Optional[str] = None, @@ -814,11 +814,15 @@ def csv( .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.2.0 + Supports DataFrame input. + Parameters ---------- - path : str or list + path : str, list, :class:`RDD`, or :class:`DataFrame` string, or list of strings, for input path(s), - or RDD of Strings storing CSV rows. + or RDD of Strings storing CSV rows, + or a DataFrame with a single string column containing CSV rows. schema : :class:`pyspark.sql.types.StructType` or str, optional an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). @@ -896,7 +900,7 @@ def csv( if not is_remote_only() and isinstance(path, RDD): - def func(iterator): + def func(iterator: Iterable) -> Iterable: for x in iterator: if not isinstance(x, str): x = str(x) @@ -905,7 +909,8 @@ def func(iterator): yield x keyed = path.mapPartitions(func) - keyed._bypass_serializer = True + keyed._bypass_serializer = True # type: ignore[attr-defined] + assert self._spark._jvm is not None jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) # see SPARK-22112 # There aren't any jvm api for creating a dataframe from rdd storing csv. @@ -915,12 +920,20 @@ def func(iterator): jrdd.rdd(), self._spark._jvm.Encoders.STRING() ) return self._df(self._jreader.csv(jdataset)) + + from pyspark.sql.dataframe import DataFrame + + if isinstance(path, DataFrame): + assert self._spark._jvm is not None + return self._df( + self._spark._jvm.PythonSQLUtils.csvFromDataFrame(self._jreader, path._jdf) + ) else: raise PySparkTypeError( errorClass="NOT_EXPECTED_TYPE", messageParameters={ "arg_name": "path", - "expected_type": "str or list[RDD]", + "expected_type": "str, list, RDD, or DataFrame", "arg_type": type(path).__name__, }, ) diff --git a/python/pyspark/sql/tests/connect/test_connect_readwriter.py b/python/pyspark/sql/tests/connect/test_connect_readwriter.py index 161f6305be8c..58a92fe9f017 100644 --- a/python/pyspark/sql/tests/connect/test_connect_readwriter.py +++ b/python/pyspark/sql/tests/connect/test_connect_readwriter.py @@ -213,6 +213,42 @@ def test_json_with_dataframe_input_zero_columns(self): with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"): self.connect.read.json(empty_schema_df).collect() + def test_csv_with_dataframe_input(self): + csv_df = self.connect.createDataFrame( + [("Alice,25",), ("Bob,30",)], + schema="value STRING", + ) + result = self.connect.read.csv(csv_df) + expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")] + self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected) + + def test_csv_with_dataframe_input_and_schema(self): + csv_df = self.connect.createDataFrame( + [("Alice,25",), ("Bob,30",)], + schema="value STRING", + ) + result = self.connect.read.csv(csv_df, schema="name STRING, age INT") + expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_csv_with_dataframe_input_non_string_column(self): + int_df = self.connect.createDataFrame([(1,), (2,)], schema="value INT") + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_STRING_TYPE"): + self.connect.read.csv(int_df).collect() + + def test_csv_with_dataframe_input_multiple_columns(self): + multi_df = self.connect.createDataFrame( + [("Alice,25", "extra"), ("Bob,30", "extra")], + schema="value STRING, other STRING", + ) + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"): + self.connect.read.csv(multi_df).collect() + + def test_csv_with_dataframe_input_zero_columns(self): + empty_schema_df = self.connect.range(1).select() + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"): + self.connect.read.csv(empty_schema_df).collect() + def test_multi_paths(self): # SPARK-42041: DataFrameReader should support list of paths diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index 94ebd958843b..bd49b1f46548 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -100,7 +100,7 @@ def compare_method_signatures(self, classic_cls, connect_cls, cls_name): connect_signature = inspect.signature(connect_methods[method]) # Cannot support RDD arguments from Spark Connect - has_rdd_arguments = ("createDataFrame", "xml", "json", "toJSON") + has_rdd_arguments = ("createDataFrame", "xml", "json", "csv", "toJSON") if method not in has_rdd_arguments: self.assertEqual( classic_signature, diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index e4eb5900cc85..720d02c6bac6 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -152,6 +152,42 @@ def test_ignorewhitespace_csv(self): self.assertEqual(readback.collect(), expected) shutil.rmtree(tmpPath) + def test_csv_with_dataframe_input(self): + csv_df = self.spark.createDataFrame( + [("Alice,25",), ("Bob,30",)], + schema="value STRING", + ) + result = self.spark.read.csv(csv_df) + expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")] + self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected) + + def test_csv_with_dataframe_input_and_schema(self): + csv_df = self.spark.createDataFrame( + [("Alice,25",), ("Bob,30",)], + schema="value STRING", + ) + result = self.spark.read.csv(csv_df, schema="name STRING, age INT") + expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_csv_with_dataframe_input_non_string_column(self): + int_df = self.spark.createDataFrame([(1,), (2,)], schema="value INT") + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_STRING_TYPE"): + self.spark.read.csv(int_df).collect() + + def test_csv_with_dataframe_input_multiple_columns(self): + multi_df = self.spark.createDataFrame( + [("Alice,25", "extra"), ("Bob,30", "extra")], + schema="value STRING, other STRING", + ) + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"): + self.spark.read.csv(multi_df).collect() + + def test_csv_with_dataframe_input_zero_columns(self): + empty_schema_df = self.spark.range(1).select() + with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"): + self.spark.read.csv(empty_schema_df).collect() + def test_xml(self): tmpPath = tempfile.mkdtemp() shutil.rmtree(tmpPath) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index bc7678c291a2..93847f11f61a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.CLASS_LOADER import org.apache.spark.security.SocketAuthServer -import org.apache.spark.sql.{internal, Column, DataFrame, DataFrameReader, Encoders, Row, SparkSession, TableArg} +import org.apache.spark.sql.{internal, Column, DataFrame, DataFrameReader, Dataset, Encoders, Row, SparkSession, TableArg} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -196,14 +196,10 @@ private[sql] object PythonSQLUtils extends Logging { def internalFn(name: String, inputs: Column*): Column = Column.internalFn(name, inputs: _*) /** - * Parses a [[DataFrame]] containing JSON strings into a structured [[DataFrame]]. - * The input DataFrame must have exactly one column of StringType. - * This is used by PySpark to avoid manual Dataset[String] conversion on the Python side. + * Validates that the input [[DataFrame]] has exactly one column of StringType + * and converts it to a Dataset[String]. */ - def jsonFromDataFrame( - reader: DataFrameReader, - df: DataFrame): DataFrame = { - val classicReader = reader.asInstanceOf[ClassicDataFrameReader] + private def toStringDataset(df: DataFrame): Dataset[String] = { val fields = df.schema.fields if (fields.length != 1) { throw QueryCompilationErrors.dataframeInputNotSingleColumnError(fields.length) @@ -211,7 +207,19 @@ private[sql] object PythonSQLUtils extends Logging { if (fields.head.dataType != org.apache.spark.sql.types.StringType) { throw QueryCompilationErrors.dataframeInputNotStringTypeError(fields.head.dataType) } - classicReader.json(df.as(Encoders.STRING)) + df.as(Encoders.STRING) + } + + def jsonFromDataFrame( + reader: DataFrameReader, + df: DataFrame): DataFrame = { + reader.asInstanceOf[ClassicDataFrameReader].json(toStringDataset(df)) + } + + def csvFromDataFrame( + reader: DataFrameReader, + df: DataFrame): DataFrame = { + reader.asInstanceOf[ClassicDataFrameReader].csv(toStringDataset(df)) } def cleanupPythonWorkerLogs(sessionUUID: String, sparkContext: SparkContext): Unit = {