Skip to content
Closed
18 changes: 17 additions & 1 deletion python/pyspark/sql/connect/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def text(

def csv(
self,
path: PathOrPaths,
path: Union[PathOrPaths, "DataFrame"],
schema: Optional[Union[StructType, str]] = None,
sep: Optional[str] = None,
encoding: Optional[str] = None,
Expand Down Expand Up @@ -371,6 +371,22 @@ def csv(
)
if isinstance(path, str):
path = [path]

from pyspark.sql.connect.dataframe import DataFrame

if isinstance(path, DataFrame):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safer to add this if dataframe check first with return to avoid any behaviour changes. e.g., we might support sth else than list.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. moved the check up.

# Schema must be set explicitly here because the DataFrame path
# bypasses load(), which normally calls self.schema(schema).
if schema is not None:
self.schema(schema)
return self._df(
Parse(
child=path._plan,
format=proto.Parse.ParseFormat.PARSE_FORMAT_CSV,
schema=self._schema,
options=self._options,
)
)
return self.load(path=path, format="csv", schema=schema)

csv.__doc__ = PySparkDataFrameReader.csv.__doc__
Expand Down
25 changes: 19 additions & 6 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,7 @@ def text(

def csv(
self,
path: PathOrPaths,
path: Union[str, List[str], "RDD[str]", "DataFrame"],
schema: Optional[Union[StructType, str]] = None,
sep: Optional[str] = None,
encoding: Optional[str] = None,
Expand Down Expand Up @@ -814,11 +814,15 @@ def csv(
.. versionchanged:: 3.4.0
Supports Spark Connect.

.. versionchanged:: 4.2.0
Supports DataFrame input.

Parameters
----------
path : str or list
path : str, list, :class:`RDD`, or :class:`DataFrame`
string, or list of strings, for input path(s),
or RDD of Strings storing CSV rows.
or RDD of Strings storing CSV rows,
or a DataFrame with a single string column containing CSV rows.
schema : :class:`pyspark.sql.types.StructType` or str, optional
an optional :class:`pyspark.sql.types.StructType` for the input schema
or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``).
Expand Down Expand Up @@ -896,7 +900,7 @@ def csv(

if not is_remote_only() and isinstance(path, RDD):

def func(iterator):
def func(iterator: Iterable) -> Iterable:
for x in iterator:
if not isinstance(x, str):
x = str(x)
Expand All @@ -905,7 +909,8 @@ def func(iterator):
yield x

keyed = path.mapPartitions(func)
keyed._bypass_serializer = True
keyed._bypass_serializer = True # type: ignore[attr-defined]
assert self._spark._jvm is not None
jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString())
# see SPARK-22112
# There aren't any jvm api for creating a dataframe from rdd storing csv.
Expand All @@ -915,12 +920,20 @@ def func(iterator):
jrdd.rdd(), self._spark._jvm.Encoders.STRING()
)
return self._df(self._jreader.csv(jdataset))

from pyspark.sql.dataframe import DataFrame

if isinstance(path, DataFrame):
assert self._spark._jvm is not None
return self._df(
self._spark._jvm.PythonSQLUtils.csvFromDataFrame(self._jreader, path._jdf)
)
else:
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "path",
"expected_type": "str or list[RDD]",
"expected_type": "str, list, RDD, or DataFrame",
"arg_type": type(path).__name__,
},
)
Expand Down
36 changes: 36 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,42 @@ def test_json_with_dataframe_input_zero_columns(self):
with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"):
self.connect.read.json(empty_schema_df).collect()

def test_csv_with_dataframe_input(self):
csv_df = self.connect.createDataFrame(
[("Alice,25",), ("Bob,30",)],
schema="value STRING",
)
result = self.connect.read.csv(csv_df)
expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")]
self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected)

def test_csv_with_dataframe_input_and_schema(self):
csv_df = self.connect.createDataFrame(
[("Alice,25",), ("Bob,30",)],
schema="value STRING",
)
result = self.connect.read.csv(csv_df, schema="name STRING, age INT")
expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)]
self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected)

def test_csv_with_dataframe_input_non_string_column(self):
int_df = self.connect.createDataFrame([(1,), (2,)], schema="value INT")
with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_STRING_TYPE"):
self.connect.read.csv(int_df).collect()

def test_csv_with_dataframe_input_multiple_columns(self):
multi_df = self.connect.createDataFrame(
[("Alice,25", "extra"), ("Bob,30", "extra")],
schema="value STRING, other STRING",
)
with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"):
self.connect.read.csv(multi_df).collect()

def test_csv_with_dataframe_input_zero_columns(self):
empty_schema_df = self.connect.range(1).select()
with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"):
self.connect.read.csv(empty_schema_df).collect()

def test_multi_paths(self):
# SPARK-42041: DataFrameReader should support list of paths

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/test_connect_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def compare_method_signatures(self, classic_cls, connect_cls, cls_name):
connect_signature = inspect.signature(connect_methods[method])

# Cannot support RDD arguments from Spark Connect
has_rdd_arguments = ("createDataFrame", "xml", "json", "toJSON")
has_rdd_arguments = ("createDataFrame", "xml", "json", "csv", "toJSON")
if method not in has_rdd_arguments:
self.assertEqual(
classic_signature,
Expand Down
36 changes: 36 additions & 0 deletions python/pyspark/sql/tests/test_datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,42 @@ def test_ignorewhitespace_csv(self):
self.assertEqual(readback.collect(), expected)
shutil.rmtree(tmpPath)

def test_csv_with_dataframe_input(self):
csv_df = self.spark.createDataFrame(
[("Alice,25",), ("Bob,30",)],
schema="value STRING",
)
result = self.spark.read.csv(csv_df)
expected = [Row(_c0="Alice", _c1="25"), Row(_c0="Bob", _c1="30")]
self.assertEqual(sorted(result.collect(), key=lambda r: r._c0), expected)

def test_csv_with_dataframe_input_and_schema(self):
csv_df = self.spark.createDataFrame(
[("Alice,25",), ("Bob,30",)],
schema="value STRING",
)
result = self.spark.read.csv(csv_df, schema="name STRING, age INT")
expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)]
self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected)

def test_csv_with_dataframe_input_non_string_column(self):
int_df = self.spark.createDataFrame([(1,), (2,)], schema="value INT")
with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_STRING_TYPE"):
self.spark.read.csv(int_df).collect()

def test_csv_with_dataframe_input_multiple_columns(self):
multi_df = self.spark.createDataFrame(
[("Alice,25", "extra"), ("Bob,30", "extra")],
schema="value STRING, other STRING",
)
with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"):
self.spark.read.csv(multi_df).collect()

def test_csv_with_dataframe_input_zero_columns(self):
empty_schema_df = self.spark.range(1).select()
with self.assertRaisesRegex(Exception, "DATAFRAME_INPUT_NOT_SINGLE_COLUMN"):
self.spark.read.csv(empty_schema_df).collect()

def test_xml(self):
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.api.python.DechunkedInputStream
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.CLASS_LOADER
import org.apache.spark.security.SocketAuthServer
import org.apache.spark.sql.{internal, Column, DataFrame, DataFrameReader, Encoders, Row, SparkSession, TableArg}
import org.apache.spark.sql.{internal, Column, DataFrame, DataFrameReader, Dataset, Encoders, Row, SparkSession, TableArg}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -196,22 +196,30 @@ private[sql] object PythonSQLUtils extends Logging {
def internalFn(name: String, inputs: Column*): Column = Column.internalFn(name, inputs: _*)

/**
* Parses a [[DataFrame]] containing JSON strings into a structured [[DataFrame]].
* The input DataFrame must have exactly one column of StringType.
* This is used by PySpark to avoid manual Dataset[String] conversion on the Python side.
* Validates that the input [[DataFrame]] has exactly one column of StringType
* and converts it to a Dataset[String].
*/
def jsonFromDataFrame(
reader: DataFrameReader,
df: DataFrame): DataFrame = {
val classicReader = reader.asInstanceOf[ClassicDataFrameReader]
private def toStringDataset(df: DataFrame): Dataset[String] = {
val fields = df.schema.fields
if (fields.length != 1) {
throw QueryCompilationErrors.dataframeInputNotSingleColumnError(fields.length)
}
if (fields.head.dataType != org.apache.spark.sql.types.StringType) {
throw QueryCompilationErrors.dataframeInputNotStringTypeError(fields.head.dataType)
}
classicReader.json(df.as(Encoders.STRING))
df.as(Encoders.STRING)
}

def jsonFromDataFrame(
reader: DataFrameReader,
df: DataFrame): DataFrame = {
reader.asInstanceOf[ClassicDataFrameReader].json(toStringDataset(df))
}

def csvFromDataFrame(
reader: DataFrameReader,
df: DataFrame): DataFrame = {
reader.asInstanceOf[ClassicDataFrameReader].csv(toStringDataset(df))
}

def cleanupPythonWorkerLogs(sessionUUID: String, sparkContext: SparkContext): Unit = {
Expand Down