|
| 1 | +/* |
| 2 | +Copyright 2026 The Spice.ai OSS Authors |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +*/ |
| 16 | + |
| 17 | +//! Inline literal table-function arguments into a [`LogicalPlan`]. |
| 18 | +//! |
| 19 | +//! SQL table functions expose their scalar arguments via a one-row `args` |
| 20 | +//! `MemTable`. Because the `MemTable` is only populated at execution time, |
| 21 | +//! the planner sees column references rather than literals, which prevents |
| 22 | +//! filter pushdown for connectors that need concrete values (e.g. the HTTP |
| 23 | +//! connector's `request_path`). |
| 24 | +//! |
| 25 | +//! Since all scalar args are guaranteed to be literals (enforced by |
| 26 | +//! [`super::sql::literal_arg`]), this module walks the *unoptimized* |
| 27 | +//! [`LogicalPlan`] produced by `ctx.sql(body)` and replaces every |
| 28 | +//! `Expr::Column` that references the `args` table with the corresponding |
| 29 | +//! `Expr::Literal`. The optimizer then sees constants and can fold / |
| 30 | +//! push down as usual. |
| 31 | +//! |
| 32 | +//! This follows the same pattern as `DataFusion`'s own |
| 33 | +//! [`LogicalPlan::replace_params_with_values`], which replaces |
| 34 | +//! `Expr::Placeholder` with `Expr::Literal`. |
| 35 | +
|
| 36 | +use std::collections::HashMap; |
| 37 | + |
| 38 | +use arrow::datatypes::Schema; |
| 39 | +use datafusion::{ |
| 40 | + common::{ |
| 41 | + Column, Result as DataFusionResult, |
| 42 | + tree_node::{Transformed, TreeNode}, |
| 43 | + }, |
| 44 | + logical_expr::{LogicalPlan, Projection, TableScan, expr::Alias}, |
| 45 | + prelude::Expr, |
| 46 | + scalar::ScalarValue, |
| 47 | +}; |
| 48 | + |
| 49 | +use super::sql::SQL_TABLE_ARGS_TABLE_NAME; |
| 50 | + |
| 51 | +/// Returns `true` if `table_name` refers to the `args` table. |
| 52 | +fn is_args_table_ref(table_name: &datafusion::sql::TableReference) -> bool { |
| 53 | + table_name |
| 54 | + .table() |
| 55 | + .eq_ignore_ascii_case(SQL_TABLE_ARGS_TABLE_NAME) |
| 56 | +} |
| 57 | + |
| 58 | +/// If `plan` is `Projection([single_expr], TableScan("args"))` and |
| 59 | +/// `single_expr` is a literal, return that literal. This detects the |
| 60 | +/// pattern left behind after column→literal replacement inside scalar |
| 61 | +/// subqueries. |
| 62 | +fn try_extract_literal_from_args_subquery(plan: &LogicalPlan) -> Option<Expr> { |
| 63 | + if let LogicalPlan::Projection(Projection { expr, input, .. }) = plan |
| 64 | + // Single projected expression over a TableScan on `args`. |
| 65 | + && let [expr] = expr.as_slice() |
| 66 | + && let LogicalPlan::TableScan(TableScan { table_name, .. }) = input.as_ref() |
| 67 | + && is_args_table_ref(table_name) |
| 68 | + { |
| 69 | + let inner = match expr { |
| 70 | + Expr::Alias(Alias { expr, .. }) => expr, |
| 71 | + other => other, |
| 72 | + }; |
| 73 | + if matches!(inner, Expr::Literal(..)) { |
| 74 | + return Some(inner.clone()); |
| 75 | + } |
| 76 | + } |
| 77 | + None |
| 78 | +} |
| 79 | + |
| 80 | +/// Recursively collapse `Expr::ScalarSubquery` nodes whose inner plan |
| 81 | +/// is `Projection([literal], TableScan("args"))`. |
| 82 | +/// |
| 83 | +/// `DataFusion`'s `Expr::transform_up` explicitly skips `ScalarSubquery` |
| 84 | +/// children, so we must walk the expression tree manually to find and |
| 85 | +/// replace them. |
| 86 | +fn collapse_args_subqueries(expr: Expr) -> Expr { |
| 87 | + match expr { |
| 88 | + Expr::ScalarSubquery(ref subquery) => { |
| 89 | + if let Some(literal) = try_extract_literal_from_args_subquery(&subquery.subquery) { |
| 90 | + literal |
| 91 | + } else { |
| 92 | + expr |
| 93 | + } |
| 94 | + } |
| 95 | + // Recurse into expression types that can contain ScalarSubquery. |
| 96 | + Expr::BinaryExpr(mut bin) => { |
| 97 | + *bin.left = collapse_args_subqueries(*bin.left); |
| 98 | + *bin.right = collapse_args_subqueries(*bin.right); |
| 99 | + Expr::BinaryExpr(bin) |
| 100 | + } |
| 101 | + Expr::Not(inner) => Expr::Not(Box::new(collapse_args_subqueries(*inner))), |
| 102 | + Expr::IsNotNull(inner) => Expr::IsNotNull(Box::new(collapse_args_subqueries(*inner))), |
| 103 | + Expr::IsNull(inner) => Expr::IsNull(Box::new(collapse_args_subqueries(*inner))), |
| 104 | + Expr::IsTrue(inner) => Expr::IsTrue(Box::new(collapse_args_subqueries(*inner))), |
| 105 | + Expr::IsFalse(inner) => Expr::IsFalse(Box::new(collapse_args_subqueries(*inner))), |
| 106 | + Expr::Negative(inner) => Expr::Negative(Box::new(collapse_args_subqueries(*inner))), |
| 107 | + Expr::Cast(mut cast) => { |
| 108 | + *cast.expr = collapse_args_subqueries(*cast.expr); |
| 109 | + Expr::Cast(cast) |
| 110 | + } |
| 111 | + Expr::TryCast(mut cast) => { |
| 112 | + *cast.expr = collapse_args_subqueries(*cast.expr); |
| 113 | + Expr::TryCast(cast) |
| 114 | + } |
| 115 | + Expr::Alias(mut alias) => { |
| 116 | + *alias.expr = collapse_args_subqueries(*alias.expr); |
| 117 | + Expr::Alias(alias) |
| 118 | + } |
| 119 | + Expr::ScalarFunction(mut func) => { |
| 120 | + func.args = func |
| 121 | + .args |
| 122 | + .into_iter() |
| 123 | + .map(collapse_args_subqueries) |
| 124 | + .collect(); |
| 125 | + Expr::ScalarFunction(func) |
| 126 | + } |
| 127 | + Expr::Case(mut case) => { |
| 128 | + case.expr = case.expr.map(|o| Box::new(collapse_args_subqueries(*o))); |
| 129 | + case.when_then_expr = case |
| 130 | + .when_then_expr |
| 131 | + .into_iter() |
| 132 | + .map(|(w, t)| { |
| 133 | + ( |
| 134 | + Box::new(collapse_args_subqueries(*w)), |
| 135 | + Box::new(collapse_args_subqueries(*t)), |
| 136 | + ) |
| 137 | + }) |
| 138 | + .collect(); |
| 139 | + case.else_expr = case |
| 140 | + .else_expr |
| 141 | + .map(|e| Box::new(collapse_args_subqueries(*e))); |
| 142 | + Expr::Case(case) |
| 143 | + } |
| 144 | + // For any other expression type, return as-is. |
| 145 | + other => other, |
| 146 | + } |
| 147 | +} |
| 148 | + |
| 149 | +/// Walk an unoptimized [`LogicalPlan`] and replace every `Expr::Column` |
| 150 | +/// referencing the `args` table with the corresponding `Expr::Literal`. |
| 151 | +/// |
| 152 | +/// Also collapses `Expr::ScalarSubquery` nodes that, after column |
| 153 | +/// replacement, reduce to a single literal projected from `args`. |
| 154 | +/// |
| 155 | +/// This operates on the plan produced *before* optimization, so the |
| 156 | +/// optimizer's filter-pushdown passes see concrete literal values instead |
| 157 | +/// of column references to a `MemTable`. |
| 158 | +pub(super) fn inline_args_into_plan( |
| 159 | + plan: LogicalPlan, |
| 160 | + schema: &Schema, |
| 161 | + values: &[ScalarValue], |
| 162 | +) -> DataFusionResult<LogicalPlan> { |
| 163 | + if schema.fields().is_empty() { |
| 164 | + return Ok(plan); |
| 165 | + } |
| 166 | + |
| 167 | + let arg_map: HashMap<String, ScalarValue> = schema |
| 168 | + .fields() |
| 169 | + .iter() |
| 170 | + .zip(values) |
| 171 | + .map(|(field, value)| (field.name().to_ascii_lowercase(), value.clone())) |
| 172 | + .collect(); |
| 173 | + |
| 174 | + // Pass 1: replace `Expr::Column` refs to `args` with literals inside |
| 175 | + // all plans (including subquery plans). |
| 176 | + let plan = plan |
| 177 | + .transform_up_with_subqueries(|plan| { |
| 178 | + plan.map_expressions(|expr| { |
| 179 | + expr.transform_up(|e| { |
| 180 | + if let Expr::Column(Column { |
| 181 | + ref relation, |
| 182 | + ref name, |
| 183 | + .. |
| 184 | + }) = e |
| 185 | + { |
| 186 | + let key = name.to_ascii_lowercase(); |
| 187 | + let should_replace = match relation { |
| 188 | + Some(r) => is_args_table_ref(r) && arg_map.contains_key(&key), |
| 189 | + None => arg_map.contains_key(&key), |
| 190 | + }; |
| 191 | + if should_replace && let Some(value) = arg_map.get(&key) { |
| 192 | + return Ok(Transformed::yes(Expr::Literal(value.clone(), None))); |
| 193 | + } |
| 194 | + } |
| 195 | + Ok(Transformed::no(e)) |
| 196 | + }) |
| 197 | + }) |
| 198 | + })? |
| 199 | + .data; |
| 200 | + |
| 201 | + // Pass 2: collapse `Expr::ScalarSubquery` nodes whose inner plan is |
| 202 | + // now `Projection([literal], TableScan("args"))`. This turns the |
| 203 | + // subquery into a bare literal so the optimizer can push it down. |
| 204 | + // |
| 205 | + // We use a manual expression walk because DataFusion's |
| 206 | + // `Expr::transform_up` explicitly skips `ScalarSubquery` children. |
| 207 | + plan.transform_up_with_subqueries(|plan| { |
| 208 | + plan.map_expressions(|expr| { |
| 209 | + let collapsed = collapse_args_subqueries(expr.clone()); |
| 210 | + if collapsed == expr { |
| 211 | + Ok(Transformed::no(expr)) |
| 212 | + } else { |
| 213 | + Ok(Transformed::yes(collapsed)) |
| 214 | + } |
| 215 | + }) |
| 216 | + }) |
| 217 | + .map(|res| res.data) |
| 218 | +} |
| 219 | + |
| 220 | +#[cfg(test)] |
| 221 | +mod tests { |
| 222 | + use super::*; |
| 223 | + use arrow::datatypes::Field as ArrowField; |
| 224 | + use datafusion::datasource::MemTable; |
| 225 | + use datafusion::prelude::SessionContext; |
| 226 | + use std::sync::Arc; |
| 227 | + |
| 228 | + /// Register a one-row `args` `MemTable` and plan the body SQL, returning |
| 229 | + /// the unoptimized plan. |
| 230 | + async fn plan_body(body: &str, schema: &Schema, values: &[ScalarValue]) -> LogicalPlan { |
| 231 | + let ctx = SessionContext::new(); |
| 232 | + let schema_ref = Arc::new(schema.clone()); |
| 233 | + |
| 234 | + // Build the one-row args MemTable. |
| 235 | + let arrays: Vec<_> = values |
| 236 | + .iter() |
| 237 | + .map(|v| v.to_array().expect("to_array")) |
| 238 | + .collect(); |
| 239 | + let batch = arrow::record_batch::RecordBatch::try_new(Arc::clone(&schema_ref), arrays) |
| 240 | + .expect("batch"); |
| 241 | + let table = MemTable::try_new(schema_ref, vec![vec![batch]]).expect("memtable"); |
| 242 | + ctx.register_table("args", Arc::new(table)) |
| 243 | + .expect("register"); |
| 244 | + |
| 245 | + // Also register a dummy `raw_users` table so body SQL can reference it. |
| 246 | + let users_schema = Arc::new(arrow::datatypes::Schema::new(vec![ |
| 247 | + ArrowField::new("content", arrow::datatypes::DataType::Utf8, true), |
| 248 | + ArrowField::new("request_path", arrow::datatypes::DataType::Utf8, true), |
| 249 | + ])); |
| 250 | + let users_table = MemTable::try_new(users_schema, vec![vec![]]).expect("users memtable"); |
| 251 | + ctx.register_table("raw_users", Arc::new(users_table)) |
| 252 | + .expect("register users"); |
| 253 | + |
| 254 | + // Also register a dummy `t` table. |
| 255 | + let t_schema = Arc::new(arrow::datatypes::Schema::new(vec![ |
| 256 | + ArrowField::new("id", arrow::datatypes::DataType::Int64, true), |
| 257 | + ArrowField::new("name", arrow::datatypes::DataType::Utf8, true), |
| 258 | + ArrowField::new("active", arrow::datatypes::DataType::Boolean, true), |
| 259 | + ArrowField::new("col", arrow::datatypes::DataType::Utf8, true), |
| 260 | + ])); |
| 261 | + let t_table = MemTable::try_new(t_schema, vec![vec![]]).expect("t memtable"); |
| 262 | + ctx.register_table("t", Arc::new(t_table)) |
| 263 | + .expect("register t"); |
| 264 | + |
| 265 | + ctx.sql(body) |
| 266 | + .await |
| 267 | + .expect("plan body") |
| 268 | + .into_unoptimized_plan() |
| 269 | + } |
| 270 | + |
| 271 | + fn utf8_schema(names: &[&str]) -> Schema { |
| 272 | + Schema::new( |
| 273 | + names |
| 274 | + .iter() |
| 275 | + .map(|n| ArrowField::new(*n, arrow::datatypes::DataType::Utf8, true)) |
| 276 | + .collect::<Vec<_>>(), |
| 277 | + ) |
| 278 | + } |
| 279 | + |
| 280 | + /// Format the plan to a string for assertion. |
| 281 | + fn plan_str(plan: &LogicalPlan) -> String { |
| 282 | + format!("{plan}") |
| 283 | + } |
| 284 | + |
| 285 | + #[tokio::test] |
| 286 | + async fn inline_replaces_scalar_subquery() { |
| 287 | + let schema = utf8_schema(&["username"]); |
| 288 | + let values = vec![ScalarValue::Utf8(Some("pg".into()))]; |
| 289 | + let body = "SELECT content FROM raw_users WHERE request_path = (SELECT username FROM args)"; |
| 290 | + let plan = plan_body(body, &schema, &values).await; |
| 291 | + let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite"); |
| 292 | + let s = plan_str(&rewritten); |
| 293 | + assert!( |
| 294 | + s.contains("Utf8(\"pg\")"), |
| 295 | + "Expected inlined literal in plan: {s}" |
| 296 | + ); |
| 297 | + } |
| 298 | + |
| 299 | + #[tokio::test] |
| 300 | + async fn inline_replaces_expression_over_args() { |
| 301 | + let schema = utf8_schema(&["username"]); |
| 302 | + let values = vec![ScalarValue::Utf8(Some("pg".into()))]; |
| 303 | + let body = "SELECT content FROM raw_users WHERE request_path = (SELECT concat('/users/', username) FROM args)"; |
| 304 | + let plan = plan_body(body, &schema, &values).await; |
| 305 | + let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite"); |
| 306 | + let s = plan_str(&rewritten); |
| 307 | + assert!( |
| 308 | + s.contains("Utf8(\"pg\")"), |
| 309 | + "Expected inlined literal in plan: {s}" |
| 310 | + ); |
| 311 | + assert!(s.contains("concat"), "Should still contain concat: {s}"); |
| 312 | + } |
| 313 | + |
| 314 | + #[tokio::test] |
| 315 | + async fn inline_replaces_from_args_direct() { |
| 316 | + // Body uses `FROM args` directly — columns should still be replaced. |
| 317 | + let schema = Schema::new(vec![ArrowField::new( |
| 318 | + "x", |
| 319 | + arrow::datatypes::DataType::Int64, |
| 320 | + true, |
| 321 | + )]); |
| 322 | + let values = vec![ScalarValue::Int64(Some(42))]; |
| 323 | + let body = "SELECT x AS value, x * 2 AS doubled FROM args"; |
| 324 | + let plan = plan_body(body, &schema, &values).await; |
| 325 | + let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite"); |
| 326 | + let s = plan_str(&rewritten); |
| 327 | + assert!( |
| 328 | + s.contains("Int64(42)"), |
| 329 | + "Expected inlined literal 42 in plan: {s}" |
| 330 | + ); |
| 331 | + } |
| 332 | + |
| 333 | + #[tokio::test] |
| 334 | + async fn inline_handles_multiple_args() { |
| 335 | + let schema = Schema::new(vec![ |
| 336 | + ArrowField::new("a", arrow::datatypes::DataType::Utf8, true), |
| 337 | + ArrowField::new("b", arrow::datatypes::DataType::Int64, true), |
| 338 | + ]); |
| 339 | + let values = vec![ |
| 340 | + ScalarValue::Utf8(Some("hello".into())), |
| 341 | + ScalarValue::Int64(Some(99)), |
| 342 | + ]; |
| 343 | + let body = |
| 344 | + "SELECT * FROM t WHERE name = (SELECT a FROM args) AND id = (SELECT b FROM args)"; |
| 345 | + let plan = plan_body(body, &schema, &values).await; |
| 346 | + let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite"); |
| 347 | + let s = plan_str(&rewritten); |
| 348 | + assert!( |
| 349 | + s.contains("Utf8(\"hello\")"), |
| 350 | + "Expected inlined 'hello': {s}" |
| 351 | + ); |
| 352 | + assert!(s.contains("Int64(99)"), "Expected inlined 99: {s}"); |
| 353 | + } |
| 354 | + |
| 355 | + #[tokio::test] |
| 356 | + async fn inline_empty_schema_is_noop() { |
| 357 | + let schema = Schema::empty(); |
| 358 | + let values: Vec<ScalarValue> = vec![]; |
| 359 | + // With an empty schema, just plan against `t` directly (no args table). |
| 360 | + let ctx = SessionContext::new(); |
| 361 | + let t_schema = Arc::new(Schema::new(vec![ArrowField::new( |
| 362 | + "id", |
| 363 | + arrow::datatypes::DataType::Int64, |
| 364 | + true, |
| 365 | + )])); |
| 366 | + let t_table = MemTable::try_new(t_schema, vec![vec![]]).expect("t memtable"); |
| 367 | + ctx.register_table("t", Arc::new(t_table)) |
| 368 | + .expect("register t"); |
| 369 | + let plan = ctx |
| 370 | + .sql("SELECT * FROM t") |
| 371 | + .await |
| 372 | + .expect("plan") |
| 373 | + .into_unoptimized_plan(); |
| 374 | + let original = plan_str(&plan); |
| 375 | + let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite"); |
| 376 | + assert_eq!(original, plan_str(&rewritten)); |
| 377 | + } |
| 378 | + |
| 379 | + #[tokio::test] |
| 380 | + async fn inline_handles_boolean_and_null() { |
| 381 | + let schema = Schema::new(vec![ArrowField::new( |
| 382 | + "flag", |
| 383 | + arrow::datatypes::DataType::Boolean, |
| 384 | + true, |
| 385 | + )]); |
| 386 | + let values = vec![ScalarValue::Boolean(Some(true))]; |
| 387 | + let body = "SELECT * FROM t WHERE active = (SELECT flag FROM args)"; |
| 388 | + let plan = plan_body(body, &schema, &values).await; |
| 389 | + let rewritten = inline_args_into_plan(plan, &schema, &values).expect("rewrite"); |
| 390 | + let s = plan_str(&rewritten); |
| 391 | + assert!(s.contains("Boolean(true)"), "Expected inlined boolean: {s}"); |
| 392 | + } |
| 393 | +} |
0 commit comments