Skip to content

Commit 36fe55b

Browse files
authored
LET (from FROM clauses) implementation (#303)
1 parent 24b2cf6 commit 36fe55b

20 files changed

+539
-18
lines changed

lang/src/org/partiql/lang/ast/AstSerialization.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,11 @@ private class AstSerializerImpl(val astVersion: AstVersion, val ion: IonSystem):
202202
}
203203

204204
private fun IonWriterContext.writeSelect(expr: Select) {
205-
val (setQuantifier, projection, from, where, groupBy, having, limit, _: MetaContainer) = expr
205+
val (setQuantifier, projection, from, fromLet, where, groupBy, having, limit, _: MetaContainer) = expr
206+
207+
if (fromLet != null) {
208+
throw UnsupportedOperationException("LET clause is not supported by the V0 AST")
209+
}
206210

207211
writeSelectProjection(projection, setQuantifier)
208212

lang/src/org/partiql/lang/ast/ExprNodeToStatement.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ fun ExprNode.toAstExpr(): PartiqlAst.Expr {
133133
},
134134
project = node.projection.toAstSelectProject(),
135135
from = node.from.toAstFromSource(),
136+
fromLet = node.fromLet?.toAstLetSource(),
136137
where = node.where?.toAstExpr(),
137138
group = node.groupBy?.toAstGroupSpec(),
138139
having = node.having?.toAstExpr(),
@@ -258,6 +259,17 @@ private fun FromSource.toAstFromSource(): PartiqlAst.FromSource {
258259
}
259260
}
260261

262+
private fun LetSource.toAstLetSource(): PartiqlAst.Let {
263+
val thiz = this
264+
return PartiqlAst.build {
265+
let(
266+
thiz.bindings.map {
267+
letBinding(it.expr.toAstExpr(), it.name.name)
268+
}
269+
)
270+
}
271+
}
272+
261273
private fun PathComponent.toAstPathStep(): PartiqlAst.PathStep {
262274
val thiz = this
263275
return PartiqlAst.build {

lang/src/org/partiql/lang/ast/StatementToExprNode.kt

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import org.partiql.lang.domains.PartiqlAst.FromSource
1010
import org.partiql.lang.domains.PartiqlAst.GroupBy
1111
import org.partiql.lang.domains.PartiqlAst.GroupingStrategy
1212
import org.partiql.lang.domains.PartiqlAst.JoinType
13+
import org.partiql.lang.domains.PartiqlAst.Let
1314
import org.partiql.lang.domains.PartiqlAst.PathStep
1415
import org.partiql.lang.domains.PartiqlAst.ProjectItem
1516
import org.partiql.lang.domains.PartiqlAst.Projection
@@ -167,6 +168,7 @@ private class StatementTransformer(val ion: IonSystem) {
167168
setQuantifier = setq?.toSetQuantifier() ?: org.partiql.lang.ast.SetQuantifier.ALL,
168169
projection = project.toSelectProjection(),
169170
from = from.toFromSource(),
171+
fromLet = fromLet?.toLetSource(),
170172
where = where?.toExprNode(),
171173
groupBy = group?.toGroupBy(),
172174
having = having?.toExprNode(),
@@ -232,7 +234,18 @@ private class StatementTransformer(val ion: IonSystem) {
232234
is JoinType.Full -> JoinOp.OUTER
233235
}
234236

235-
private fun SymbolPrimitive?.toSymbolicName() = this?.let { SymbolicName(it.text, it.metas.toPartiQlMetaContainer()) }
237+
private fun Let.toLetSource(): LetSource {
238+
return LetSource(
239+
this.letBindings.map {
240+
LetBinding(
241+
it.expr.toExprNode(),
242+
it.name.toSymbolicName()
243+
)
244+
}
245+
)
246+
}
247+
248+
private fun SymbolPrimitive.toSymbolicName() = SymbolicName(this.text, this.metas.toPartiQlMetaContainer())
236249

237250
private fun GroupBy.toGroupBy(): org.partiql.lang.ast.GroupBy =
238251
GroupBy(

lang/src/org/partiql/lang/ast/ast.kt

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,14 @@ data class Select(
333333
val setQuantifier: SetQuantifier = SetQuantifier.ALL,
334334
val projection: SelectProjection,
335335
val from: FromSource,
336+
val fromLet: LetSource? = null,
336337
val where: ExprNode? = null,
337338
val groupBy: GroupBy? = null,
338339
val having: ExprNode? = null,
339340
val limit: ExprNode? = null,
340341
override val metas: MetaContainer
341342
) : ExprNode() {
342-
override val children: List<AstNode> = listOfNotNull(projection, from, where, groupBy, having, limit)
343+
override val children: List<AstNode> = listOfNotNull(projection, from, fromLet, where, groupBy, having, limit)
343344
}
344345

345346
//********************************
@@ -633,6 +634,25 @@ data class FromSourceUnpivot(
633634
override val children: List<AstNode> = listOf(expr)
634635
}
635636

637+
//********************************
638+
// LET clause
639+
//********************************
640+
641+
/** Represents a list of LetBindings */
642+
data class LetSource(
643+
val bindings: List<LetBinding>
644+
) : AstNode() {
645+
override val children: List<AstNode> = bindings
646+
}
647+
648+
/** Represents `<expr> AS <name>` */
649+
data class LetBinding(
650+
val expr: ExprNode,
651+
val name: SymbolicName
652+
) : AstNode() {
653+
override val children: List<AstNode> = listOf(expr)
654+
}
655+
636656
/** For `GROUP [ PARTIAL ] BY <item>... [ GROUP AS <gropuName> ]`. */
637657
data class GroupBy(
638658
val grouping: GroupingStrategy,

lang/src/org/partiql/lang/ast/passes/AstRewriterBase.kt

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,17 @@ open class AstRewriterBase : AstRewriter {
140140
* The traversal order is in the SQL semantic order--that is:
141141
*
142142
* 1. `FROM`
143-
* 2. `WHERE`
144-
* 3. `GROUP BY`
145-
* 4. `HAVING`
146-
* 5. *projection*
147-
* 6. `ORDER BY` (to be implemented)
148-
* 7. `LIMIT`
143+
* 2. `LET`
144+
* 3. `WHERE`
145+
* 4. `GROUP BY`
146+
* 5. `HAVING`
147+
* 6. *projection*
148+
* 7. `ORDER BY` (to be implemented)
149+
* 8. `LIMIT`
149150
*/
150151
protected open fun innerRewriteSelect(selectExpr: Select): Select {
151152
val from = rewriteFromSource(selectExpr.from)
153+
val fromLet = selectExpr.fromLet?.let { rewriteLetSource(it) }
152154
val where = selectExpr.where?.let { rewriteSelectWhere(it) }
153155
val groupBy = selectExpr.groupBy?.let { rewriteGroupBy(it) }
154156
val having = selectExpr.having?.let { rewriteSelectHaving(it) }
@@ -160,6 +162,7 @@ open class AstRewriterBase : AstRewriter {
160162
setQuantifier = selectExpr.setQuantifier,
161163
projection = projection,
162164
from = from,
165+
fromLet = fromLet,
163166
where = where,
164167
groupBy = groupBy,
165168
having = having,
@@ -248,6 +251,12 @@ open class AstRewriterBase : AstRewriter {
248251
variables.atName?.let { rewriteSymbolicName(it) },
249252
variables.byName?.let { rewriteSymbolicName(it) })
250253

254+
open fun rewriteLetSource(letSource: LetSource) =
255+
LetSource(letSource.bindings.map { rewriteLetBinding(it) })
256+
257+
open fun rewriteLetBinding(letBinding: LetBinding): LetBinding =
258+
LetBinding(rewriteExprNode(letBinding.expr), rewriteSymbolicName(letBinding.name))
259+
251260
/**
252261
* This is called by the methods responsible for rewriting instances of the [FromSourceLet]
253262
* to rewrite their expression. This exists to provide a place for derived rewriters to

lang/src/org/partiql/lang/ast/passes/AstSanityValidator.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ object AstSanityValidator {
6262
}
6363
}
6464
is Select -> {
65-
val (_, projection, _, _, groupBy, having, _, metas) = node
65+
val (_, projection, _, _, _, groupBy, having, _, metas) = node
6666

6767
if(groupBy != null) {
6868
if (groupBy.grouping == GroupingStrategy.PARTIAL) {

lang/src/org/partiql/lang/ast/passes/AstWalker.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ open class AstWalker(private val visitor: AstVisitor) {
8787
}
8888
}
8989
is Select -> case {
90-
val (_, projection, from, where, groupBy, having, limit, _: MetaContainer) = expr
90+
val (_, projection, from, fromLet, where, groupBy, having, limit, _: MetaContainer) = expr
9191
walkSelectProjection(projection)
9292
walkFromSource(from)
9393
walkExprNode(where)

lang/src/org/partiql/lang/ast/passes/GroupByPathExpressionRewriter.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class GroupByPathExpressionRewriter(
9595
// The scope of the expressions in the FROM clause is the same as that of the parent scope.
9696
val from = this.rewriteFromSource(selectExpr.from)
9797

98+
val fromLet = selectExpr.fromLet?.let { unshadowedRewriter.rewriteLetSource(it) }
99+
98100
val where = selectExpr.where?.let { unshadowedRewriter.rewriteSelectWhere(it) }
99101

100102
val groupBy = selectExpr.groupBy?.let { unshadowedRewriter.rewriteGroupBy(it) }
@@ -109,6 +111,7 @@ class GroupByPathExpressionRewriter(
109111
setQuantifier = selectExpr.setQuantifier,
110112
projection = projection,
111113
from = from,
114+
fromLet = fromLet,
112115
where = where,
113116
groupBy = groupBy,
114117
having = having,

lang/src/org/partiql/lang/errors/ErrorCode.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,11 @@ enum class ErrorCode(private val category: ErrorCategory,
282282
LOC_TOKEN,
283283
"expected identifier for alias"),
284284

285+
PARSE_EXPECTED_AS_FOR_LET(
286+
ErrorCategory.PARSER,
287+
LOC_TOKEN,
288+
"expected AS for LET clause"),
289+
285290
PARSE_UNSUPPORTED_CALL_WITH_STAR(
286291
ErrorCategory.PARSER,
287292
LOC_TOKEN,

lang/src/org/partiql/lang/eval/EvaluatingCompiler.kt

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -890,27 +890,36 @@ internal class EvaluatingCompiler(
890890
}
891891

892892
private fun compileSelect(selectExpr: Select): ThunkEnv {
893-
// Get all the FROM source aliases for binding error checks
893+
// Get all the FROM source aliases and LET bindings for binding error checks
894894
val fold = object : PartiqlAst.VisitorFold<Set<String>>() {
895895
/** Store all the visited FROM source aliases in the accumulator */
896896
override fun visitFromSourceScan(node: PartiqlAst.FromSource.Scan, accumulator: Set<String>): Set<String> {
897897
val aliases = listOfNotNull(node.asAlias?.text, node.atAlias?.text, node.byAlias?.text)
898898
return accumulator + aliases.toSet()
899899
}
900900

901+
override fun visitLetBinding(node: PartiqlAst.LetBinding, accumulator: Set<String>): Set<String> {
902+
val aliases = listOfNotNull(node.name.text)
903+
return accumulator + aliases
904+
}
905+
901906
/** Prevents visitor from recursing into nested select statements */
902907
override fun walkExprSelect(node: PartiqlAst.Expr.Select, accumulator: Set<String>): Set<String> {
903908
return accumulator
904909
}
905910
}
906911
val pigGeneratedAst = selectExpr.toAstExpr() as PartiqlAst.Expr.Select
907912
val allFromSourceAliases = fold.walkFromSource(pigGeneratedAst.from, emptySet())
913+
.union(pigGeneratedAst.fromLet?.let { fold.walkLet(pigGeneratedAst.fromLet, emptySet()) } ?: emptySet())
908914

909915
return nestCompilationContext(ExpressionContext.NORMAL, emptySet()) {
910-
val (setQuantifier, projection, from, _, groupBy, having, _, metas: MetaContainer) = selectExpr
916+
val (setQuantifier, projection, from, fromLet, _, groupBy, having, _, metas: MetaContainer) = selectExpr
911917

912918
val fromSourceThunks = compileFromSources(from)
913-
val sourceThunks = compileQueryWithoutProjection(selectExpr, fromSourceThunks)
919+
920+
val letSourceThunks = fromLet?.let { compileLetSources(it) }
921+
922+
val sourceThunks = compileQueryWithoutProjection(selectExpr, fromSourceThunks, letSourceThunks)
914923

915924
// Returns a thunk that invokes [sourceThunks], and invokes [projectionThunk] to perform the projection.
916925
fun getQueryThunk(selectProjectionThunk: ThunkEnvValue<List<ExprValue>>): ThunkEnv {
@@ -1358,13 +1367,20 @@ internal class EvaluatingCompiler(
13581367

13591368
return sources
13601369
}
1370+
1371+
private fun compileLetSources(letSource: LetSource): List<CompiledLetSource> =
1372+
letSource.bindings.map {
1373+
CompiledLetSource(name = it.name.name, thunk = compileExprNode(it.expr))
1374+
}
1375+
13611376
/**
13621377
* Compiles the clauses of the SELECT or PIVOT into a thunk that does not generate
13631378
* the final projection.
13641379
*/
13651380
private fun compileQueryWithoutProjection(
13661381
ast: Select,
1367-
compiledSources: List<CompiledFromSource>
1382+
compiledSources: List<CompiledFromSource>,
1383+
compiledLetSources: List<CompiledLetSource>?
13681384
): (Environment) -> Sequence<FromProduction> {
13691385

13701386
val localsBinder = compiledSources.map { it.alias }.localsBinder(valueFactory.missingValue)
@@ -1438,6 +1454,24 @@ internal class EvaluatingCompiler(
14381454
// bind the joined value to the bindings for the filter/project
14391455
FromProduction(joinedValues, fromEnv.nest(localsBinder.bindLocals(joinedValues)))
14401456
}
1457+
// Nest LET bindings in the FROM environment
1458+
if (compiledLetSources != null) {
1459+
seq = seq.map { fromProduction ->
1460+
val parentEnv = fromProduction.env
1461+
1462+
val letEnv: Environment = compiledLetSources.fold(parentEnv) { accEnvironment, curLetSource ->
1463+
val letValue = curLetSource.thunk(accEnvironment)
1464+
val binding = Bindings.over { bindingName ->
1465+
when {
1466+
bindingName.isEquivalentTo(curLetSource.name) -> letValue
1467+
else -> null
1468+
}
1469+
}
1470+
accEnvironment.nest(newLocals = binding)
1471+
}
1472+
fromProduction.copy(env = letEnv)
1473+
}
1474+
}
14411475
if (whereThunk != null) {
14421476
seq = seq.filter { (_, env) ->
14431477
val whereClauseResult = whereThunk(env)
@@ -1931,6 +1965,10 @@ private enum class JoinExpansion {
19311965
OUTER
19321966
}
19331967

1968+
private data class CompiledLetSource(
1969+
val name: String,
1970+
val thunk: ThunkEnv)
1971+
19341972
private enum class ExpressionContext {
19351973
/**
19361974
* Indicates that the compiler is compiling a normal expression (i.e. not one of the other

0 commit comments

Comments
 (0)