Skip to content

Commit 6d10a20

Browse files
authored
chore(query): auto cast int to string for concat functions (#17891)
* chore(query): auto cast int to string for concat functions * chore(query): auto cast int to string for concat functions * chore(query): auto cast int to string for concat functions
1 parent 2b6aca0 commit 6d10a20

File tree

8 files changed

+82
-4
lines changed

8 files changed

+82
-4
lines changed

src/query/functions/src/cast_rules.rs

+16-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,15 @@ pub fn register(registry: &mut FunctionRegistry) {
4949

5050
for func_name in ALL_STRING_FUNC_NAMES {
5151
registry.register_additional_cast_rules(func_name, GENERAL_CAST_RULES.iter().cloned());
52-
registry.register_additional_cast_rules(func_name, CAST_FROM_STRING_RULES.iter().cloned());
52+
if ["concat", "concat_ws"].contains(func_name) {
53+
registry.register_additional_cast_rules(
54+
func_name,
55+
get_cast_int_to_string_rules().into_iter(),
56+
)
57+
} else {
58+
registry
59+
.register_additional_cast_rules(func_name, CAST_FROM_STRING_RULES.iter().cloned());
60+
}
5361
registry.register_additional_cast_rules(func_name, CAST_FROM_VARIANT_RULES());
5462
registry.register_additional_cast_rules(func_name, CAST_INT_TO_UINT64.iter().cloned());
5563
}
@@ -366,3 +374,10 @@ pub const CAST_INT_TO_UINT64: AutoCastRules = &[
366374
DataType::Number(NumberDataType::UInt64),
367375
),
368376
];
377+
378+
pub fn get_cast_int_to_string_rules() -> Vec<(DataType, DataType)> {
379+
ALL_NUMERICS_TYPES
380+
.iter()
381+
.map(|ty| (DataType::Number(*ty), DataType::String))
382+
.collect()
383+
}

src/query/functions/src/scalars/array.rs

+36
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,42 @@ pub fn register(registry: &mut FunctionRegistry) {
876876
}
877877
}),
878878
);
879+
880+
registry.register_passthrough_nullable_2_arg::<ArrayType<GenericType<0>>, ArrayType<GenericType<0>>, ArrayType<GenericType<0>> , _, _>(
881+
"array_intersection",
882+
|_, _, _| FunctionDomain::MayThrow,
883+
vectorize_with_builder_2_arg::<ArrayType<GenericType<0>>, ArrayType<GenericType<0>>, ArrayType<GenericType<0>> >(
884+
|left, right, output, _| {
885+
let mut set: StackHashSet<u128, 16> = StackHashSet::with_capacity(left.len());
886+
let builder = &mut output.builder;
887+
for val in left.iter() {
888+
if val == ScalarRef::Null {
889+
continue;
890+
}
891+
let mut hasher = SipHasher24::new();
892+
val.hash(&mut hasher);
893+
let hash128 = hasher.finish128();
894+
let key = hash128.into();
895+
let _ = set.set_insert(key);
896+
}
897+
898+
for val in right.iter() {
899+
if val == ScalarRef::Null {
900+
continue;
901+
}
902+
let mut hasher = SipHasher24::new();
903+
val.hash(&mut hasher);
904+
let hash128 = hasher.finish128();
905+
let key = hash128.into();
906+
907+
if set.contains(&key) {
908+
builder.push(val);
909+
}
910+
}
911+
output.commit_row()
912+
},
913+
),
914+
);
879915
}
880916

881917
fn register_array_aggr(registry: &mut FunctionRegistry) {

src/query/functions/tests/it/scalars/array.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ fn test_array() {
3838
test_array_append(file);
3939
test_array_indexof(file);
4040
test_array_unique(file);
41-
test_array_distinct(file);
41+
test_array_distinct_intersection(file);
4242
test_array_sum(file);
4343
test_array_avg(file);
4444
test_array_count(file);
@@ -316,7 +316,7 @@ fn test_array_unique(file: &mut impl Write) {
316316
]);
317317
}
318318

319-
fn test_array_distinct(file: &mut impl Write) {
319+
fn test_array_distinct_intersection(file: &mut impl Write) {
320320
run_ast(file, "array_distinct([])", &[]);
321321
run_ast(file, "array_distinct([1, 1, 2, 2, 3, NULL])", &[]);
322322
run_ast(
@@ -331,6 +331,12 @@ fn test_array_distinct(file: &mut impl Write) {
331331
("c", Int16Type::from_data(vec![3i16, 1, 3, 4])),
332332
("d", Int16Type::from_data(vec![4i16, 2, 3, 4])),
333333
]);
334+
335+
run_ast(
336+
file,
337+
"array_intersection(['a', NULL, 'a', 'b', NULL, 'c', 'd'], ['a', 'd'])",
338+
&[],
339+
);
334340
}
335341

336342
fn test_array_sum(file: &mut impl Write) {

src/query/functions/tests/it/scalars/string.rs

+1
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ fn test_trim(file: &mut impl Write) {
373373
}
374374

375375
fn test_concat(file: &mut impl Write) {
376+
run_ast(file, "concat('5', 3, 4)", &[]);
376377
run_ast(file, "concat('5', '3', '4')", &[]);
377378
run_ast(file, "concat(NULL, '3', '4')", &[]);
378379
run_ast(

src/query/functions/tests/it/scalars/testdata/array.txt

+9
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,15 @@ evaluation (internal):
12061206
+--------+--------------------------------------------------------------------------------------+
12071207

12081208

1209+
ast : array_intersection(['a', NULL, 'a', 'b', NULL, 'c', 'd'], ['a', 'd'])
1210+
raw expr : array_intersection(array('a', NULL, 'a', 'b', NULL, 'c', 'd'), array('a', 'd'))
1211+
checked expr : array_intersection<T0=String NULL><Array(T0), Array(T0)>(array<T0=String NULL><T0, T0, T0, T0, T0, T0, T0>(CAST<String>("a" AS String NULL), CAST<NULL>(NULL AS String NULL), CAST<String>("a" AS String NULL), CAST<String>("b" AS String NULL), CAST<NULL>(NULL AS String NULL), CAST<String>("c" AS String NULL), CAST<String>("d" AS String NULL)), CAST<Array(String)>(array<T0=String><T0, T0>("a", "d") AS Array(String NULL)))
1212+
optimized expr : ['a', 'd']
1213+
output type : Array(String NULL)
1214+
output domain : [{"a"..="d"}]
1215+
output : ['a', 'd']
1216+
1217+
12091218
ast : array_sum([])
12101219
raw expr : array_sum(array())
12111220
checked expr : array_sum<Array(Nothing)>(array<>())

src/query/functions/tests/it/scalars/testdata/function_list.txt

+2
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ Functions overloads:
155155
0 array_indexof(NULL, NULL) :: NULL
156156
1 array_indexof(Array(T0), T0) :: UInt64
157157
2 array_indexof(Array(T0) NULL, T0 NULL) :: UInt64 NULL
158+
0 array_intersection(Array(T0), Array(T0)) :: Array(T0)
159+
1 array_intersection(Array(T0) NULL, Array(T0) NULL) :: Array(T0) NULL
158160
0 array_kurtosis FACTORY
159161
0 array_max FACTORY
160162
0 array_median FACTORY

src/query/functions/tests/it/scalars/testdata/string.txt

+9
Original file line numberDiff line numberDiff line change
@@ -1858,6 +1858,15 @@ evaluation (internal):
18581858
+--------+----------------------------+
18591859

18601860

1861+
ast : concat('5', 3, 4)
1862+
raw expr : concat('5', 3, 4)
1863+
checked expr : concat<String, String, String>("5", CAST<UInt8>(3_u8 AS String), CAST<UInt8>(4_u8 AS String))
1864+
optimized expr : "534"
1865+
output type : String
1866+
output domain : {"534"..="534"}
1867+
output : '534'
1868+
1869+
18611870
ast : concat('5', '3', '4')
18621871
raw expr : concat('5', '3', '4')
18631872
checked expr : concat<String, String, String>("5", "3", "4")

tests/sqllogictests/suites/ee/05_ee_ddl/05_0003_ddl_create_add_computed_column.test

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ alter table t2 modify column b drop stored
8787
statement ok
8888
create table t3(a string, b string as (concat(a, '-')) stored)
8989

90-
statement error 1117
90+
statement ok
9191
alter table t3 modify column a float
9292

9393
statement ok

0 commit comments

Comments
 (0)