diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index e402067926f2a..fa0ad28651115 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -518,7 +518,7 @@ ZONE: 'ZONE'; EQ : '=' | '=='; NSEQ: '<=>'; -NEQ : '<>'; +NEQ : '<>' {complex_type_level_counter == 0}?; NEQJ: '!='; LT : '<'; LTE : '<=' | '!>'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 698afa4860027..4cfd60816f20a 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1377,7 +1377,7 @@ primitiveType dataType : complex=ARRAY (LT dataType GT)? #complexDataType | complex=MAP (LT dataType COMMA dataType GT)? #complexDataType - | complex=STRUCT ((LT complexColTypeList? GT) | NEQ)? #complexDataType + | complex=STRUCT (LT complexColTypeList? GT)? #complexDataType | primitiveType #primitiveDataType ; diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala index beb7061a841a8..b2b1bced39cad 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala @@ -180,7 +180,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] { * Create a complex DataType. Arrays, Maps and Structures are supported. */ override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { - if (ctx.LT() == null && ctx.NEQ() == null) { + if (ctx.LT() == null) { throw QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.getText, ctx) } ctx.complex.getType match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 94e60db67ac75..e4019b0a723ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.{SparkConf, SparkThrowable} import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedHaving, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} +import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Cast, Concat, GreaterThan, Literal, NamedExpression, NullsFirst, ShiftRight, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} import org.apache.spark.sql.catalyst.parser.{AbstractParser, ParseException} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType} import org.apache.spark.util.ArrayImplicits._ /** @@ -1164,4 +1164,75 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { } } } + + test("SPARK-52709: Parsing STRUCT (empty,nested,within complex types) followed by shiftRight") { + + // Test valid complex data types, and their combinations. + val typeStringsToTest = Seq( + "STRUCT<>", // Empty struct + "STRUCT>", // Nested struct + "STRUCT>", // Struct containing an array + "MAP>", // Map containing a struct + "ARRAY>", // Array containing empty structs + "ARRAY>" // Array containing non-empty structs + ) + + /** + * Helper function to generate a SQL CAST fragment and its corresponding + * expected expression for a given type string. + */ + def createCastNullAsTypeExpression(typeString: String): (String, NamedExpression) = { + // Use the suite's 'parser' instance to parse the DataType + val dataType: DataType = parser.parseDataType(typeString) + val castExpr = Cast(Literal(null, NullType), dataType) + val expectedExpr = UnresolvedAlias(castExpr) // SparkSqlParserSuite expects UnresolvedAlias + val sqlFragment = s"CAST(null AS $typeString)" + (sqlFragment, expectedExpr) + } + + // Generate the SQL fragments and their corresponding expected expressions for all CASTs + val castExpressionsData = typeStringsToTest.map(createCastNullAsTypeExpression) + + // Extract just the SQL fragments for the SELECT statement + val selectClauses = castExpressionsData.map(_._1) + + val sql = + s""" + |SELECT + | ${selectClauses.mkString(",\n ")}, + | 4 >> 1 + """.stripMargin + + // Construct the list of ALL expected expressions for the Project node. + // This includes all the CAST expressions generated above, plus the ShiftRight expression. + val allExpectedExprs = castExpressionsData.map(_._2) :+ + UnresolvedAlias(ShiftRight(Literal(4, IntegerType), Literal(1, IntegerType))) + + // Define the expected logical plan + val expectedPlan = Project( + allExpectedExprs, + OneRowRelation() + ) + + assertEqual(sql, expectedPlan) + } + + test("SPARK-52709-Invalid: Parsing should fail for empty ARRAY<> type") { + val sql = "SELECT CAST(null AS ARRAY<>)" + checkError( + exception = parseException(sql), + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'<'", "hint" -> ": missing ')'") + ) + } + + test("SPARK-52709-Invalid: Parsing should fail for empty MAP<> type") { + val sql = "SELECT CAST(null AS MAP<>)" + checkError( + exception = parseException(sql), + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'<'", "hint" -> ": missing ')'") + ) + } } +