Skip to content

Commit b79ddd1

Browse files
LiaCastanedagstvgrluvatonbenbellick
authored andcommitted
Add lambda substrait support (apache#21193) (#134)
Part of apache#21172 Substrait support wasn't implemented in the core lambda support to reduce PR size Substrait consuming and producing of higher-order functions, lambdas and lambda variables Unit tests added to `datafusion/substrait/tests/cases/roundtrip_logical_plan.rs` None --------- (cherry picked from commit 9a6f67e) (cherry picked from commit 1ac2df1) Co-authored-by: gstvg <28798827+gstvg@users.noreply.github.com> Co-authored-by: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Co-authored-by: Ben Bellick <36523439+benbellick@users.noreply.github.com>
1 parent 08da279 commit b79ddd1

15 files changed

Lines changed: 1661 additions & 35 deletions

File tree

datafusion/substrait/src/logical_plan/consumer/expr/field_reference.rs

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use datafusion::logical_expr::Expr;
2121
use std::sync::Arc;
2222
use substrait::proto::expression::FieldReference;
2323
use substrait::proto::expression::field_reference::ReferenceType::DirectReference;
24-
use substrait::proto::expression::field_reference::RootType;
24+
use substrait::proto::expression::field_reference::{LambdaParameterReference, RootType};
2525
use substrait::proto::expression::reference_segment::ReferenceType::StructField;
2626

2727
pub async fn from_field_reference(
@@ -56,9 +56,9 @@ pub(crate) fn from_substrait_field_reference(
5656
Some(RootType::Expression(_)) => not_impl_err!(
5757
"Expression root type in field reference is not supported"
5858
),
59-
Some(RootType::LambdaParameterReference(_)) => not_impl_err!(
60-
"Lambda parameter reference in field reference is not yet supported"
61-
),
59+
Some(RootType::LambdaParameterReference(
60+
LambdaParameterReference { steps_out },
61+
)) => consumer.lambda_variable(*steps_out as usize, field_idx),
6262
}
6363
}
6464
_ => not_impl_err!(
@@ -85,3 +85,83 @@ fn resolve_outer_reference(
8585
let col = Column::from((qualifier, field));
8686
Ok(Expr::OuterReferenceColumn(Arc::clone(field), col))
8787
}
88+
89+
#[cfg(test)]
90+
mod tests {
91+
use datafusion::{
92+
common::{DFSchema, assert_contains},
93+
prelude::SessionContext,
94+
};
95+
use substrait::proto::{
96+
Type,
97+
expression::{
98+
FieldReference, ReferenceSegment,
99+
field_reference::{self, LambdaParameterReference, RootType},
100+
reference_segment::{ReferenceType, StructField},
101+
},
102+
r#type::{I64, Kind},
103+
};
104+
105+
use crate::{
106+
extensions::Extensions,
107+
logical_plan::consumer::{
108+
DefaultSubstraitConsumer, SubstraitConsumer, from_field_reference,
109+
},
110+
};
111+
112+
#[tokio::test]
113+
async fn test_lambda_variable_invalid_steps_out() {
114+
let lambda_field_ref = lambda_field_ref(0, 99);
115+
116+
let extensions = Extensions::default();
117+
let session_state = SessionContext::new().state();
118+
let consumer = DefaultSubstraitConsumer::new(&extensions, &session_state);
119+
120+
let err = from_field_reference(&consumer, &lambda_field_ref, &DFSchema::empty())
121+
.await
122+
.unwrap_err();
123+
124+
assert_contains!(err.to_string(), "No lambda at 99 steps out, got only 0");
125+
}
126+
127+
#[tokio::test]
128+
async fn test_lambda_variable_invalid_field_idx() {
129+
let lambda_field_ref = lambda_field_ref(1, 0);
130+
131+
let extensions = Extensions::default();
132+
let session_state = SessionContext::new().state();
133+
let consumer = DefaultSubstraitConsumer::new(&extensions, &session_state);
134+
let _names = consumer
135+
.push_lambda_parameters(
136+
&[Type {
137+
kind: Some(Kind::I64(I64::default())),
138+
}],
139+
&DFSchema::empty(),
140+
)
141+
.unwrap();
142+
143+
let err = from_field_reference(&consumer, &lambda_field_ref, &DFSchema::empty())
144+
.await
145+
.unwrap_err();
146+
147+
assert_contains!(
148+
err.to_string(),
149+
"At lambda 0 steps out, no field at index 1, got only 1"
150+
);
151+
}
152+
153+
fn lambda_field_ref(field: i32, steps_out: u32) -> FieldReference {
154+
FieldReference {
155+
reference_type: Some(field_reference::ReferenceType::DirectReference(
156+
ReferenceSegment {
157+
reference_type: Some(ReferenceType::StructField(Box::new(
158+
StructField { field, child: None },
159+
))),
160+
},
161+
)),
162+
root_type: Some(RootType::LambdaParameterReference(
163+
LambdaParameterReference { steps_out },
164+
)),
165+
}
166+
}
167+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::{
19+
common::{DFSchema, substrait_err},
20+
prelude::{Expr, lambda},
21+
};
22+
use substrait::proto;
23+
24+
use crate::logical_plan::consumer::SubstraitConsumer;
25+
26+
pub async fn from_lambda(
27+
consumer: &impl SubstraitConsumer,
28+
expr: &proto::expression::Lambda,
29+
input_schema: &DFSchema,
30+
) -> datafusion::common::Result<Expr> {
31+
let Some(parameters) = expr.parameters.as_ref() else {
32+
return substrait_err!("Lambda expression without parameters is not allowed");
33+
};
34+
35+
let names = consumer.push_lambda_parameters(&parameters.types, input_schema)?;
36+
37+
let Some(body) = expr.body.as_ref() else {
38+
return substrait_err!("Lambda expression without body is not allowed");
39+
};
40+
41+
let body = consumer.consume_expression(body, input_schema).await?;
42+
43+
consumer.pop_lambda_parameters();
44+
45+
Ok(lambda(names, body))
46+
}
47+
48+
#[cfg(test)]
49+
mod tests {
50+
use datafusion::{
51+
common::{DFSchema, assert_contains},
52+
prelude::SessionContext,
53+
};
54+
use substrait::proto::{self, Expression, r#type::Struct};
55+
56+
use crate::{
57+
extensions::Extensions,
58+
logical_plan::consumer::{DefaultSubstraitConsumer, from_lambda},
59+
};
60+
61+
#[tokio::test]
62+
async fn test_lambda_without_body() {
63+
let lambda = proto::expression::Lambda {
64+
parameters: Some(Struct::default()),
65+
body: None,
66+
};
67+
68+
let extensions = Extensions::default();
69+
let session_state = SessionContext::new().state();
70+
let consumer = DefaultSubstraitConsumer::new(&extensions, &session_state);
71+
72+
let err = from_lambda(&consumer, &lambda, &DFSchema::empty())
73+
.await
74+
.unwrap_err();
75+
76+
assert_contains!(
77+
err.to_string(),
78+
"Lambda expression without body is not allowed"
79+
);
80+
}
81+
82+
#[tokio::test]
83+
async fn test_lambda_without_parameters() {
84+
let lambda = proto::expression::Lambda {
85+
parameters: None,
86+
body: Some(Box::new(Expression::default())),
87+
};
88+
89+
let extensions = Extensions::default();
90+
let session_state = SessionContext::new().state();
91+
let consumer = DefaultSubstraitConsumer::new(&extensions, &session_state);
92+
93+
let err = from_lambda(&consumer, &lambda, &DFSchema::empty())
94+
.await
95+
.unwrap_err();
96+
97+
assert_contains!(
98+
err.to_string(),
99+
"Lambda expression without parameters is not allowed"
100+
);
101+
}
102+
}

datafusion/substrait/src/logical_plan/consumer/expr/mod.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mod cast;
2020
mod field_reference;
2121
mod function_arguments;
2222
mod if_then;
23+
mod lambda;
2324
mod literal;
2425
mod nested;
2526
mod scalar_function;
@@ -32,6 +33,7 @@ pub use cast::*;
3233
pub use field_reference::*;
3334
pub use function_arguments::*;
3435
pub use if_then::*;
36+
pub use lambda::*;
3537
pub use literal::*;
3638
pub use nested::*;
3739
pub use scalar_function::*;
@@ -95,8 +97,11 @@ pub async fn from_substrait_rex(
9597
RexType::DynamicParameter(expr) => {
9698
consumer.consume_dynamic_parameter(expr, input_schema).await
9799
}
98-
RexType::Lambda(_) | RexType::LambdaInvocation(_) => {
99-
not_impl_err!("Lambda expressions are not yet supported")
100+
RexType::Lambda(lambda) => {
101+
consumer.consume_lambda(lambda.as_ref(), input_schema).await
102+
}
103+
RexType::LambdaInvocation(_) => {
104+
not_impl_err!("Lambda invocations are not supported")
100105
}
101106
},
102107
None => substrait_err!("Expression must set rex_type: {expression:?}"),

datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ pub async fn from_scalar_function(
3030
f: &ScalarFunction,
3131
input_schema: &DFSchema,
3232
) -> Result<Expr> {
33-
//TODO: handle higher order functions, as they are also encoded as scalar functions
3433
let Some(fn_signature) = consumer
3534
.get_extensions()
3635
.functions
@@ -45,6 +44,20 @@ pub async fn from_scalar_function(
4544
let fn_name = substrait_fun_name(fn_signature);
4645
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;
4746

47+
let higher_order_func = consumer
48+
.get_function_registry()
49+
.higher_order_function(fn_name)
50+
.or_else(|e| {
51+
if let Some(alt_name) = substrait_to_df_name(fn_name) {
52+
consumer
53+
.get_function_registry()
54+
.higher_order_function(alt_name)
55+
.or(Err(e))
56+
} else {
57+
Err(e)
58+
}
59+
});
60+
4861
let udf_func = consumer.get_function_registry().udf(fn_name).or_else(|e| {
4962
if let Some(alt_name) = substrait_to_df_name(fn_name) {
5063
consumer.get_function_registry().udf(alt_name).or(Err(e))
@@ -53,9 +66,14 @@ pub async fn from_scalar_function(
5366
}
5467
});
5568

56-
// try to first match the requested function into registered udfs, then built-in ops
69+
// try to first match the requested function into registered higher-order functions, then udfs, built-in ops
5770
// and finally built-in expressions
58-
if let Ok(func) = udf_func {
71+
if let Ok(func) = higher_order_func {
72+
Ok(Expr::HigherOrderFunction(expr::HigherOrderFunction::new(
73+
func.to_owned(),
74+
args,
75+
)))
76+
} else if let Ok(func) = udf_func {
5977
Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
6078
func.to_owned(),
6179
args,

0 commit comments

Comments
 (0)