Skip to content

Commit d2df862

Browse files
committed
adapter: Automatically cast parameters of EXECUTE
1 parent 3a787ca commit d2df862

File tree

11 files changed

+180
-73
lines changed

11 files changed

+180
-73
lines changed

src/adapter/src/coord/sql.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ impl Coordinator {
6868
now: EpochMillis,
6969
) -> Result<(), AdapterError> {
7070
let param_types = params
71-
.types
71+
.actual_types
7272
.iter()
7373
.map(|ty| Some(ty.clone()))
7474
.collect::<Vec<_>>();
7575
let desc = describe(catalog, stmt.clone(), &param_types, session)?;
76-
let params = params.datums.into_iter().zip(params.types).collect();
76+
let params = params.datums.into_iter().zip(params.actual_types).collect();
7777
let result_formats = vec![mz_pgwire_common::Format::Text; desc.arity()];
7878
let logging = session.mint_logging(sql, Some(&stmt), now);
7979
session.set_portal(

src/adapter/src/coord/statement_logging.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ impl Coordinator {
727727
now,
728728
);
729729

730-
let params = std::iter::zip(params.types.iter(), params.datums.iter())
730+
let params = std::iter::zip(params.actual_types.iter(), params.datums.iter())
731731
.map(|(r#type, datum)| {
732732
mz_pgrepr::Value::from_datum(datum, r#type).map(|val| {
733733
let mut buf = BytesMut::new();

src/adapter/src/session.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ impl<T: TimestampManipulation> Session<T> {
710710
if !portal_name.is_empty() && self.portals.contains_key(&portal_name) {
711711
return Err(AdapterError::DuplicateCursor(portal_name));
712712
}
713+
let param_types = desc.param_types.clone();
713714
self.portals.insert(
714715
portal_name,
715716
Portal {
@@ -718,7 +719,8 @@ impl<T: TimestampManipulation> Session<T> {
718719
catalog_revision,
719720
parameters: Params {
720721
datums: Row::pack(params.iter().map(|(d, _t)| d)),
721-
types: params.into_iter().map(|(_d, t)| t).collect(),
722+
actual_types: params.into_iter().map(|(_d, t)| t).collect(),
723+
expected_types: param_types,
722724
},
723725
result_formats,
724726
state: PortalState::NotStarted,

src/sql/src/plan.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,15 +2042,17 @@ impl TryFrom<ClusterAlterOptionExtracted> for AlterClusterPlanStrategy {
20422042
#[derive(Debug, Clone)]
20432043
pub struct Params {
20442044
pub datums: Row,
2045-
pub types: Vec<ScalarType>,
2045+
pub actual_types: Vec<ScalarType>,
2046+
pub expected_types: Vec<ScalarType>,
20462047
}
20472048

20482049
impl Params {
20492050
/// Returns a `Params` with no parameters.
20502051
pub fn empty() -> Params {
20512052
Params {
20522053
datums: Row::pack_slice(&[]),
2053-
types: vec![],
2054+
actual_types: vec![],
2055+
expected_types: vec![],
20542056
}
20552057
}
20562058
}

src/sql/src/plan/error.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ pub enum PlanError {
101101
},
102102
UnknownParameter(usize),
103103
ParameterNotAllowed(String),
104+
WrongParameterType(usize, String, String),
104105
RecursionLimit(RecursionLimitError),
105106
StrconvParse(strconv::ParseError),
106107
Catalog(CatalogError),
@@ -479,6 +480,9 @@ impl PlanError {
479480
Self::NetworkPolicyInUse => {
480481
Some("Use ALTER SYSTEM SET 'network_policy' to change the default network policy.".into())
481482
}
483+
Self::WrongParameterType(_, _, _) => {
484+
Some("EXECUTE automatically inserts only such casts that are allowed in an assignment cast context. Try adding an explicit cast.".into())
485+
}
482486
_ => None,
483487
}
484488
}
@@ -577,6 +581,7 @@ impl fmt::Display for PlanError {
577581
}
578582
Self::UnknownParameter(n) => write!(f, "there is no parameter ${}", n),
579583
Self::ParameterNotAllowed(object_type) => write!(f, "{} cannot have parameters", object_type),
584+
Self::WrongParameterType(i, expected_ty, actual_ty) => write!(f, "unable to cast given parameter ${}: expected {}, got {}", i, expected_ty, actual_ty),
580585
Self::RecursionLimit(e) => write!(f, "{}", e),
581586
Self::StrconvParse(e) => write!(f, "{}", e),
582587
Self::Catalog(e) => write!(f, "{}", e),

src/sql/src/plan/hir.rs

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ use mz_repr::adt::numeric::NumericMaxScale;
3636
use mz_repr::*;
3737
use serde::{Deserialize, Serialize};
3838

39-
use crate::plan::Params;
4039
use crate::plan::error::PlanError;
41-
use crate::plan::query::ExprContext;
42-
use crate::plan::typeconv::{self, CastContext};
40+
use crate::plan::query::{EXECUTE_CAST_CONTEXT, ExprContext, execute_expr_context};
41+
use crate::plan::typeconv::{self, CastContext, plan_cast};
42+
use crate::plan::{Params, QueryContext, QueryLifetime, StatementContext};
4343

4444
use super::plan_utils::GroupSizeHints;
4545

@@ -167,24 +167,18 @@ pub enum HirRelationExpr {
167167
/// Column indices used to order rows within groups.
168168
order_key: Vec<ColumnOrder>,
169169
/// Number of records to retain.
170-
/// It is of ScalarType::Int64. (It's not entirely clear why not UInt64, see below at
171-
/// `offset`.)
170+
/// It is of ScalarType::Int64.
171+
/// (UInt64 would make sense in theory: Then we wouldn't need to manually check
172+
/// non-negativity, but would just get this for free when casting to UInt64. However, Int64
173+
/// is better for Postgres compat. This is because if there is a $1 here, then when external
174+
/// tools `describe` the prepared statement, they discover this type. If what they find
175+
/// were UInt64, then they might have trouble calling the prepared statement, because the
176+
/// unsigned types are non-standard, and also don't exist even in Postgres.)
172177
limit: Option<HirScalarExpr>,
173178
/// Number of records to skip.
174179
/// It is of ScalarType::Int64.
175180
/// This can contain parameters at first, but by the time we reach lowering, this should
176181
/// already be simply a Literal.
177-
///
178-
/// TODO: It's not clear why this is Int64 instead of UInt64. If it were UInt64, we wouldn't
179-
/// need to manually check non-negativity, but would just get this for free when casting to
180-
/// UInt64.
181-
/// There are some arguments for Postgres-compatibility, because Postgres expects an
182-
/// Int64. But it's not clear whether there is an actual situation where having a UInt64
183-
/// here would introduce a Postgres-incompatibility.
184-
/// Another thing is that our prepared statements currently have the limitation that the
185-
/// argument passed to EXECUTE must have exactly the same type as the parameter, i.e., it's
186-
/// not enough if it's castable. Since UInt64 is not a standard type, various tools might
187-
/// have trouble passing it.
188182
offset: HirScalarExpr,
189183
/// User-supplied hint: how many rows will have the same group key.
190184
expected_group_size: Option<u64>,
@@ -2112,10 +2106,15 @@ impl HirRelationExpr {
21122106

21132107
/// Replaces any parameter references in the expression with the
21142108
/// corresponding datum from `params`.
2115-
pub fn bind_parameters(&mut self, params: &Params) -> Result<(), PlanError> {
2109+
pub fn bind_parameters(
2110+
&mut self,
2111+
scx: &StatementContext,
2112+
lifetime: QueryLifetime,
2113+
params: &Params,
2114+
) -> Result<(), PlanError> {
21162115
#[allow(deprecated)]
21172116
self.visit_scalar_expressions_mut(0, &mut |e: &mut HirScalarExpr, _: usize| {
2118-
e.bind_parameters(params)
2117+
e.bind_parameters(scx, lifetime, params)
21192118
})
21202119
}
21212120

@@ -2984,15 +2983,20 @@ impl HirScalarExpr {
29842983

29852984
/// Replaces any parameter references in the expression with the
29862985
/// corresponding datum in `params`.
2987-
pub fn bind_parameters(&mut self, params: &Params) -> Result<(), PlanError> {
2986+
pub fn bind_parameters(
2987+
&mut self,
2988+
scx: &StatementContext,
2989+
lifetime: QueryLifetime,
2990+
params: &Params,
2991+
) -> Result<(), PlanError> {
29882992
#[allow(deprecated)]
29892993
self.visit_recursively_mut(0, &mut |_: usize, e: &mut HirScalarExpr| {
29902994
if let HirScalarExpr::Parameter(n, name) = e {
29912995
let datum = match params.datums.iter().nth(*n - 1) {
29922996
None => return Err(PlanError::UnknownParameter(*n)),
29932997
Some(datum) => datum,
29942998
};
2995-
let scalar_type = &params.types[*n - 1];
2999+
let scalar_type = &params.actual_types[*n - 1];
29963000
let row = Row::pack([datum]);
29973001
let column_type = scalar_type.clone().nullable(datum.is_null());
29983002

@@ -3002,7 +3006,16 @@ impl HirScalarExpr {
30023006
Some(Arc::from(format!("${n}")))
30033007
};
30043008

3005-
*e = HirScalarExpr::Literal(row, column_type, TreatAsEqual(name));
3009+
let qcx = QueryContext::root(scx, lifetime);
3010+
let ecx = execute_expr_context(&qcx);
3011+
3012+
*e = plan_cast(
3013+
&ecx,
3014+
*EXECUTE_CAST_CONTEXT,
3015+
HirScalarExpr::Literal(row, column_type, TreatAsEqual(name)),
3016+
&params.expected_types[*n - 1],
3017+
)
3018+
.expect("checked in plan_params");
30063019
}
30073020
Ok(())
30083021
})

src/sql/src/plan/query.rs

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use std::collections::{BTreeMap, BTreeSet};
4242
use std::convert::{TryFrom, TryInto};
4343
use std::num::NonZeroU64;
4444
use std::rc::Rc;
45-
use std::sync::Arc;
45+
use std::sync::{Arc, LazyLock};
4646
use std::{iter, mem};
4747

4848
use itertools::Itertools;
@@ -97,7 +97,7 @@ use crate::plan::hir::{
9797
use crate::plan::plan_utils::{self, GroupSizeHints, JoinSide};
9898
use crate::plan::scope::{Scope, ScopeItem, ScopeUngroupedColumn};
9999
use crate::plan::statement::{StatementContext, StatementDesc, show};
100-
use crate::plan::typeconv::{self, CastContext};
100+
use crate::plan::typeconv::{self, CastContext, plan_hypothetical_cast};
101101
use crate::plan::{
102102
Params, PlanContext, QueryWhen, ShowCreatePlan, WebhookValidation, WebhookValidationSecret,
103103
literal, transform_ast,
@@ -1312,43 +1312,60 @@ pub fn plan_params<'a>(
13121312
}
13131313

13141314
let qcx = QueryContext::root(scx, QueryLifetime::OneShot);
1315-
let scope = Scope::empty();
1316-
let rel_type = RelationType::empty();
13171315

13181316
let mut datums = Row::default();
13191317
let mut packer = datums.packer();
13201318
let mut actual_types = Vec::new();
13211319
let temp_storage = &RowArena::new();
1322-
for (mut expr, expected_ty) in params.into_iter().zip(&desc.param_types) {
1320+
for (i, (mut expr, expected_ty)) in params.into_iter().zip(&desc.param_types).enumerate() {
13231321
transform_ast::transform(scx, &mut expr)?;
13241322

1325-
let ecx = &ExprContext {
1326-
qcx: &qcx,
1327-
name: "EXECUTE",
1328-
scope: &scope,
1329-
relation_type: &rel_type,
1330-
allow_aggregates: false,
1331-
allow_subqueries: false,
1332-
allow_parameters: false,
1333-
allow_windows: false,
1334-
};
1335-
let ex = plan_expr(ecx, &expr)?.type_as_any(ecx)?;
1323+
let ecx = execute_expr_context(&qcx);
1324+
let ex = plan_expr(&ecx, &expr)?.type_as_any(&ecx)?;
13361325
let actual_ty = ecx.scalar_type(&ex);
1337-
if actual_ty != *expected_ty {
1338-
sql_bail!(
1339-
"mismatched parameter type: expected {}, got {}",
1326+
if plan_hypothetical_cast(&ecx, *EXECUTE_CAST_CONTEXT, &actual_ty, expected_ty).is_none() {
1327+
return Err(PlanError::WrongParameterType(
1328+
i + 1,
13401329
ecx.humanize_scalar_type(expected_ty, false),
13411330
ecx.humanize_scalar_type(&actual_ty, false),
1342-
);
1331+
));
13431332
}
13441333
let ex = ex.lower_uncorrelated()?;
13451334
let evaled = ex.eval(&[], temp_storage)?;
13461335
packer.push(evaled);
13471336
actual_types.push(actual_ty);
13481337
}
1349-
Ok(Params { datums, types: actual_types })
1338+
Ok(Params {
1339+
datums,
1340+
actual_types,
1341+
expected_types: desc.param_types.clone(),
1342+
})
1343+
}
1344+
1345+
static EXECUTE_CONTEXT_SCOPE: LazyLock<Scope> = LazyLock::new(Scope::empty);
1346+
static EXECUTE_CONTEXT_REL_TYPE: LazyLock<RelationType> = LazyLock::new(RelationType::empty);
1347+
1348+
/// Returns an `ExprContext` for the expressions in the parameters of an EXECUTE statement.
1349+
pub(crate) fn execute_expr_context<'a>(qcx: &'a QueryContext<'a>) -> ExprContext<'a> {
1350+
ExprContext {
1351+
qcx,
1352+
name: "EXECUTE",
1353+
scope: &EXECUTE_CONTEXT_SCOPE,
1354+
relation_type: &EXECUTE_CONTEXT_REL_TYPE,
1355+
allow_aggregates: false,
1356+
allow_subqueries: false,
1357+
allow_parameters: false,
1358+
allow_windows: false,
1359+
}
13501360
}
13511361

1362+
/// The CastContext used when matching up the types of parameters passed to EXECUTE.
1363+
///
1364+
/// This is an assignment cast also in Postgres, see
1365+
/// <https://github.com/MaterializeInc/database-issues/issues/9266>
1366+
pub(crate) static EXECUTE_CAST_CONTEXT: LazyLock<CastContext> =
1367+
LazyLock::new(|| CastContext::Assignment);
1368+
13521369
pub fn plan_index_exprs<'a>(
13531370
scx: &'a StatementContext,
13541371
on_desc: &RelationDesc,

src/sql/src/plan/side_effecting_func.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ pub fn plan_select_if_side_effecting(
102102
let temp_storage = RowArena::new();
103103
let mut args = vec![];
104104
for mut arg in sef_call.args {
105-
arg.bind_parameters(params)?;
105+
arg.bind_parameters(scx, QueryLifetime::OneShot, params)?;
106106
let arg = arg.lower_uncorrelated()?;
107107
args.push(arg);
108108
}

src/sql/src/plan/statement.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,10 @@ pub fn plan(
278278
resolved_ids: &ResolvedIds,
279279
) -> Result<Plan, PlanError> {
280280
let param_types = params
281-
.types
281+
// We need the `expected_types` here, not the `actual_types`! This is because
282+
// `expected_types` is how the parameter expression (e.g. `$1`) looks "from the outside":
283+
// `bind_parameters` will insert a cast from the actual type to the expected type.
284+
.expected_types
282285
.iter()
283286
.enumerate()
284287
.map(|(i, ty)| (i + 1, ty.clone()))

0 commit comments

Comments
 (0)