Skip to content

Commit

Permalink
LET (from FROM clauses) implementation (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
alancai98 authored Oct 6, 2020
1 parent 24b2cf6 commit 36fe55b
Show file tree
Hide file tree
Showing 20 changed files with 539 additions and 18 deletions.
6 changes: 5 additions & 1 deletion lang/src/org/partiql/lang/ast/AstSerialization.kt
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,11 @@ private class AstSerializerImpl(val astVersion: AstVersion, val ion: IonSystem):
}

private fun IonWriterContext.writeSelect(expr: Select) {
val (setQuantifier, projection, from, where, groupBy, having, limit, _: MetaContainer) = expr
val (setQuantifier, projection, from, fromLet, where, groupBy, having, limit, _: MetaContainer) = expr

if (fromLet != null) {
throw UnsupportedOperationException("LET clause is not supported by the V0 AST")
}

writeSelectProjection(projection, setQuantifier)

Expand Down
12 changes: 12 additions & 0 deletions lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ fun ExprNode.toAstExpr(): PartiqlAst.Expr {
},
project = node.projection.toAstSelectProject(),
from = node.from.toAstFromSource(),
fromLet = node.fromLet?.toAstLetSource(),
where = node.where?.toAstExpr(),
group = node.groupBy?.toAstGroupSpec(),
having = node.having?.toAstExpr(),
Expand Down Expand Up @@ -258,6 +259,17 @@ private fun FromSource.toAstFromSource(): PartiqlAst.FromSource {
}
}

private fun LetSource.toAstLetSource(): PartiqlAst.Let {
val thiz = this
return PartiqlAst.build {
let(
thiz.bindings.map {
letBinding(it.expr.toAstExpr(), it.name.name)
}
)
}
}

private fun PathComponent.toAstPathStep(): PartiqlAst.PathStep {
val thiz = this
return PartiqlAst.build {
Expand Down
15 changes: 14 additions & 1 deletion lang/src/org/partiql/lang/ast/StatementToExprNode.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.partiql.lang.domains.PartiqlAst.FromSource
import org.partiql.lang.domains.PartiqlAst.GroupBy
import org.partiql.lang.domains.PartiqlAst.GroupingStrategy
import org.partiql.lang.domains.PartiqlAst.JoinType
import org.partiql.lang.domains.PartiqlAst.Let
import org.partiql.lang.domains.PartiqlAst.PathStep
import org.partiql.lang.domains.PartiqlAst.ProjectItem
import org.partiql.lang.domains.PartiqlAst.Projection
Expand Down Expand Up @@ -167,6 +168,7 @@ private class StatementTransformer(val ion: IonSystem) {
setQuantifier = setq?.toSetQuantifier() ?: org.partiql.lang.ast.SetQuantifier.ALL,
projection = project.toSelectProjection(),
from = from.toFromSource(),
fromLet = fromLet?.toLetSource(),
where = where?.toExprNode(),
groupBy = group?.toGroupBy(),
having = having?.toExprNode(),
Expand Down Expand Up @@ -232,7 +234,18 @@ private class StatementTransformer(val ion: IonSystem) {
is JoinType.Full -> JoinOp.OUTER
}

private fun SymbolPrimitive?.toSymbolicName() = this?.let { SymbolicName(it.text, it.metas.toPartiQlMetaContainer()) }
private fun Let.toLetSource(): LetSource {
return LetSource(
this.letBindings.map {
LetBinding(
it.expr.toExprNode(),
it.name.toSymbolicName()
)
}
)
}

private fun SymbolPrimitive.toSymbolicName() = SymbolicName(this.text, this.metas.toPartiQlMetaContainer())

private fun GroupBy.toGroupBy(): org.partiql.lang.ast.GroupBy =
GroupBy(
Expand Down
22 changes: 21 additions & 1 deletion lang/src/org/partiql/lang/ast/ast.kt
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,14 @@ data class Select(
val setQuantifier: SetQuantifier = SetQuantifier.ALL,
val projection: SelectProjection,
val from: FromSource,
val fromLet: LetSource? = null,
val where: ExprNode? = null,
val groupBy: GroupBy? = null,
val having: ExprNode? = null,
val limit: ExprNode? = null,
override val metas: MetaContainer
) : ExprNode() {
override val children: List<AstNode> = listOfNotNull(projection, from, where, groupBy, having, limit)
override val children: List<AstNode> = listOfNotNull(projection, from, fromLet, where, groupBy, having, limit)
}

//********************************
Expand Down Expand Up @@ -633,6 +634,25 @@ data class FromSourceUnpivot(
override val children: List<AstNode> = listOf(expr)
}

//********************************
// LET clause
//********************************

/** Represents a list of LetBindings */
data class LetSource(
val bindings: List<LetBinding>
) : AstNode() {
override val children: List<AstNode> = bindings
}

/** Represents `<expr> AS <name>` */
data class LetBinding(
val expr: ExprNode,
val name: SymbolicName
) : AstNode() {
override val children: List<AstNode> = listOf(expr)
}

/** For `GROUP [ PARTIAL ] BY <item>... [ GROUP AS <gropuName> ]`. */
data class GroupBy(
val grouping: GroupingStrategy,
Expand Down
21 changes: 15 additions & 6 deletions lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,17 @@ open class AstRewriterBase : AstRewriter {
* The traversal order is in the SQL semantic order--that is:
*
* 1. `FROM`
* 2. `WHERE`
* 3. `GROUP BY`
* 4. `HAVING`
* 5. *projection*
* 6. `ORDER BY` (to be implemented)
* 7. `LIMIT`
* 2. `LET`
* 3. `WHERE`
* 4. `GROUP BY`
* 5. `HAVING`
* 6. *projection*
* 7. `ORDER BY` (to be implemented)
* 8. `LIMIT`
*/
protected open fun innerRewriteSelect(selectExpr: Select): Select {
val from = rewriteFromSource(selectExpr.from)
val fromLet = selectExpr.fromLet?.let { rewriteLetSource(it) }
val where = selectExpr.where?.let { rewriteSelectWhere(it) }
val groupBy = selectExpr.groupBy?.let { rewriteGroupBy(it) }
val having = selectExpr.having?.let { rewriteSelectHaving(it) }
Expand All @@ -160,6 +162,7 @@ open class AstRewriterBase : AstRewriter {
setQuantifier = selectExpr.setQuantifier,
projection = projection,
from = from,
fromLet = fromLet,
where = where,
groupBy = groupBy,
having = having,
Expand Down Expand Up @@ -248,6 +251,12 @@ open class AstRewriterBase : AstRewriter {
variables.atName?.let { rewriteSymbolicName(it) },
variables.byName?.let { rewriteSymbolicName(it) })

open fun rewriteLetSource(letSource: LetSource) =
LetSource(letSource.bindings.map { rewriteLetBinding(it) })

open fun rewriteLetBinding(letBinding: LetBinding): LetBinding =
LetBinding(rewriteExprNode(letBinding.expr), rewriteSymbolicName(letBinding.name))

/**
* This is called by the methods responsible for rewriting instances of the [FromSourceLet]
* to rewrite their expression. This exists to provide a place for derived rewriters to
Expand Down
2 changes: 1 addition & 1 deletion lang/src/org/partiql/lang/ast/passes/AstSanityValidator.kt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ object AstSanityValidator {
}
}
is Select -> {
val (_, projection, _, _, groupBy, having, _, metas) = node
val (_, projection, _, _, _, groupBy, having, _, metas) = node

if(groupBy != null) {
if (groupBy.grouping == GroupingStrategy.PARTIAL) {
Expand Down
2 changes: 1 addition & 1 deletion lang/src/org/partiql/lang/ast/passes/AstWalker.kt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ open class AstWalker(private val visitor: AstVisitor) {
}
}
is Select -> case {
val (_, projection, from, where, groupBy, having, limit, _: MetaContainer) = expr
val (_, projection, from, fromLet, where, groupBy, having, limit, _: MetaContainer) = expr
walkSelectProjection(projection)
walkFromSource(from)
walkExprNode(where)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class GroupByPathExpressionRewriter(
// The scope of the expressions in the FROM clause is the same as that of the parent scope.
val from = this.rewriteFromSource(selectExpr.from)

val fromLet = selectExpr.fromLet?.let { unshadowedRewriter.rewriteLetSource(it) }

val where = selectExpr.where?.let { unshadowedRewriter.rewriteSelectWhere(it) }

val groupBy = selectExpr.groupBy?.let { unshadowedRewriter.rewriteGroupBy(it) }
Expand All @@ -109,6 +111,7 @@ class GroupByPathExpressionRewriter(
setQuantifier = selectExpr.setQuantifier,
projection = projection,
from = from,
fromLet = fromLet,
where = where,
groupBy = groupBy,
having = having,
Expand Down
5 changes: 5 additions & 0 deletions lang/src/org/partiql/lang/errors/ErrorCode.kt
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,11 @@ enum class ErrorCode(private val category: ErrorCategory,
LOC_TOKEN,
"expected identifier for alias"),

PARSE_EXPECTED_AS_FOR_LET(
ErrorCategory.PARSER,
LOC_TOKEN,
"expected AS for LET clause"),

PARSE_UNSUPPORTED_CALL_WITH_STAR(
ErrorCategory.PARSER,
LOC_TOKEN,
Expand Down
46 changes: 42 additions & 4 deletions lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -890,27 +890,36 @@ internal class EvaluatingCompiler(
}

private fun compileSelect(selectExpr: Select): ThunkEnv {
// Get all the FROM source aliases for binding error checks
// Get all the FROM source aliases and LET bindings for binding error checks
val fold = object : PartiqlAst.VisitorFold<Set<String>>() {
/** Store all the visited FROM source aliases in the accumulator */
override fun visitFromSourceScan(node: PartiqlAst.FromSource.Scan, accumulator: Set<String>): Set<String> {
val aliases = listOfNotNull(node.asAlias?.text, node.atAlias?.text, node.byAlias?.text)
return accumulator + aliases.toSet()
}

override fun visitLetBinding(node: PartiqlAst.LetBinding, accumulator: Set<String>): Set<String> {
val aliases = listOfNotNull(node.name.text)
return accumulator + aliases
}

/** Prevents visitor from recursing into nested select statements */
override fun walkExprSelect(node: PartiqlAst.Expr.Select, accumulator: Set<String>): Set<String> {
return accumulator
}
}
val pigGeneratedAst = selectExpr.toAstExpr() as PartiqlAst.Expr.Select
val allFromSourceAliases = fold.walkFromSource(pigGeneratedAst.from, emptySet())
.union(pigGeneratedAst.fromLet?.let { fold.walkLet(pigGeneratedAst.fromLet, emptySet()) } ?: emptySet())

return nestCompilationContext(ExpressionContext.NORMAL, emptySet()) {
val (setQuantifier, projection, from, _, groupBy, having, _, metas: MetaContainer) = selectExpr
val (setQuantifier, projection, from, fromLet, _, groupBy, having, _, metas: MetaContainer) = selectExpr

val fromSourceThunks = compileFromSources(from)
val sourceThunks = compileQueryWithoutProjection(selectExpr, fromSourceThunks)

val letSourceThunks = fromLet?.let { compileLetSources(it) }

val sourceThunks = compileQueryWithoutProjection(selectExpr, fromSourceThunks, letSourceThunks)

// Returns a thunk that invokes [sourceThunks], and invokes [projectionThunk] to perform the projection.
fun getQueryThunk(selectProjectionThunk: ThunkEnvValue<List<ExprValue>>): ThunkEnv {
Expand Down Expand Up @@ -1358,13 +1367,20 @@ internal class EvaluatingCompiler(

return sources
}

private fun compileLetSources(letSource: LetSource): List<CompiledLetSource> =
letSource.bindings.map {
CompiledLetSource(name = it.name.name, thunk = compileExprNode(it.expr))
}

/**
* Compiles the clauses of the SELECT or PIVOT into a thunk that does not generate
* the final projection.
*/
private fun compileQueryWithoutProjection(
ast: Select,
compiledSources: List<CompiledFromSource>
compiledSources: List<CompiledFromSource>,
compiledLetSources: List<CompiledLetSource>?
): (Environment) -> Sequence<FromProduction> {

val localsBinder = compiledSources.map { it.alias }.localsBinder(valueFactory.missingValue)
Expand Down Expand Up @@ -1438,6 +1454,24 @@ internal class EvaluatingCompiler(
// bind the joined value to the bindings for the filter/project
FromProduction(joinedValues, fromEnv.nest(localsBinder.bindLocals(joinedValues)))
}
// Nest LET bindings in the FROM environment
if (compiledLetSources != null) {
seq = seq.map { fromProduction ->
val parentEnv = fromProduction.env

val letEnv: Environment = compiledLetSources.fold(parentEnv) { accEnvironment, curLetSource ->
val letValue = curLetSource.thunk(accEnvironment)
val binding = Bindings.over { bindingName ->
when {
bindingName.isEquivalentTo(curLetSource.name) -> letValue
else -> null
}
}
accEnvironment.nest(newLocals = binding)
}
fromProduction.copy(env = letEnv)
}
}
if (whereThunk != null) {
seq = seq.filter { (_, env) ->
val whereClauseResult = whereThunk(env)
Expand Down Expand Up @@ -1931,6 +1965,10 @@ private enum class JoinExpansion {
OUTER
}

private data class CompiledLetSource(
val name: String,
val thunk: ThunkEnv)

private enum class ExpressionContext {
/**
* Indicates that the compiler is compiling a normal expression (i.e. not one of the other
Expand Down
1 change: 1 addition & 0 deletions lang/src/org/partiql/lang/syntax/LexerConstants.kt
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ internal val DATE_PART_KEYWORDS: Set<String> = DatePart.values()
"tuple",
"remove",
"index",
"let",

// Ion type names

Expand Down
Loading

0 comments on commit 36fe55b

Please sign in to comment.