Skip to content

Commit 3e4fff4

Browse files
committed
fix(mysql): infer LIMIT placeholders in prepare
Signed-off-by: discord9 <discord9@163.com>
1 parent 15fc148 commit 3e4fff4

2 files changed

Lines changed: 79 additions & 9 deletions

File tree

src/query/src/planner.rs

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -494,14 +494,45 @@ impl DfLogicalPlanner {
494494
Ok(())
495495
}
496496

497+
fn infer_limit_placeholder_types(
498+
plan: &LogicalPlan,
499+
placeholder_types: &mut HashMap<String, Option<DataType>>,
500+
) -> Result<()> {
501+
plan.apply(|node| {
502+
if let LogicalPlan::Limit(limit) = node {
503+
for expr in limit.skip.iter().chain(limit.fetch.iter()) {
504+
expr.apply(|e| {
505+
if let DfExpr::Placeholder(ph) = e {
506+
placeholder_types
507+
.entry(ph.id.clone())
508+
.and_modify(|existing| {
509+
if existing.is_none() {
510+
*existing = Some(DataType::Int64);
511+
}
512+
})
513+
.or_insert(Some(DataType::Int64));
514+
}
515+
516+
Ok(TreeNodeRecursion::Continue)
517+
})?;
518+
}
519+
}
520+
521+
Ok(TreeNodeRecursion::Continue)
522+
})?;
523+
524+
Ok(())
525+
}
526+
497527
/// Gets inferred parameter types from a logical plan.
498528
/// Returns a map where each parameter ID is mapped to:
499529
/// - Some(DataType) if the parameter type could be inferred
500530
/// - None if the parameter type could not be inferred
501531
///
502532
/// This function first uses DataFusion's `get_parameter_types()` to infer types.
503533
/// If any parameters have `None` values (i.e., DataFusion couldn't infer their types),
504-
/// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts.
534+
/// it falls back to using `extract_placeholder_cast_types()` to detect explicit casts
535+
/// and applies context-specific inference such as LIMIT/OFFSET placeholders.
505536
///
506537
/// This is because datafusion can only infer types for a limited cases.
507538
///
@@ -510,19 +541,15 @@ impl DfLogicalPlanner {
510541
pub fn get_inferred_parameter_types(
511542
plan: &LogicalPlan,
512543
) -> Result<HashMap<String, Option<DataType>>> {
513-
let param_types = plan.get_parameter_types().context(PlanSqlSnafu)?;
544+
let mut param_types = plan.get_parameter_types().context(PlanSqlSnafu)?;
514545

515546
let has_none = param_types.values().any(|v| v.is_none());
516547

517-
if !has_none {
518-
Ok(param_types)
519-
} else {
548+
if has_none {
520549
let cast_types = Self::extract_placeholder_cast_types(plan)?;
521550

522-
let mut merged = param_types;
523-
524551
for (id, opt_type) in cast_types {
525-
merged
552+
param_types
526553
.entry(id)
527554
.and_modify(|existing| {
528555
if existing.is_none() {
@@ -532,8 +559,10 @@ impl DfLogicalPlanner {
532559
.or_insert(opt_type);
533560
}
534561

535-
Ok(merged)
562+
Self::infer_limit_placeholder_types(plan, &mut param_types)?;
536563
}
564+
565+
Ok(param_types)
537566
}
538567
}
539568

@@ -793,6 +822,15 @@ mod tests {
793822
assert_eq!(type_3, &Some(DataType::Int32));
794823
}
795824

825+
#[tokio::test]
826+
async fn test_get_inferred_parameter_types_limit_offset() {
827+
let plan = parse_sql_to_plan("SELECT id FROM test LIMIT $1 OFFSET $2").await;
828+
let types = DfLogicalPlanner::get_inferred_parameter_types(&plan).unwrap();
829+
830+
assert_eq!(types.get("$1"), Some(&Some(DataType::Int64)));
831+
assert_eq!(types.get("$2"), Some(&Some(DataType::Int64)));
832+
}
833+
796834
#[tokio::test]
797835
async fn test_plan_pql_applies_extension_rules() {
798836
for inner_agg in ["count", "sum", "avg", "min", "max", "stddev", "stdvar"] {

src/servers/tests/mysql/mysql_server_test.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,38 @@ async fn test_query_prepared() -> Result<()> {
516516
_ => unreachable!(),
517517
}
518518

519+
// Regression test for #8142: LIMIT ? should work in prepared statements.
520+
// The LIMIT placeholder should be inferred as Int64 so the MySQL prepare
521+
// response advertises the correct parameter count.
522+
{
523+
let stmt = connection
524+
.prep("SELECT uint32s FROM all_datatypes LIMIT ?")
525+
.await
526+
.unwrap();
527+
let rows: Vec<Row> = connection
528+
.exec(stmt, vec![mysql_async::Value::Int(1)])
529+
.await
530+
.unwrap();
531+
assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row");
532+
}
533+
534+
// Also cover mixed placeholders: the WHERE placeholder is inferred from
535+
// the column type and the LIMIT placeholder is inferred from its context.
536+
{
537+
let stmt = connection
538+
.prep("SELECT uint32s FROM all_datatypes WHERE uint32s >= ? LIMIT ?")
539+
.await
540+
.unwrap();
541+
let rows: Vec<Row> = connection
542+
.exec(
543+
stmt,
544+
vec![mysql_async::Value::UInt(0), mysql_async::Value::UInt(1)],
545+
)
546+
.await
547+
.unwrap();
548+
assert_eq!(rows.len(), 1, "LIMIT 1 should return 1 row");
549+
}
550+
519551
Ok(())
520552
}
521553

0 commit comments

Comments
 (0)