diff --git a/src/ast/data_type.rs b/src/ast/data_type.rs index 52919de8a..b731057d0 100644 --- a/src/ast/data_type.rs +++ b/src/ast/data_type.rs @@ -48,7 +48,15 @@ pub enum DataType { /// Table type in [PostgreSQL], e.g. CREATE FUNCTION RETURNS TABLE(...). /// /// [PostgreSQL]: https://www.postgresql.org/docs/15/sql-createfunction.html - Table(Vec), + /// [MsSQL]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql?view=sql-server-ver16#c-create-a-multi-statement-table-valued-function + Table(Option>), + /// Table type with a name, e.g. CREATE FUNCTION RETURNS @result TABLE(...). + NamedTable( + /// Table name. + ObjectName, + /// Table columns. + Vec, + ), /// Fixed-length character type, e.g. CHARACTER(10). Character(Option), /// Fixed-length char type, e.g. CHAR(10). @@ -716,7 +724,17 @@ impl fmt::Display for DataType { DataType::Unspecified => Ok(()), DataType::Trigger => write!(f, "TRIGGER"), DataType::AnyType => write!(f, "ANY TYPE"), - DataType::Table(fields) => write!(f, "TABLE({})", display_comma_separated(fields)), + DataType::Table(fields) => match fields { + Some(fields) => { + write!(f, "TABLE({})", display_comma_separated(fields)) + } + None => { + write!(f, "TABLE") + } + }, + DataType::NamedTable(name, fields) => { + write!(f, "{} TABLE ({})", name, display_comma_separated(fields)) + } DataType::GeometricType(kind) => write!(f, "{}", kind), } } diff --git a/src/ast/ddl.rs b/src/ast/ddl.rs index a457a0655..14f909a87 100644 --- a/src/ast/ddl.rs +++ b/src/ast/ddl.rs @@ -2313,6 +2313,12 @@ impl fmt::Display for CreateFunction { if let Some(CreateFunctionBody::Return(function_body)) = &self.function_body { write!(f, " RETURN {function_body}")?; } + if let Some(CreateFunctionBody::AsReturnSubquery(function_body)) = &self.function_body { + write!(f, " AS RETURN {function_body}")?; + } + if let Some(CreateFunctionBody::AsReturnSelect(function_body)) = &self.function_body { + write!(f, " AS RETURN {function_body}")?; + } if let Some(using) = &self.using { write!(f, " {using}")?; } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 6b7ba12d9..b58075372 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -8660,6 +8660,28 @@ pub enum CreateFunctionBody { /// /// [PostgreSQL]: https://www.postgresql.org/docs/current/sql-createfunction.html Return(Expr), + + /// Function body expression using the 'AS RETURN' keywords + /// + /// Example: + /// ```sql + /// CREATE FUNCTION myfunc(a INT, b INT) + /// RETURNS TABLE + /// AS RETURN (SELECT a + b AS sum); + /// ``` + /// + /// [MsSql]: https://learn.microsoft.com/en-us/sql/t-sql/statements/create-function-transact-sql + AsReturnSubquery(Expr), + + /// Function body expression using the 'AS RETURN' keywords, with an un-parenthesized SELECT query + /// + /// Example: + /// ```sql + /// CREATE FUNCTION myfunc(a INT, b INT) + /// RETURNS TABLE + /// AS RETURN SELECT a + b AS sum; + /// ``` + AsReturnSelect(Select), } #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] diff --git a/src/parser/mod.rs b/src/parser/mod.rs index d18c7f694..5effa1943 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -5203,19 +5203,79 @@ impl<'a> Parser<'a> { let (name, args) = self.parse_create_function_name_and_params()?; self.expect_keyword(Keyword::RETURNS)?; - let return_type = Some(self.parse_data_type()?); - self.expect_keyword_is(Keyword::AS)?; + let return_table = self.maybe_parse(|p| { + let return_table_name = p.parse_identifier()?; + let table_column_defs = if p.peek_keyword(Keyword::TABLE) { + match p.parse_data_type()? { + DataType::Table(t) => t, + _ => parser_err!( + "Expected table data type after TABLE keyword", + p.peek_token().span.start + )?, + } + } else { + parser_err!( + "Expected TABLE keyword after return type", + p.peek_token().span.start + )? + }; - let begin_token = self.expect_keyword(Keyword::BEGIN)?; - let statements = self.parse_statement_list(&[Keyword::END])?; - let end_token = self.expect_keyword(Keyword::END)?; + if table_column_defs.is_none() + || table_column_defs.clone().is_some_and(|tcd| tcd.is_empty()) + { + parser_err!( + "Expected table column definitions after TABLE keyword", + p.peek_token().span.start + )? + } - let function_body = Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { - begin_token: AttachedToken(begin_token), - statements, - end_token: AttachedToken(end_token), - })); + Ok(DataType::NamedTable( + ObjectName(vec![ObjectNamePart::Identifier(return_table_name)]), + table_column_defs.clone().unwrap(), + )) + })?; + + let return_type = if return_table.is_some() { + return_table + } else { + Some(self.parse_data_type()?) + }; + + let _ = self.parse_keyword(Keyword::AS); + + let function_body = if self.peek_keyword(Keyword::BEGIN) { + let begin_token = self.expect_keyword(Keyword::BEGIN)?; + let statements = self.parse_statement_list(&[Keyword::END])?; + let end_token = self.expect_keyword(Keyword::END)?; + + Some(CreateFunctionBody::AsBeginEnd(BeginEndStatements { + begin_token: AttachedToken(begin_token), + statements, + end_token: AttachedToken(end_token), + })) + } else if self.parse_keyword(Keyword::RETURN) { + if self.peek_token() == Token::LParen { + let expr = self.parse_expr()?; + if !matches!(expr, Expr::Subquery(_)) { + parser_err!( + "Expected a subquery after RETURN", + self.peek_token().span.start + )? + } + Some(CreateFunctionBody::AsReturnSubquery(expr)) + } else if self.peek_keyword(Keyword::SELECT) { + let select = self.parse_select()?; + Some(CreateFunctionBody::AsReturnSelect(select)) + } else { + parser_err!( + "Expected a subquery (or bare SELECT statement) after RETURN", + self.peek_token().span.start + )? + } + } else { + parser_err!("Unparsable function body", self.peek_token().span.start)? + }; Ok(Statement::CreateFunction(CreateFunction { or_alter, @@ -9766,8 +9826,12 @@ impl<'a> Parser<'a> { Ok(DataType::AnyType) } Keyword::TABLE => { - let columns = self.parse_returns_table_columns()?; - Ok(DataType::Table(columns)) + if self.peek_token() != Token::LParen { + Ok(DataType::Table(None)) + } else { + let columns = self.parse_returns_table_columns()?; + Ok(DataType::Table(Some(columns))) + } } Keyword::SIGNED => { if self.parse_keyword(Keyword::INTEGER) { @@ -9808,13 +9872,7 @@ impl<'a> Parser<'a> { } fn parse_returns_table_column(&mut self) -> Result { - let name = self.parse_identifier()?; - let data_type = self.parse_data_type()?; - Ok(ColumnDef { - name, - data_type, - options: Vec::new(), // No constraints expected here - }) + self.parse_column_def() } fn parse_returns_table_columns(&mut self) -> Result, ParserError> { diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs index 1c0a00b16..c1628b165 100644 --- a/tests/sqlparser_mssql.rs +++ b/tests/sqlparser_mssql.rs @@ -254,6 +254,12 @@ fn parse_create_function() { "; let _ = ms().verified_stmt(multi_statement_function); + let multi_statement_function_without_as = multi_statement_function.replace(" AS", ""); + let _ = ms().one_statement_parses_to( + &multi_statement_function_without_as, + multi_statement_function, + ); + let create_function_with_conditional = "\ CREATE FUNCTION some_scalar_udf() \ RETURNS INT \ @@ -288,6 +294,58 @@ fn parse_create_function() { END\ "; let _ = ms().verified_stmt(create_function_with_return_expression); + + let create_inline_table_value_function = "\ + CREATE FUNCTION some_inline_tvf(@foo INT, @bar VARCHAR(256)) \ + RETURNS TABLE \ + AS \ + RETURN (SELECT 1 AS col_1)\ + "; + let _ = ms().verified_stmt(create_inline_table_value_function); + + let create_inline_table_value_function_without_parentheses = "\ + CREATE FUNCTION some_inline_tvf(@foo INT, @bar VARCHAR(256)) \ + RETURNS TABLE \ + AS \ + RETURN SELECT 1 AS col_1\ + "; + let _ = ms().verified_stmt(create_inline_table_value_function_without_parentheses); + + let create_inline_table_value_function_without_as = + create_inline_table_value_function.replace(" AS", ""); + let _ = ms().one_statement_parses_to( + &create_inline_table_value_function_without_as, + create_inline_table_value_function, + ); + + let create_multi_statement_table_value_function = "\ + CREATE FUNCTION some_multi_statement_tvf(@foo INT, @bar VARCHAR(256)) \ + RETURNS @t TABLE (col_1 INT) \ + AS \ + BEGIN \ + INSERT INTO @t SELECT 1; \ + RETURN; \ + END\ + "; + let _ = ms().verified_stmt(create_multi_statement_table_value_function); + + let create_multi_statement_table_value_function_without_as = + create_multi_statement_table_value_function.replace(" AS", ""); + let _ = ms().one_statement_parses_to( + &create_multi_statement_table_value_function_without_as, + create_multi_statement_table_value_function, + ); + + let create_multi_statement_table_value_function_with_constraints = "\ + CREATE FUNCTION some_multi_statement_tvf(@foo INT, @bar VARCHAR(256)) \ + RETURNS @t TABLE (col_1 INT NOT NULL) \ + AS \ + BEGIN \ + INSERT INTO @t SELECT 1; \ + RETURN @t; \ + END\ + "; + let _ = ms().verified_stmt(create_multi_statement_table_value_function_with_constraints); } #[test]