diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index 0c4f93e64..316cb7b41 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -3466,6 +3466,28 @@ impl fmt::Display for CreateDomain { } } +/// The return type of a `CREATE FUNCTION` statement. +/// +/// [PostgreSQL](https://www.postgresql.org/docs/current/sql-createfunction.html) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum FunctionReturnType { + /// `RETURNS ` + DataType(DataType), + /// `RETURNS SETOF ` + SetOf(DataType), +} + +impl fmt::Display for FunctionReturnType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + FunctionReturnType::DataType(data_type) => write!(f, "{data_type}"), + FunctionReturnType::SetOf(data_type) => write!(f, "SETOF {data_type}"), + } + } +} + #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] @@ -3486,7 +3508,7 @@ pub struct CreateFunction { /// List of arguments for the function. pub args: Option>, /// The return type of the function. - pub return_type: Option, + pub return_type: Option, /// The expression that defines the function. /// /// Examples: diff --git a/src/ast/mod.rs b/src/ast/mod.rs index d534b300b..8691cecdd 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -71,15 +71,16 @@ pub use self::ddl::{ CreateIndex, CreateOperator, CreateOperatorClass, CreateOperatorFamily, CreatePolicy, CreatePolicyCommand, CreatePolicyType, CreateTable, CreateTrigger, CreateView, Deduplicate, DeferrableInitial, DropBehavior, DropExtension, DropFunction, DropOperator, DropOperatorClass, - DropOperatorFamily, DropOperatorSignature, DropPolicy, DropTrigger, ForValues, GeneratedAs, - GeneratedExpressionMode, IdentityParameters, IdentityProperty, IdentityPropertyFormatKind, - IdentityPropertyKind, IdentityPropertyOrder, IndexColumn, IndexOption, IndexType, - KeyOrIndexDisplay, Msck, NullsDistinctOption, OperatorArgTypes, OperatorClassItem, - OperatorFamilyDropItem, OperatorFamilyItem, OperatorOption, OperatorPurpose, Owner, Partition, - PartitionBoundValue, ProcedureParam, ReferentialAction, RenameTableNameKind, ReplicaIdentity, - TagsColumnOption, TriggerObjectKind, Truncate, UserDefinedTypeCompositeAttributeDef, - UserDefinedTypeInternalLength, UserDefinedTypeRangeOption, UserDefinedTypeRepresentation, - UserDefinedTypeSqlDefinitionOption, UserDefinedTypeStorage, ViewColumnDef, + DropOperatorFamily, DropOperatorSignature, DropPolicy, DropTrigger, ForValues, + FunctionReturnType, GeneratedAs, GeneratedExpressionMode, IdentityParameters, IdentityProperty, + IdentityPropertyFormatKind, IdentityPropertyKind, IdentityPropertyOrder, IndexColumn, + IndexOption, IndexType, KeyOrIndexDisplay, Msck, NullsDistinctOption, OperatorArgTypes, + OperatorClassItem, OperatorFamilyDropItem, OperatorFamilyItem, OperatorOption, OperatorPurpose, + Owner, Partition, PartitionBoundValue, ProcedureParam, ReferentialAction, RenameTableNameKind, + ReplicaIdentity, TagsColumnOption, TriggerObjectKind, Truncate, + UserDefinedTypeCompositeAttributeDef, UserDefinedTypeInternalLength, + UserDefinedTypeRangeOption, UserDefinedTypeRepresentation, UserDefinedTypeSqlDefinitionOption, + UserDefinedTypeStorage, ViewColumnDef, }; pub use self::dml::{ Delete, Insert, Merge, MergeAction, MergeClause, MergeClauseKind, MergeInsertExpr, diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 74f19e831..07698fa19 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -597,6 +597,31 @@ impl Spanned for CreateTable { } } +impl Spanned for PartitionBoundValue { + fn span(&self) -> Span { + match self { + PartitionBoundValue::Expr(expr) => expr.span(), + PartitionBoundValue::MinValue => Span::empty(), + PartitionBoundValue::MaxValue => Span::empty(), + } + } +} + +impl Spanned for ForValues { + fn span(&self) -> Span { + match self { + ForValues::In(exprs) => union_spans(exprs.iter().map(|e| e.span())), + ForValues::From { from, to } => union_spans( + from.iter() + .map(|v| v.span()) + .chain(to.iter().map(|v| v.span())), + ), + ForValues::With { .. } => Span::empty(), + ForValues::Default => Span::empty(), + } + } +} + impl Spanned for ColumnDef { fn span(&self) -> Span { let ColumnDef { @@ -632,33 +657,6 @@ impl Spanned for TableConstraint { } } -impl Spanned for PartitionBoundValue { - fn span(&self) -> Span { - match self { - PartitionBoundValue::Expr(expr) => expr.span(), - // MINVALUE and MAXVALUE are keywords without tracked spans - PartitionBoundValue::MinValue => Span::empty(), - PartitionBoundValue::MaxValue => Span::empty(), - } - } -} - -impl Spanned for ForValues { - fn span(&self) -> Span { - match self { - ForValues::In(exprs) => union_spans(exprs.iter().map(|e| e.span())), - ForValues::From { from, to } => union_spans( - from.iter() - .map(|v| v.span()) - .chain(to.iter().map(|v| v.span())), - ), - // WITH (MODULUS n, REMAINDER r) - u64 values have no spans - ForValues::With { .. } => Span::empty(), - ForValues::Default => Span::empty(), - } - } -} - impl Spanned for CreateIndex { fn span(&self) -> Span { let CreateIndex { diff --git a/src/keywords.rs b/src/keywords.rs index cc2b9e9dd..07d02f3b2 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -933,6 +933,7 @@ define_keywords!( SESSION_USER, SET, SETERROR, + SETOF, SETS, SETTINGS, SHARE, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 6c9314d95..675c1de9e 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -5545,7 +5545,7 @@ impl<'a> Parser<'a> { self.expect_token(&Token::RParen)?; let return_type = if self.parse_keyword(Keyword::RETURNS) { - Some(self.parse_data_type()?) + Some(self.parse_function_return_type()?) } else { None }; @@ -5724,7 +5724,7 @@ impl<'a> Parser<'a> { let (name, args) = self.parse_create_function_name_and_params()?; let return_type = if self.parse_keyword(Keyword::RETURNS) { - Some(self.parse_data_type()?) + Some(self.parse_function_return_type()?) } else { None }; @@ -5827,11 +5827,11 @@ impl<'a> Parser<'a> { }) })?; - let return_type = if return_table.is_some() { - return_table - } else { - Some(self.parse_data_type()?) + let data_type = match return_table { + Some(table_type) => table_type, + None => self.parse_data_type()?, }; + let return_type = Some(FunctionReturnType::DataType(data_type)); let _ = self.parse_keyword(Keyword::AS); @@ -5883,6 +5883,17 @@ impl<'a> Parser<'a> { }) } + /// Parse a [`FunctionReturnType`] after the `RETURNS` keyword. + /// + /// Handles `RETURNS SETOF ` and plain `RETURNS `. + fn parse_function_return_type(&mut self) -> Result { + if self.parse_keyword(Keyword::SETOF) { + Ok(FunctionReturnType::SetOf(self.parse_data_type()?)) + } else { + Ok(FunctionReturnType::DataType(self.parse_data_type()?)) + } + } + fn parse_create_function_name_and_params( &mut self, ) -> Result<(ObjectName, Vec), ParserError> { diff --git a/tests/sqlparser_bigquery.rs b/tests/sqlparser_bigquery.rs index cf843ea2b..28f306f35 100644 --- a/tests/sqlparser_bigquery.rs +++ b/tests/sqlparser_bigquery.rs @@ -2279,7 +2279,7 @@ fn test_bigquery_create_function() { Ident::new("myfunction"), ]), args: Some(vec![OperateFunctionArg::with_name("x", DataType::Float64),]), - return_type: Some(DataType::Float64), + return_type: Some(FunctionReturnType::DataType(DataType::Float64)), function_body: Some(CreateFunctionBody::AsAfterOptions(Expr::Value( number("42").with_empty_span() ))), diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index d7d11ba66..72b60b511 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -255,7 +255,7 @@ fn parse_create_function() { default_expr: None, }, ]), - return_type: Some(DataType::Int(None)), + return_type: Some(FunctionReturnType::DataType(DataType::Int(None))), function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { begin_token: AttachedToken::empty(), statements: vec![Statement::Return(ReturnStatement { @@ -430,7 +430,7 @@ fn parse_create_function_parameter_default_values() { data_type: DataType::Int(None), default_expr: Some(Expr::Value((number("42")).with_empty_span())), },]), - return_type: Some(DataType::Int(None)), + return_type: Some(FunctionReturnType::DataType(DataType::Int(None))), function_body: Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { begin_token: AttachedToken::empty(), statements: vec![Statement::Return(ReturnStatement { diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index d79e2b833..338eb451e 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -4346,7 +4346,7 @@ $$"#; DataType::Varchar(None), ), ]), - return_type: Some(DataType::Boolean), + return_type: Some(FunctionReturnType::DataType(DataType::Boolean)), language: Some("plpgsql".into()), behavior: None, called_on_null: None, @@ -4389,7 +4389,7 @@ $$"#; DataType::Int(None) ) ]), - return_type: Some(DataType::Boolean), + return_type: Some(FunctionReturnType::DataType(DataType::Boolean)), language: Some("plpgsql".into()), behavior: None, called_on_null: None, @@ -4436,7 +4436,7 @@ $$"#; DataType::Int(None) ), ]), - return_type: Some(DataType::Boolean), + return_type: Some(FunctionReturnType::DataType(DataType::Boolean)), language: Some("plpgsql".into()), behavior: None, called_on_null: None, @@ -4483,7 +4483,7 @@ $$"#; DataType::Int(None) ), ]), - return_type: Some(DataType::Boolean), + return_type: Some(FunctionReturnType::DataType(DataType::Boolean)), language: Some("plpgsql".into()), behavior: None, called_on_null: None, @@ -4523,7 +4523,7 @@ $$"#; ), OperateFunctionArg::with_name("b", DataType::Varchar(None)), ]), - return_type: Some(DataType::Boolean), + return_type: Some(FunctionReturnType::DataType(DataType::Boolean)), language: Some("plpgsql".into()), behavior: None, called_on_null: None, @@ -4566,7 +4566,7 @@ fn parse_create_function() { OperateFunctionArg::unnamed(DataType::Integer(None)), OperateFunctionArg::unnamed(DataType::Integer(None)), ]), - return_type: Some(DataType::Integer(None)), + return_type: Some(FunctionReturnType::DataType(DataType::Integer(None))), language: Some("SQL".into()), behavior: Some(FunctionBehavior::Immutable), called_on_null: Some(FunctionCalledOnNull::Strict), @@ -4603,6 +4603,30 @@ fn parse_create_function_detailed() { ); } +#[test] +fn parse_create_function_returns_setof() { + pg_and_generic().verified_stmt( + "CREATE FUNCTION get_users() RETURNS SETOF TEXT LANGUAGE sql AS 'SELECT name FROM users'", + ); + pg_and_generic().verified_stmt( + "CREATE FUNCTION get_ids() RETURNS SETOF INTEGER LANGUAGE sql AS 'SELECT id FROM users'", + ); + pg_and_generic().verified_stmt( + r#"CREATE FUNCTION get_all() RETURNS SETOF my_schema."MyType" LANGUAGE sql AS 'SELECT * FROM t'"#, + ); + pg_and_generic().verified_stmt( + "CREATE FUNCTION get_rows() RETURNS SETOF RECORD LANGUAGE sql AS 'SELECT * FROM t'", + ); + + let sql = "CREATE FUNCTION get_names() RETURNS SETOF TEXT LANGUAGE sql AS 'SELECT name FROM t'"; + match pg_and_generic().verified_stmt(sql) { + Statement::CreateFunction(CreateFunction { return_type, .. }) => { + assert_eq!(return_type, Some(FunctionReturnType::SetOf(DataType::Text))); + } + _ => panic!("Expected CreateFunction"), + } +} + #[test] fn parse_create_function_with_security() { let sql = @@ -4678,10 +4702,10 @@ fn parse_create_function_c_with_module_pathname() { "input", DataType::Custom(ObjectName::from(vec![Ident::new("cstring")]), vec![]), ),]), - return_type: Some(DataType::Custom( + return_type: Some(FunctionReturnType::DataType(DataType::Custom( ObjectName::from(vec![Ident::new("cas")]), vec![] - )), + ))), language: Some("c".into()), behavior: Some(FunctionBehavior::Immutable), called_on_null: None, @@ -6375,7 +6399,7 @@ fn parse_trigger_related_functions() { if_not_exists: false, name: ObjectName::from(vec![Ident::new("emp_stamp")]), args: Some(vec![]), - return_type: Some(DataType::Trigger), + return_type: Some(FunctionReturnType::DataType(DataType::Trigger)), function_body: Some( CreateFunctionBody::AsBeforeOptions { body: Expr::Value((