Skip to content

Commit 6ac85a0

Browse files
authored
Enable filter pushdown in spicepod defined UDTFs (spiceai#11004)
* make string literals for HTTP pushdown * fix subqueries * clippy * lint again
1 parent 635c641 commit 6ac85a0

3 files changed

Lines changed: 401 additions & 2 deletions

File tree

Lines changed: 393 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,393 @@
1+
/*
2+
Copyright 2026 The Spice.ai OSS Authors
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
//! Inline literal table-function arguments into a [`LogicalPlan`].
18+
//!
19+
//! SQL table functions expose their scalar arguments via a one-row `args`
20+
//! `MemTable`. Because the `MemTable` is only populated at execution time,
21+
//! the planner sees column references rather than literals, which prevents
22+
//! filter pushdown for connectors that need concrete values (e.g. the HTTP
23+
//! connector's `request_path`).
24+
//!
25+
//! Since all scalar args are guaranteed to be literals (enforced by
26+
//! [`super::sql::literal_arg`]), this module walks the *unoptimized*
27+
//! [`LogicalPlan`] produced by `ctx.sql(body)` and replaces every
28+
//! `Expr::Column` that references the `args` table with the corresponding
29+
//! `Expr::Literal`. The optimizer then sees constants and can fold /
30+
//! push down as usual.
31+
//!
32+
//! This follows the same pattern as `DataFusion`'s own
33+
//! [`LogicalPlan::replace_params_with_values`], which replaces
34+
//! `Expr::Placeholder` with `Expr::Literal`.
35+
36+
use std::collections::HashMap;
37+
38+
use arrow::datatypes::Schema;
39+
use datafusion::{
40+
common::{
41+
Column, Result as DataFusionResult,
42+
tree_node::{Transformed, TreeNode},
43+
},
44+
logical_expr::{LogicalPlan, Projection, TableScan, expr::Alias},
45+
prelude::Expr,
46+
scalar::ScalarValue,
47+
};
48+
49+
use super::sql::SQL_TABLE_ARGS_TABLE_NAME;
50+
51+
/// Returns `true` if `table_name` refers to the `args` table.
52+
fn is_args_table_ref(table_name: &datafusion::sql::TableReference) -> bool {
53+
table_name
54+
.table()
55+
.eq_ignore_ascii_case(SQL_TABLE_ARGS_TABLE_NAME)
56+
}
57+
58+
/// If `plan` is `Projection([single_expr], TableScan("args"))` and
59+
/// `single_expr` is a literal, return that literal. This detects the
60+
/// pattern left behind after column→literal replacement inside scalar
61+
/// subqueries.
62+
fn try_extract_literal_from_args_subquery(plan: &LogicalPlan) -> Option<Expr> {
63+
if let LogicalPlan::Projection(Projection { expr, input, .. }) = plan
64+
// Single projected expression over a TableScan on `args`.
65+
&& let [expr] = expr.as_slice()
66+
&& let LogicalPlan::TableScan(TableScan { table_name, .. }) = input.as_ref()
67+
&& is_args_table_ref(table_name)
68+
{
69+
let inner = match expr {
70+
Expr::Alias(Alias { expr, .. }) => expr,
71+
other => other,
72+
};
73+
if matches!(inner, Expr::Literal(..)) {
74+
return Some(inner.clone());
75+
}
76+
}
77+
None
78+
}
79+
80+
/// Recursively collapse `Expr::ScalarSubquery` nodes whose inner plan
81+
/// is `Projection([literal], TableScan("args"))`.
82+
///
83+
/// `DataFusion`'s `Expr::transform_up` explicitly skips `ScalarSubquery`
84+
/// children, so we must walk the expression tree manually to find and
85+
/// replace them.
86+
fn collapse_args_subqueries(expr: Expr) -> Expr {
87+
match expr {
88+
Expr::ScalarSubquery(ref subquery) => {
89+
if let Some(literal) = try_extract_literal_from_args_subquery(&subquery.subquery) {
90+
literal
91+
} else {
92+
expr
93+
}
94+
}
95+
// Recurse into expression types that can contain ScalarSubquery.
96+
Expr::BinaryExpr(mut bin) => {
97+
*bin.left = collapse_args_subqueries(*bin.left);
98+
*bin.right = collapse_args_subqueries(*bin.right);
99+
Expr::BinaryExpr(bin)
100+
}
101+
Expr::Not(inner) => Expr::Not(Box::new(collapse_args_subqueries(*inner))),
102+
Expr::IsNotNull(inner) => Expr::IsNotNull(Box::new(collapse_args_subqueries(*inner))),
103+
Expr::IsNull(inner) => Expr::IsNull(Box::new(collapse_args_subqueries(*inner))),
104+
Expr::IsTrue(inner) => Expr::IsTrue(Box::new(collapse_args_subqueries(*inner))),
105+
Expr::IsFalse(inner) => Expr::IsFalse(Box::new(collapse_args_subqueries(*inner))),
106+
Expr::Negative(inner) => Expr::Negative(Box::new(collapse_args_subqueries(*inner))),
107+
Expr::Cast(mut cast) => {
108+
*cast.expr = collapse_args_subqueries(*cast.expr);
109+
Expr::Cast(cast)
110+
}
111+
Expr::TryCast(mut cast) => {
112+
*cast.expr = collapse_args_subqueries(*cast.expr);
113+
Expr::TryCast(cast)
114+
}
115+
Expr::Alias(mut alias) => {
116+
*alias.expr = collapse_args_subqueries(*alias.expr);
117+
Expr::Alias(alias)
118+
}
119+
Expr::ScalarFunction(mut func) => {
120+
func.args = func
121+
.args
122+
.into_iter()
123+
.map(collapse_args_subqueries)
124+
.collect();
125+
Expr::ScalarFunction(func)
126+
}
127+
Expr::Case(mut case) => {
128+
case.expr = case.expr.map(|o| Box::new(collapse_args_subqueries(*o)));
129+
case.when_then_expr = case
130+
.when_then_expr
131+
.into_iter()
132+
.map(|(w, t)| {
133+
(
134+
Box::new(collapse_args_subqueries(*w)),
135+
Box::new(collapse_args_subqueries(*t)),
136+
)
137+
})
138+
.collect();
139+
case.else_expr = case
140+
.else_expr
141+
.map(|e| Box::new(collapse_args_subqueries(*e)));
142+
Expr::Case(case)
143+
}
144+
// For any other expression type, return as-is.
145+
other => other,
146+
}
147+
}
148+
149+
/// Walk an unoptimized [`LogicalPlan`] and replace every `Expr::Column`
150+
/// referencing the `args` table with the corresponding `Expr::Literal`.
151+
///
152+
/// Also collapses `Expr::ScalarSubquery` nodes that, after column
153+
/// replacement, reduce to a single literal projected from `args`.
154+
///
155+
/// This operates on the plan produced *before* optimization, so the
156+
/// optimizer's filter-pushdown passes see concrete literal values instead
157+
/// of column references to a `MemTable`.
158+
pub(super) fn inline_args_into_plan(
159+
plan: LogicalPlan,
160+
schema: &Schema,
161+
values: &[ScalarValue],
162+
) -> DataFusionResult<LogicalPlan> {
163+
if schema.fields().is_empty() {
164+
return Ok(plan);
165+
}
166+
167+
let arg_map: HashMap<String, ScalarValue> = schema
168+
.fields()
169+
.iter()
170+
.zip(values)
171+
.map(|(field, value)| (field.name().to_ascii_lowercase(), value.clone()))
172+
.collect();
173+
174+
// Pass 1: replace `Expr::Column` refs to `args` with literals inside
175+
// all plans (including subquery plans).
176+
let plan = plan
177+
.transform_up_with_subqueries(|plan| {
178+
plan.map_expressions(|expr| {
179+
expr.transform_up(|e| {
180+
if let Expr::Column(Column {
181+
ref relation,
182+
ref name,
183+
..
184+
}) = e
185+
{
186+
let key = name.to_ascii_lowercase();
187+
let should_replace = match relation {
188+
Some(r) => is_args_table_ref(r) && arg_map.contains_key(&key),
189+
None => arg_map.contains_key(&key),
190+
};
191+
if should_replace && let Some(value) = arg_map.get(&key) {
192+
return Ok(Transformed::yes(Expr::Literal(value.clone(), None)));
193+
}
194+
}
195+
Ok(Transformed::no(e))
196+
})
197+
})
198+
})?
199+
.data;
200+
201+
// Pass 2: collapse `Expr::ScalarSubquery` nodes whose inner plan is
202+
// now `Projection([literal], TableScan("args"))`. This turns the
203+
// subquery into a bare literal so the optimizer can push it down.
204+
//
205+
// We use a manual expression walk because DataFusion's
206+
// `Expr::transform_up` explicitly skips `ScalarSubquery` children.
207+
plan.transform_up_with_subqueries(|plan| {
208+
plan.map_expressions(|expr| {
209+
let collapsed = collapse_args_subqueries(expr.clone());
210+
if collapsed == expr {
211+
Ok(Transformed::no(expr))
212+
} else {
213+
Ok(Transformed::yes(collapsed))
214+
}
215+
})
216+
})
217+
.map(|res| res.data)
218+
}
219+
220+
#[cfg(test)]
221+
mod tests {
222+
use super::*;
223+
use arrow::datatypes::Field as ArrowField;
224+
use datafusion::datasource::MemTable;
225+
use datafusion::prelude::SessionContext;
226+
use std::sync::Arc;
227+
228+
/// Register a one-row `args` `MemTable` and plan the body SQL, returning
229+
/// the unoptimized plan.
230+
async fn plan_body(body: &str, schema: &Schema, values: &[ScalarValue]) -> LogicalPlan {
231+
let ctx = SessionContext::new();
232+
let schema_ref = Arc::new(schema.clone());
233+
234+
// Build the one-row args MemTable.
235+
let arrays: Vec<_> = values
236+
.iter()
237+
.map(|v| v.to_array().expect("to_array"))
238+
.collect();
239+
let batch = arrow::record_batch::RecordBatch::try_new(Arc::clone(&schema_ref), arrays)
240+
.expect("batch");
241+
let table = MemTable::try_new(schema_ref, vec![vec![batch]]).expect("memtable");
242+
ctx.register_table("args", Arc::new(table))
243+
.expect("register");
244+
245+
// Also register a dummy `raw_users` table so body SQL can reference it.
246+
let users_schema = Arc::new(arrow::datatypes::Schema::new(vec![
247+
ArrowField::new("content", arrow::datatypes::DataType::Utf8, true),
248+
ArrowField::new("request_path", arrow::datatypes::DataType::Utf8, true),
249+
]));
250+
let users_table = MemTable::try_new(users_schema, vec![vec![]]).expect("users memtable");
251+
ctx.register_table("raw_users", Arc::new(users_table))
252+
.expect("register users");
253+
254+
// Also register a dummy `t` table.
255+
let t_schema = Arc::new(arrow::datatypes::Schema::new(vec![
256+
ArrowField::new("id", arrow::datatypes::DataType::Int64, true),
257+
ArrowField::new("name", arrow::datatypes::DataType::Utf8, true),
258+
ArrowField::new("active", arrow::datatypes::DataType::Boolean, true),
259+
ArrowField::new("col", arrow::datatypes::DataType::Utf8, true),
260+
]));
261+
let t_table = MemTable::try_new(t_schema, vec![vec![]]).expect("t memtable");
262+
ctx.register_table("t", Arc::new(t_table))
263+
.expect("register t");
264+
265+
ctx.sql(body)
266+
.await
267+
.expect("plan body")
268+
.into_unoptimized_plan()
269+
}
270+
271+
fn utf8_schema(names: &[&str]) -> Schema {
272+
Schema::new(
273+
names
274+
.iter()
275+
.map(|n| ArrowField::new(*n, arrow::datatypes::DataType::Utf8, true))
276+
.collect::<Vec<_>>(),
277+
)
278+
}
279+
280+
/// Format the plan to a string for assertion.
281+
fn plan_str(plan: &LogicalPlan) -> String {
282+
format!("{plan}")
283+
}
284+
285+
#[tokio::test]
286+
async fn inline_replaces_scalar_subquery() {
287+
let schema = utf8_schema(&["username"]);
288+
let values = vec![ScalarValue::Utf8(Some("pg".into()))];
289+
let body = "SELECT content FROM raw_users WHERE request_path = (SELECT username FROM args)";
290+
let plan = plan_body(body, &schema, &values).await;
291+
let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite");
292+
let s = plan_str(&rewritten);
293+
assert!(
294+
s.contains("Utf8(\"pg\")"),
295+
"Expected inlined literal in plan: {s}"
296+
);
297+
}
298+
299+
#[tokio::test]
300+
async fn inline_replaces_expression_over_args() {
301+
let schema = utf8_schema(&["username"]);
302+
let values = vec![ScalarValue::Utf8(Some("pg".into()))];
303+
let body = "SELECT content FROM raw_users WHERE request_path = (SELECT concat('/users/', username) FROM args)";
304+
let plan = plan_body(body, &schema, &values).await;
305+
let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite");
306+
let s = plan_str(&rewritten);
307+
assert!(
308+
s.contains("Utf8(\"pg\")"),
309+
"Expected inlined literal in plan: {s}"
310+
);
311+
assert!(s.contains("concat"), "Should still contain concat: {s}");
312+
}
313+
314+
#[tokio::test]
315+
async fn inline_replaces_from_args_direct() {
316+
// Body uses `FROM args` directly — columns should still be replaced.
317+
let schema = Schema::new(vec![ArrowField::new(
318+
"x",
319+
arrow::datatypes::DataType::Int64,
320+
true,
321+
)]);
322+
let values = vec![ScalarValue::Int64(Some(42))];
323+
let body = "SELECT x AS value, x * 2 AS doubled FROM args";
324+
let plan = plan_body(body, &schema, &values).await;
325+
let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite");
326+
let s = plan_str(&rewritten);
327+
assert!(
328+
s.contains("Int64(42)"),
329+
"Expected inlined literal 42 in plan: {s}"
330+
);
331+
}
332+
333+
#[tokio::test]
334+
async fn inline_handles_multiple_args() {
335+
let schema = Schema::new(vec![
336+
ArrowField::new("a", arrow::datatypes::DataType::Utf8, true),
337+
ArrowField::new("b", arrow::datatypes::DataType::Int64, true),
338+
]);
339+
let values = vec![
340+
ScalarValue::Utf8(Some("hello".into())),
341+
ScalarValue::Int64(Some(99)),
342+
];
343+
let body =
344+
"SELECT * FROM t WHERE name = (SELECT a FROM args) AND id = (SELECT b FROM args)";
345+
let plan = plan_body(body, &schema, &values).await;
346+
let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite");
347+
let s = plan_str(&rewritten);
348+
assert!(
349+
s.contains("Utf8(\"hello\")"),
350+
"Expected inlined 'hello': {s}"
351+
);
352+
assert!(s.contains("Int64(99)"), "Expected inlined 99: {s}");
353+
}
354+
355+
#[tokio::test]
356+
async fn inline_empty_schema_is_noop() {
357+
let schema = Schema::empty();
358+
let values: Vec<ScalarValue> = vec![];
359+
// With an empty schema, just plan against `t` directly (no args table).
360+
let ctx = SessionContext::new();
361+
let t_schema = Arc::new(Schema::new(vec![ArrowField::new(
362+
"id",
363+
arrow::datatypes::DataType::Int64,
364+
true,
365+
)]));
366+
let t_table = MemTable::try_new(t_schema, vec![vec![]]).expect("t memtable");
367+
ctx.register_table("t", Arc::new(t_table))
368+
.expect("register t");
369+
let plan = ctx
370+
.sql("SELECT * FROM t")
371+
.await
372+
.expect("plan")
373+
.into_unoptimized_plan();
374+
let original = plan_str(&plan);
375+
let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite");
376+
assert_eq!(original, plan_str(&rewritten));
377+
}
378+
379+
#[tokio::test]
380+
async fn inline_handles_boolean_and_null() {
381+
let schema = Schema::new(vec![ArrowField::new(
382+
"flag",
383+
arrow::datatypes::DataType::Boolean,
384+
true,
385+
)]);
386+
let values = vec![ScalarValue::Boolean(Some(true))];
387+
let body = "SELECT * FROM t WHERE active = (SELECT flag FROM args)";
388+
let plan = plan_body(body, &schema, &values).await;
389+
let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite");
390+
let s = plan_str(&rewritten);
391+
assert!(s.contains("Boolean(true)"), "Expected inlined boolean: {s}");
392+
}
393+
}

0 commit comments

Comments
 (0)