From 36fe55b6457f0dc82c02da6975358c01dd13aa3b Mon Sep 17 00:00:00 2001 From: Alan Cai Date: Tue, 6 Oct 2020 12:14:00 -0400 Subject: [PATCH] LET (from `FROM` clauses) implementation (#303) --- .../org/partiql/lang/ast/AstSerialization.kt | 6 +- .../partiql/lang/ast/ExprNodeToStatement.kt | 12 ++ .../partiql/lang/ast/StatementToExprNode.kt | 15 +- lang/src/org/partiql/lang/ast/ast.kt | 22 ++- .../lang/ast/passes/AstRewriterBase.kt | 21 ++- .../lang/ast/passes/AstSanityValidator.kt | 2 +- .../org/partiql/lang/ast/passes/AstWalker.kt | 2 +- .../passes/GroupByPathExpressionRewriter.kt | 3 + lang/src/org/partiql/lang/errors/ErrorCode.kt | 5 + .../partiql/lang/eval/EvaluatingCompiler.kt | 46 +++++- .../org/partiql/lang/syntax/LexerConstants.kt | 1 + lang/src/org/partiql/lang/syntax/SqlParser.kt | 70 ++++++++ lang/test/org/partiql/lang/ast/AstNodeTest.kt | 7 +- .../partiql/lang/errors/ParserErrorsTest.kt | 22 +++ .../eval/EvaluatingCompilerFromLetTests.kt | 149 ++++++++++++++++++ .../lang/eval/EvaluatorErrorTestCase.kt | 49 ++++++ .../partiql/lang/eval/EvaluatorTestBase.kt | 17 ++ .../org/partiql/lang/syntax/SqlParserTest.kt | 66 ++++++++ .../partiql/lang/syntax/SqlParserTestBase.kt | 16 ++ .../partiql/lang/util/ArgumentsProvider.kt | 26 +++ 20 files changed, 539 insertions(+), 18 deletions(-) create mode 100644 lang/test/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt create mode 100644 lang/test/org/partiql/lang/eval/EvaluatorErrorTestCase.kt create mode 100644 lang/test/org/partiql/lang/util/ArgumentsProvider.kt diff --git a/lang/src/org/partiql/lang/ast/AstSerialization.kt b/lang/src/org/partiql/lang/ast/AstSerialization.kt index bd85d71a18..1052eb9baf 100644 --- a/lang/src/org/partiql/lang/ast/AstSerialization.kt +++ b/lang/src/org/partiql/lang/ast/AstSerialization.kt @@ -202,7 +202,11 @@ private class AstSerializerImpl(val astVersion: AstVersion, val ion: IonSystem): } private fun IonWriterContext.writeSelect(expr: Select) { - val (setQuantifier, projection, from, where, groupBy, having, limit, _: MetaContainer) = expr + val (setQuantifier, projection, from, fromLet, where, groupBy, having, limit, _: MetaContainer) = expr + + if (fromLet != null) { + throw UnsupportedOperationException("LET clause is not supported by the V0 AST") + } writeSelectProjection(projection, setQuantifier) diff --git a/lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt b/lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt index 0d2ef0a304..902a51a19b 100644 --- a/lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt +++ b/lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt @@ -133,6 +133,7 @@ fun ExprNode.toAstExpr(): PartiqlAst.Expr { }, project = node.projection.toAstSelectProject(), from = node.from.toAstFromSource(), + fromLet = node.fromLet?.toAstLetSource(), where = node.where?.toAstExpr(), group = node.groupBy?.toAstGroupSpec(), having = node.having?.toAstExpr(), @@ -258,6 +259,17 @@ private fun FromSource.toAstFromSource(): PartiqlAst.FromSource { } } +private fun LetSource.toAstLetSource(): PartiqlAst.Let { + val thiz = this + return PartiqlAst.build { + let( + thiz.bindings.map { + letBinding(it.expr.toAstExpr(), it.name.name) + } + ) + } +} + private fun PathComponent.toAstPathStep(): PartiqlAst.PathStep { val thiz = this return PartiqlAst.build { diff --git a/lang/src/org/partiql/lang/ast/StatementToExprNode.kt b/lang/src/org/partiql/lang/ast/StatementToExprNode.kt index 83c076f96e..d11bc77c8b 100644 --- a/lang/src/org/partiql/lang/ast/StatementToExprNode.kt +++ b/lang/src/org/partiql/lang/ast/StatementToExprNode.kt @@ -10,6 +10,7 @@ import org.partiql.lang.domains.PartiqlAst.FromSource import org.partiql.lang.domains.PartiqlAst.GroupBy import org.partiql.lang.domains.PartiqlAst.GroupingStrategy import org.partiql.lang.domains.PartiqlAst.JoinType +import org.partiql.lang.domains.PartiqlAst.Let import org.partiql.lang.domains.PartiqlAst.PathStep import org.partiql.lang.domains.PartiqlAst.ProjectItem import org.partiql.lang.domains.PartiqlAst.Projection @@ -167,6 +168,7 @@ private class StatementTransformer(val ion: IonSystem) { setQuantifier = setq?.toSetQuantifier() ?: org.partiql.lang.ast.SetQuantifier.ALL, projection = project.toSelectProjection(), from = from.toFromSource(), + fromLet = fromLet?.toLetSource(), where = where?.toExprNode(), groupBy = group?.toGroupBy(), having = having?.toExprNode(), @@ -232,7 +234,18 @@ private class StatementTransformer(val ion: IonSystem) { is JoinType.Full -> JoinOp.OUTER } - private fun SymbolPrimitive?.toSymbolicName() = this?.let { SymbolicName(it.text, it.metas.toPartiQlMetaContainer()) } + private fun Let.toLetSource(): LetSource { + return LetSource( + this.letBindings.map { + LetBinding( + it.expr.toExprNode(), + it.name.toSymbolicName() + ) + } + ) + } + + private fun SymbolPrimitive.toSymbolicName() = SymbolicName(this.text, this.metas.toPartiQlMetaContainer()) private fun GroupBy.toGroupBy(): org.partiql.lang.ast.GroupBy = GroupBy( diff --git a/lang/src/org/partiql/lang/ast/ast.kt b/lang/src/org/partiql/lang/ast/ast.kt index fd4ae53df8..8be4f19948 100644 --- a/lang/src/org/partiql/lang/ast/ast.kt +++ b/lang/src/org/partiql/lang/ast/ast.kt @@ -333,13 +333,14 @@ data class Select( val setQuantifier: SetQuantifier = SetQuantifier.ALL, val projection: SelectProjection, val from: FromSource, + val fromLet: LetSource? = null, val where: ExprNode? = null, val groupBy: GroupBy? = null, val having: ExprNode? = null, val limit: ExprNode? = null, override val metas: MetaContainer ) : ExprNode() { - override val children: List = listOfNotNull(projection, from, where, groupBy, having, limit) + override val children: List = listOfNotNull(projection, from, fromLet, where, groupBy, having, limit) } //******************************** @@ -633,6 +634,25 @@ data class FromSourceUnpivot( override val children: List = listOf(expr) } +//******************************** +// LET clause +//******************************** + +/** Represents a list of LetBindings */ +data class LetSource( + val bindings: List +) : AstNode() { + override val children: List = bindings +} + +/** Represents ` AS ` */ +data class LetBinding( + val expr: ExprNode, + val name: SymbolicName +) : AstNode() { + override val children: List = listOf(expr) +} + /** For `GROUP [ PARTIAL ] BY ... [ GROUP AS ]`. */ data class GroupBy( val grouping: GroupingStrategy, diff --git a/lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt b/lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt index 03ffb27c8d..f322988b72 100644 --- a/lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt +++ b/lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt @@ -140,15 +140,17 @@ open class AstRewriterBase : AstRewriter { * The traversal order is in the SQL semantic order--that is: * * 1. `FROM` - * 2. `WHERE` - * 3. `GROUP BY` - * 4. `HAVING` - * 5. *projection* - * 6. `ORDER BY` (to be implemented) - * 7. `LIMIT` + * 2. `LET` + * 3. `WHERE` + * 4. `GROUP BY` + * 5. `HAVING` + * 6. *projection* + * 7. `ORDER BY` (to be implemented) + * 8. `LIMIT` */ protected open fun innerRewriteSelect(selectExpr: Select): Select { val from = rewriteFromSource(selectExpr.from) + val fromLet = selectExpr.fromLet?.let { rewriteLetSource(it) } val where = selectExpr.where?.let { rewriteSelectWhere(it) } val groupBy = selectExpr.groupBy?.let { rewriteGroupBy(it) } val having = selectExpr.having?.let { rewriteSelectHaving(it) } @@ -160,6 +162,7 @@ open class AstRewriterBase : AstRewriter { setQuantifier = selectExpr.setQuantifier, projection = projection, from = from, + fromLet = fromLet, where = where, groupBy = groupBy, having = having, @@ -248,6 +251,12 @@ open class AstRewriterBase : AstRewriter { variables.atName?.let { rewriteSymbolicName(it) }, variables.byName?.let { rewriteSymbolicName(it) }) + open fun rewriteLetSource(letSource: LetSource) = + LetSource(letSource.bindings.map { rewriteLetBinding(it) }) + + open fun rewriteLetBinding(letBinding: LetBinding): LetBinding = + LetBinding(rewriteExprNode(letBinding.expr), rewriteSymbolicName(letBinding.name)) + /** * This is called by the methods responsible for rewriting instances of the [FromSourceLet] * to rewrite their expression. This exists to provide a place for derived rewriters to diff --git a/lang/src/org/partiql/lang/ast/passes/AstSanityValidator.kt b/lang/src/org/partiql/lang/ast/passes/AstSanityValidator.kt index 4d2979c97d..0332276fc3 100644 --- a/lang/src/org/partiql/lang/ast/passes/AstSanityValidator.kt +++ b/lang/src/org/partiql/lang/ast/passes/AstSanityValidator.kt @@ -62,7 +62,7 @@ object AstSanityValidator { } } is Select -> { - val (_, projection, _, _, groupBy, having, _, metas) = node + val (_, projection, _, _, _, groupBy, having, _, metas) = node if(groupBy != null) { if (groupBy.grouping == GroupingStrategy.PARTIAL) { diff --git a/lang/src/org/partiql/lang/ast/passes/AstWalker.kt b/lang/src/org/partiql/lang/ast/passes/AstWalker.kt index e5fa180f0d..d57efd20d0 100644 --- a/lang/src/org/partiql/lang/ast/passes/AstWalker.kt +++ b/lang/src/org/partiql/lang/ast/passes/AstWalker.kt @@ -87,7 +87,7 @@ open class AstWalker(private val visitor: AstVisitor) { } } is Select -> case { - val (_, projection, from, where, groupBy, having, limit, _: MetaContainer) = expr + val (_, projection, from, fromLet, where, groupBy, having, limit, _: MetaContainer) = expr walkSelectProjection(projection) walkFromSource(from) walkExprNode(where) diff --git a/lang/src/org/partiql/lang/ast/passes/GroupByPathExpressionRewriter.kt b/lang/src/org/partiql/lang/ast/passes/GroupByPathExpressionRewriter.kt index bf665efcb7..e0fe76a702 100644 --- a/lang/src/org/partiql/lang/ast/passes/GroupByPathExpressionRewriter.kt +++ b/lang/src/org/partiql/lang/ast/passes/GroupByPathExpressionRewriter.kt @@ -95,6 +95,8 @@ class GroupByPathExpressionRewriter( // The scope of the expressions in the FROM clause is the same as that of the parent scope. val from = this.rewriteFromSource(selectExpr.from) + val fromLet = selectExpr.fromLet?.let { unshadowedRewriter.rewriteLetSource(it) } + val where = selectExpr.where?.let { unshadowedRewriter.rewriteSelectWhere(it) } val groupBy = selectExpr.groupBy?.let { unshadowedRewriter.rewriteGroupBy(it) } @@ -109,6 +111,7 @@ class GroupByPathExpressionRewriter( setQuantifier = selectExpr.setQuantifier, projection = projection, from = from, + fromLet = fromLet, where = where, groupBy = groupBy, having = having, diff --git a/lang/src/org/partiql/lang/errors/ErrorCode.kt b/lang/src/org/partiql/lang/errors/ErrorCode.kt index bb6bf33bd4..5ae6a600f7 100644 --- a/lang/src/org/partiql/lang/errors/ErrorCode.kt +++ b/lang/src/org/partiql/lang/errors/ErrorCode.kt @@ -282,6 +282,11 @@ enum class ErrorCode(private val category: ErrorCategory, LOC_TOKEN, "expected identifier for alias"), + PARSE_EXPECTED_AS_FOR_LET( + ErrorCategory.PARSER, + LOC_TOKEN, + "expected AS for LET clause"), + PARSE_UNSUPPORTED_CALL_WITH_STAR( ErrorCategory.PARSER, LOC_TOKEN, diff --git a/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt b/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt index 88c8614b4a..5580d6b346 100644 --- a/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt +++ b/lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt @@ -890,7 +890,7 @@ internal class EvaluatingCompiler( } private fun compileSelect(selectExpr: Select): ThunkEnv { - // Get all the FROM source aliases for binding error checks + // Get all the FROM source aliases and LET bindings for binding error checks val fold = object : PartiqlAst.VisitorFold>() { /** Store all the visited FROM source aliases in the accumulator */ override fun visitFromSourceScan(node: PartiqlAst.FromSource.Scan, accumulator: Set): Set { @@ -898,6 +898,11 @@ internal class EvaluatingCompiler( return accumulator + aliases.toSet() } + override fun visitLetBinding(node: PartiqlAst.LetBinding, accumulator: Set): Set { + val aliases = listOfNotNull(node.name.text) + return accumulator + aliases + } + /** Prevents visitor from recursing into nested select statements */ override fun walkExprSelect(node: PartiqlAst.Expr.Select, accumulator: Set): Set { return accumulator @@ -905,12 +910,16 @@ internal class EvaluatingCompiler( } val pigGeneratedAst = selectExpr.toAstExpr() as PartiqlAst.Expr.Select val allFromSourceAliases = fold.walkFromSource(pigGeneratedAst.from, emptySet()) + .union(pigGeneratedAst.fromLet?.let { fold.walkLet(pigGeneratedAst.fromLet, emptySet()) } ?: emptySet()) return nestCompilationContext(ExpressionContext.NORMAL, emptySet()) { - val (setQuantifier, projection, from, _, groupBy, having, _, metas: MetaContainer) = selectExpr + val (setQuantifier, projection, from, fromLet, _, groupBy, having, _, metas: MetaContainer) = selectExpr val fromSourceThunks = compileFromSources(from) - val sourceThunks = compileQueryWithoutProjection(selectExpr, fromSourceThunks) + + val letSourceThunks = fromLet?.let { compileLetSources(it) } + + val sourceThunks = compileQueryWithoutProjection(selectExpr, fromSourceThunks, letSourceThunks) // Returns a thunk that invokes [sourceThunks], and invokes [projectionThunk] to perform the projection. fun getQueryThunk(selectProjectionThunk: ThunkEnvValue>): ThunkEnv { @@ -1358,13 +1367,20 @@ internal class EvaluatingCompiler( return sources } + + private fun compileLetSources(letSource: LetSource): List = + letSource.bindings.map { + CompiledLetSource(name = it.name.name, thunk = compileExprNode(it.expr)) + } + /** * Compiles the clauses of the SELECT or PIVOT into a thunk that does not generate * the final projection. */ private fun compileQueryWithoutProjection( ast: Select, - compiledSources: List + compiledSources: List, + compiledLetSources: List? ): (Environment) -> Sequence { val localsBinder = compiledSources.map { it.alias }.localsBinder(valueFactory.missingValue) @@ -1438,6 +1454,24 @@ internal class EvaluatingCompiler( // bind the joined value to the bindings for the filter/project FromProduction(joinedValues, fromEnv.nest(localsBinder.bindLocals(joinedValues))) } + // Nest LET bindings in the FROM environment + if (compiledLetSources != null) { + seq = seq.map { fromProduction -> + val parentEnv = fromProduction.env + + val letEnv: Environment = compiledLetSources.fold(parentEnv) { accEnvironment, curLetSource -> + val letValue = curLetSource.thunk(accEnvironment) + val binding = Bindings.over { bindingName -> + when { + bindingName.isEquivalentTo(curLetSource.name) -> letValue + else -> null + } + } + accEnvironment.nest(newLocals = binding) + } + fromProduction.copy(env = letEnv) + } + } if (whereThunk != null) { seq = seq.filter { (_, env) -> val whereClauseResult = whereThunk(env) @@ -1931,6 +1965,10 @@ private enum class JoinExpansion { OUTER } +private data class CompiledLetSource( + val name: String, + val thunk: ThunkEnv) + private enum class ExpressionContext { /** * Indicates that the compiler is compiling a normal expression (i.e. not one of the other diff --git a/lang/src/org/partiql/lang/syntax/LexerConstants.kt b/lang/src/org/partiql/lang/syntax/LexerConstants.kt index 09d4220314..7c6198aaac 100644 --- a/lang/src/org/partiql/lang/syntax/LexerConstants.kt +++ b/lang/src/org/partiql/lang/syntax/LexerConstants.kt @@ -258,6 +258,7 @@ internal val DATE_PART_KEYWORDS: Set = DatePart.values() "tuple", "remove", "index", + "let", // Ion type names diff --git a/lang/src/org/partiql/lang/syntax/SqlParser.kt b/lang/src/org/partiql/lang/syntax/SqlParser.kt index 045a465123..66d7e9dccb 100644 --- a/lang/src/org/partiql/lang/syntax/SqlParser.kt +++ b/lang/src/org/partiql/lang/syntax/SqlParser.kt @@ -62,6 +62,7 @@ class SqlParser(private val ion: IonSystem) : Parser { PROJECT_ALL, // Wildcard, i.e. the * in `SELECT * FROM f` and a.b.c.* in `SELECT a.b.c.* FROM f` PATH_WILDCARD, PATH_UNPIVOT, + LET, SELECT_LIST, SELECT_VALUE, DISTINCT, @@ -553,6 +554,11 @@ class SqlParser(private val ion: IonSystem) : Parser { val fromSource = fromList.children[0].toFromSource() + val fromLet = unconsumedChildren.firstOrNull { it.type == LET }?.let { + unconsumedChildren.remove(it) + it.toLetSource() + } + val whereExpr = unconsumedChildren.firstOrNull { it.type == WHERE }?.let { unconsumedChildren.remove(it) it.children[0].toExprNode() @@ -600,6 +606,7 @@ class SqlParser(private val ion: IonSystem) : Parser { setQuantifier = setQuantifier, projection = projection, from = fromSource, + fromLet = fromLet, where = whereExpr, groupBy = groupBy, having = havingExpr, @@ -724,6 +731,21 @@ class SqlParser(private val ion: IonSystem) : Parser { return head.unwrapAliasesAndUnpivot() } + private fun ParseNode.toLetSource(): LetSource { + val letBindings = this.children.map { it.toLetBinding() } + return LetSource(letBindings) + } + + private fun ParseNode.toLetBinding(): LetBinding { + val (asAliasSymbol, parseNode) = unwrapAsAlias() + if (asAliasSymbol == null) { + this.errMalformedParseTree("Unsupported syntax for ${this.type}") + } + else { + return LetBinding(parseNode.toExprNode(), asAliasSymbol) + } + } + private fun ParseNode.unwrapAliasesAndUnpivot(): FromSource { val (aliases, unwrappedParseNode) = unwrapAliases() return when(unwrappedParseNode.type) { @@ -1540,6 +1562,12 @@ class SqlParser(private val ion: IonSystem) : Parser { } } + if (rem.head?.keywordText == "let") { + val letParseNode = rem.parseLet() + rem = letParseNode.remaining + children.add(letParseNode) + } + parseOptionalSingleExpressionClause(WHERE) if (rem.head?.keywordText == "group") { @@ -1815,6 +1843,48 @@ class SqlParser(private val ion: IonSystem) : Parser { } } + private fun List.parseLet(): ParseNode { + val letClauses = ArrayList() + var rem = this.tail + var child = rem.parseExpression() + rem = child.remaining + + if (rem.head?.type != AS) { + rem.head.err("Expected $AS following $LET expr", PARSE_EXPECTED_AS_FOR_LET) + } + + rem = rem.tail + + if (rem.head?.type?.isIdentifier() != true) { + rem.head.err("Expected identifier for $AS-alias", PARSE_EXPECTED_IDENT_FOR_ALIAS) + } + + var name = rem.head + rem = rem.tail + letClauses.add(ParseNode(AS_ALIAS, name, listOf(child), rem)) + + while (rem.head?.type == COMMA) { + rem = rem.tail + child = rem.parseExpression() + rem = child.remaining + if (rem.head?.type != AS) { + rem.head.err("Expected $AS following $LET expr", PARSE_EXPECTED_AS_FOR_LET) + } + + rem = rem.tail + + if (rem.head?.type?.isIdentifier() != true) { + rem.head.err("Expected identifier for $AS-alias", PARSE_EXPECTED_IDENT_FOR_ALIAS) + } + + name = rem.head + + rem = rem.tail + letClauses.add(ParseNode(AS_ALIAS, name, listOf(child), rem)) + } + return ParseNode(LET, null, letClauses, rem) + } + private fun List.parseListLiteral(): ParseNode = parseArgList( aliasSupportType = NONE, diff --git a/lang/test/org/partiql/lang/ast/AstNodeTest.kt b/lang/test/org/partiql/lang/ast/AstNodeTest.kt index 33ea540c83..e4f1b3fdaa 100644 --- a/lang/test/org/partiql/lang/ast/AstNodeTest.kt +++ b/lang/test/org/partiql/lang/ast/AstNodeTest.kt @@ -275,20 +275,21 @@ class AstNodeTest { val from = FromSourceExpr(literal("2"), LetVariables()) assertEquals(listOf(projection, from), - Select(SetQuantifier.ALL, projection, from, null, null, null, null, emptyMeta).children) + Select(SetQuantifier.ALL, projection, from, null, null, null, null, null, emptyMeta).children) } @Test fun selectWithAllChildren() { val projection = SelectProjectionValue(literal("1")) val from = FromSourceExpr(literal("2"), LetVariables()) + val fromLet = LetSource(emptyList()) val where = literal("3") val groupBy = GroupBy(GroupingStrategy.FULL, listOf()) val having = literal("4") val limit = literal("5") - assertEquals(listOf(projection, from, where, groupBy, having, limit), - Select(SetQuantifier.ALL, projection, from, where, groupBy, having, limit, emptyMeta).children) + assertEquals(listOf(projection, from, fromLet, where, groupBy, having, limit), + Select(SetQuantifier.ALL, projection, from, fromLet, where, groupBy, having, limit, emptyMeta).children) } @Test diff --git a/lang/test/org/partiql/lang/errors/ParserErrorsTest.kt b/lang/test/org/partiql/lang/errors/ParserErrorsTest.kt index 37b643d97f..c4f083ab09 100644 --- a/lang/test/org/partiql/lang/errors/ParserErrorsTest.kt +++ b/lang/test/org/partiql/lang/errors/ParserErrorsTest.kt @@ -286,6 +286,17 @@ class ParserErrorsTest : TestBase() { Property.TOKEN_VALUE to ion.newInt(1))) } + @Test + fun expectedAsForLet() { + checkInputThrowingParserException("SELECT a FROM foo LET bar b", + ErrorCode.PARSE_EXPECTED_AS_FOR_LET, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 27L, + Property.TOKEN_TYPE to TokenType.IDENTIFIER, + Property.TOKEN_VALUE to ion.newSymbol("b"))) + } + @Test fun expectedIdentForAlias() { checkInputThrowingParserException("select a as true from data", @@ -310,6 +321,17 @@ class ParserErrorsTest : TestBase() { } + @Test + fun expectedIdentForAliasLet() { + checkInputThrowingParserException("SELECT a FROM foo LET bar AS", + ErrorCode.PARSE_EXPECTED_IDENT_FOR_ALIAS, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 29L, + Property.TOKEN_TYPE to TokenType.EOF, + Property.TOKEN_VALUE to ion.newSymbol("EOF"))) + } + @Test fun substringMissingLeftParen() { //12345678901234567890123456789 diff --git a/lang/test/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt b/lang/test/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt new file mode 100644 index 0000000000..cf99bde394 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/EvaluatingCompilerFromLetTests.kt @@ -0,0 +1,149 @@ +package org.partiql.lang.eval + +import org.junit.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ArgumentsSource +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import org.partiql.lang.util.ArgumentsProviderBase +import org.partiql.lang.util.to + +class EvaluatingCompilerFromLetTests : EvaluatorTestBase() { + + private val session = mapOf("A" to "[ { id : 1 } ]", + "B" to "[ { id : 100 }, { id : 200 } ]", + "C" to """[ { name: 'foo', region: 'NA' }, + { name: 'foobar', region: 'EU' }, + { name: 'foobarbaz', region: 'NA' } ]""").toSession() + + class ArgsProviderValid : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + // LET used in WHERE + EvaluatorTestCase( + "SELECT * FROM A LET 1 AS X WHERE X = 1", + """<< {'id': 1} >>"""), + // LET used in SELECT + EvaluatorTestCase( + "SELECT X FROM A LET 1 AS X", + """<< {'X': 1} >>"""), + // LET used in GROUP BY + EvaluatorTestCase( + "SELECT * FROM C LET region AS X GROUP BY X", + """<< {'X': `EU`}, {'X': `NA`} >>"""), + // LET used in projection after GROUP BY + EvaluatorTestCase( + "SELECT foo FROM B LET 100 AS foo GROUP BY B.id, foo", + """<< {'foo': 100}, {'foo': 100} >>"""), + // LET used in HAVING after GROUP BY + EvaluatorTestCase( + "SELECT B.id FROM B LET 100 AS foo GROUP BY B.id, foo HAVING B.id > foo", + """<< {'id': 200} >>"""), + // LET shadowed binding + EvaluatorTestCase( + "SELECT X FROM A LET 1 AS X, 2 AS X", + """<< {'X': 2} >>"""), + // LET shadowing FROM binding + EvaluatorTestCase( + "SELECT * FROM A LET 100 AS A", + """<< {'_1': 100} >>"""), + // LET using other variables + EvaluatorTestCase( + "SELECT X, Y FROM A LET 1 AS X, X + 1 AS Y", + """<< {'X': 1, 'Y': 2} >>"""), + // LET recursive binding + EvaluatorTestCase( + "SELECT X FROM A LET 1 AS X, X AS X", + """<< {'X': 1} >>"""), + // LET calling function + EvaluatorTestCase( + "SELECT X FROM A LET upper('foo') AS X", + """<< {'X': 'FOO'} >>"""), + // LET calling function on each row + EvaluatorTestCase( + "SELECT nameLength FROM C LET char_length(C.name) AS nameLength", + """<< {'nameLength': 3}, {'nameLength': 6}, {'nameLength': 9} >>"""), + // LET calling function with GROUP BY and aggregation + EvaluatorTestCase( + "SELECT C.region, MAX(nameLength) AS maxLen FROM C LET char_length(C.name) AS nameLength GROUP BY C.region", + """<< {'region': `EU`, 'maxLen': 6}, {'region': `NA`, 'maxLen': 9} >>"""), + // LET outer query has correct value + EvaluatorTestCase( + "SELECT X FROM (SELECT VALUE X FROM A LET 1 AS X) LET 2 AS X", + """<< {'X': 2} >>""") + ) + } + + @ParameterizedTest + @ArgumentsSource(ArgsProviderValid::class) + fun validTests(tc: EvaluatorTestCase) = runTestCase(tc, session) + + class ArgsProviderError : ArgumentsProviderBase() { + override fun getParameters(): List = listOf( + // LET unbound variable + EvaluatorErrorTestCase( + "SELECT X FROM A LET Y AS X", + ErrorCode.EVALUATOR_BINDING_DOES_NOT_EXIST, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 21L, + Property.BINDING_NAME to "Y" + ) + ), + // LET binding definition dependent on later binding + EvaluatorErrorTestCase( + "SELECT X FROM A LET 1 AS X, Y AS Z, 3 AS Y", + ErrorCode.EVALUATOR_BINDING_DOES_NOT_EXIST, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 29L, + Property.BINDING_NAME to "Y" + ) + ), + // LET inner query binding not available in outer query + EvaluatorErrorTestCase( + "SELECT X FROM A LET Y AS X", + "SELECT X FROM (SELECT VALUE X FROM A LET 1 AS X)", + ErrorCode.EVALUATOR_BINDING_DOES_NOT_EXIST, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 8L, + Property.BINDING_NAME to "X" + ) + ), + // LET binding in subquery not in outer LET query + EvaluatorErrorTestCase( + "SELECT Z FROM A LET (SELECT 1 FROM A LET 1 AS X) AS Y, X AS Z", + ErrorCode.EVALUATOR_BINDING_DOES_NOT_EXIST, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 56L, + Property.BINDING_NAME to "X" + ) + ), + // LET binding referenced in HAVING not in GROUP BY + EvaluatorErrorTestCase( + "SELECT B.id FROM B LET 100 AS foo GROUP BY B.id HAVING B.id > foo", + ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 63L, + Property.BINDING_NAME to "foo" + ) + ), + // LET binding referenced in projection not in GROUP BY + EvaluatorErrorTestCase( + "SELECT foo FROM B LET 100 AS foo GROUP BY B.id", + ErrorCode.EVALUATOR_VARIABLE_NOT_INCLUDED_IN_GROUP_BY, + mapOf( + Property.LINE_NUMBER to 1L, + Property.COLUMN_NUMBER to 8L, + Property.BINDING_NAME to "foo" + ) + ) + ) + } + + @ParameterizedTest + @ArgumentsSource(ArgsProviderError::class) + fun errorTests(tc: EvaluatorErrorTestCase) = checkInputThrowingEvaluationException(tc, session) +} diff --git a/lang/test/org/partiql/lang/eval/EvaluatorErrorTestCase.kt b/lang/test/org/partiql/lang/eval/EvaluatorErrorTestCase.kt new file mode 100644 index 0000000000..365380a881 --- /dev/null +++ b/lang/test/org/partiql/lang/eval/EvaluatorErrorTestCase.kt @@ -0,0 +1,49 @@ +package org.partiql.lang.eval + +import org.partiql.lang.errors.ErrorCode +import org.partiql.lang.errors.Property +import kotlin.reflect.KClass + +/** + * Defines a error test case for query evaluation. + */ +data class EvaluatorErrorTestCase( + /** The "group" of the tests--this only appears in the IDE's test runner and can be used to identify where in the + * source code the test is defined. + */ + val groupName: String?, + + /** + * The query to be evaluated. + */ + val sqlUnderTest: String, + + /** + * The [ErrorCode] the query is to throw. + */ + val errorCode: ErrorCode? = null, + + /** + * The error context the query throws is to match this mapping. + */ + val expectErrorContextValues: Map, + + /** + * The Java exception that is equivalent to the thrown Kotlin exception + */ + val cause: KClass? = null) { + + constructor( + input: String, + errorCode: ErrorCode? = null, + expectErrorContextValues: Map, + cause: KClass? = null + ) : this(null, input, errorCode, expectErrorContextValues, cause) + + /** This will show up in the IDE's test runner. */ + override fun toString() : String { + val groupNameString = if (groupName == null) "" else "$groupName" + val causeString = if (cause == null) "" else ": $cause" + return "$groupNameString $sqlUnderTest : $errorCode : $expectErrorContextValues $causeString" + } +} diff --git a/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt b/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt index 2c570305fd..173216b62f 100644 --- a/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt +++ b/lang/test/org/partiql/lang/eval/EvaluatorTestBase.kt @@ -274,6 +274,23 @@ abstract class EvaluatorTestBase : TestBase() { } } + protected fun checkInputThrowingEvaluationException(tc: EvaluatorErrorTestCase, session: EvaluationSession) { + softAssert { + try { + val result = eval(tc.sqlUnderTest, session = session).ionValue; + fail("Expected EvaluationException but there was no Exception. " + + "The unepxected result was: \n${result.toPrettyString()}") + } + catch (e: EvaluationException) { + if (tc.cause != null) assertThat(e).hasRootCauseExactlyInstanceOf(tc.cause.java) + checkErrorAndErrorContext(tc.errorCode, e, tc.expectErrorContextValues) + } + catch (e: Exception) { + fail("Expected EvaluationException but a different exception was thrown:\n\t $e") + } + } + } + protected fun runTestCase(tc: EvaluatorTestCase, session: EvaluationSession) { fun showTestCase() { println("Query under test : ${tc.sqlUnderTest}") diff --git a/lang/test/org/partiql/lang/syntax/SqlParserTest.kt b/lang/test/org/partiql/lang/syntax/SqlParserTest.kt index cc082a77df..8432083240 100644 --- a/lang/test/org/partiql/lang/syntax/SqlParserTest.kt +++ b/lang/test/org/partiql/lang/syntax/SqlParserTest.kt @@ -14,10 +14,14 @@ package org.partiql.lang.syntax +import com.amazon.ionelement.api.ionInt +import com.amazon.ionelement.api.ionString import org.junit.Test import org.partiql.lang.ast.ExprNode import org.partiql.lang.ast.SourceLocationMeta import org.partiql.lang.ast.sourceLocation +import org.partiql.lang.domains.PartiqlAst +import org.partiql.lang.domains.id /** * Originally just meant to test the parser, this class now tests several different things because @@ -3092,4 +3096,66 @@ class SqlParserTest : SqlParserTestBase() { assertEquals(withoutSemicolon, withSemicolon) } + + //**************************************** + // LET clause parsing + //**************************************** + + private val projectX = PartiqlAst.build { projectList(projectExpr(id("x"))) } + + @Test + fun selectFromLetTest() = assertExpression("SELECT x FROM table1 LET 1 AS A") { + select( + project = projectX, + from = scan(id("table1")), + fromLet = let(letBinding(lit(ionInt(1)), "A")) + ) + } + + @Test + fun selectFromLetTwoBindingsTest() = assertExpression("SELECT x FROM table1 LET 1 AS A, 2 AS B") { + select( + project = projectX, + from = scan(id("table1")), + fromLet = let(letBinding(lit(ionInt(1)), "A"), letBinding(lit(ionInt(2)), "B")) + ) + } + + @Test + fun selectFromLetTableBindingTest() = assertExpression("SELECT x FROM table1 LET table1 AS A") { + select( + project = projectX, + from = scan(id("table1")), + fromLet = let(letBinding(id("table1"), "A")) + ) + } + + @Test + fun selectFromLetFunctionBindingTest() = assertExpression("SELECT x FROM table1 LET foo() AS A") { + select( + project = projectX, + from = scan(id("table1")), + fromLet = let(letBinding(call("foo", emptyList()), "A")) + ) + } + + @Test + fun selectFromLetFunctionWithLiteralsTest() = assertExpression( + "SELECT x FROM table1 LET foo(42, 'bar') AS A") { + select( + project = projectX, + from = scan(id("table1")), + fromLet = let(letBinding(call("foo", listOf(lit(ionInt(42)), lit(ionString("bar")))), "A")) + ) + } + + @Test + fun selectFromLetFunctionWithVariablesTest() = assertExpression( + "SELECT x FROM table1 LET foo(table1) AS A") { + select( + project = projectX, + from = scan(id("table1")), + fromLet = let(letBinding(call("foo", listOf(id("table1"))), "A")) + ) + } } diff --git a/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt b/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt index a88a27d4bf..d1a0fe9862 100644 --- a/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt +++ b/lang/test/org/partiql/lang/syntax/SqlParserTestBase.kt @@ -38,6 +38,22 @@ abstract class SqlParserTestBase : TestBase() { protected fun parse(source: String): ExprNode = parser.parseExprNode(source) protected fun parseToAst(source: String): PartiqlAst.Statement = parser.parseAstStatement(source) + protected fun assertExpression( + source: String, + pigBuilder: PartiqlAst.Builder.() -> PartiqlAst.PartiqlAstNode + ) { + val parsedExprNode = parse(source) + + val expectedPartiQlAst = PartiqlAst.build { pigBuilder() }.toIonElement().toString() + // Convert the query to ExprNode + + val partiqlAst = loadIonSexp(expectedPartiQlAst) + partiqlAssert(parsedExprNode, partiqlAst, source) + + pigDomainAssert(parsedExprNode, partiqlAst.toIonElement().asSexp()) + pigExprNodeTransformAsserts(parsedExprNode) + } + protected fun assertExpression( source: String, expectedSexpAstV0String: String, diff --git a/lang/test/org/partiql/lang/util/ArgumentsProvider.kt b/lang/test/org/partiql/lang/util/ArgumentsProvider.kt new file mode 100644 index 0000000000..63d6e91ecc --- /dev/null +++ b/lang/test/org/partiql/lang/util/ArgumentsProvider.kt @@ -0,0 +1,26 @@ +package org.partiql.lang.util + +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.ArgumentsProvider +import java.util.stream.Stream + +/** + * Reduces some of the boilerplate associated with the style of parameterized testing frequently + * utilized in this package. + * + * Since JUnit5 requires `@JvmStatic` on its `@MethodSource` argument factory methods, this requires all + * of the argument lists to reside in the companion object of a test class. This can be annoying since it + * forces the test to be separated from its tests cases. + * + * Classes that derive from this class can be defined near the `@ParameterizedTest` functions instead. + */ +abstract class ArgumentsProviderBase : ArgumentsProvider { + + abstract fun getParameters(): List + + @Throws(Exception::class) + override fun provideArguments(extensionContext: ExtensionContext): Stream? { + return getParameters().map { Arguments.of(it) }.stream() + } +}