diff --git a/src/query/functions/src/cast_rules.rs b/src/query/functions/src/cast_rules.rs index d9fb9b2a3f0b..2788e218f706 100644 --- a/src/query/functions/src/cast_rules.rs +++ b/src/query/functions/src/cast_rules.rs @@ -49,7 +49,15 @@ pub fn register(registry: &mut FunctionRegistry) { for func_name in ALL_STRING_FUNC_NAMES { registry.register_additional_cast_rules(func_name, GENERAL_CAST_RULES.iter().cloned()); - registry.register_additional_cast_rules(func_name, CAST_FROM_STRING_RULES.iter().cloned()); + if ["concat", "concat_ws"].contains(func_name) { + registry.register_additional_cast_rules( + func_name, + get_cast_int_to_string_rules().into_iter(), + ) + } else { + registry + .register_additional_cast_rules(func_name, CAST_FROM_STRING_RULES.iter().cloned()); + } registry.register_additional_cast_rules(func_name, CAST_FROM_VARIANT_RULES()); registry.register_additional_cast_rules(func_name, CAST_INT_TO_UINT64.iter().cloned()); } @@ -366,3 +374,10 @@ pub const CAST_INT_TO_UINT64: AutoCastRules = &[ DataType::Number(NumberDataType::UInt64), ), ]; + +pub fn get_cast_int_to_string_rules() -> Vec<(DataType, DataType)> { + ALL_NUMERICS_TYPES + .iter() + .map(|ty| (DataType::Number(*ty), DataType::String)) + .collect() +} diff --git a/src/query/functions/src/scalars/array.rs b/src/query/functions/src/scalars/array.rs index ce91a7a91f01..93c9543c7996 100644 --- a/src/query/functions/src/scalars/array.rs +++ b/src/query/functions/src/scalars/array.rs @@ -876,6 +876,42 @@ pub fn register(registry: &mut FunctionRegistry) { } }), ); + + registry.register_passthrough_nullable_2_arg::>, ArrayType>, ArrayType> , _, _>( + "array_intersection", + |_, _, _| FunctionDomain::MayThrow, + vectorize_with_builder_2_arg::>, ArrayType>, ArrayType> >( + |left, right, output, _| { + let mut set: StackHashSet = StackHashSet::with_capacity(left.len()); + let builder = &mut output.builder; + for val in left.iter() { + if val == ScalarRef::Null { + continue; + } + let mut hasher = SipHasher24::new(); + val.hash(&mut hasher); + let hash128 = hasher.finish128(); + let key = hash128.into(); + let _ = set.set_insert(key); + } + + for val in right.iter() { + if val == ScalarRef::Null { + continue; + } + let mut hasher = SipHasher24::new(); + val.hash(&mut hasher); + let hash128 = hasher.finish128(); + let key = hash128.into(); + + if set.contains(&key) { + builder.push(val); + } + } + output.commit_row() + }, + ), + ); } fn register_array_aggr(registry: &mut FunctionRegistry) { diff --git a/src/query/functions/tests/it/scalars/array.rs b/src/query/functions/tests/it/scalars/array.rs index 28a81fc69984..21aa11869604 100644 --- a/src/query/functions/tests/it/scalars/array.rs +++ b/src/query/functions/tests/it/scalars/array.rs @@ -38,7 +38,7 @@ fn test_array() { test_array_append(file); test_array_indexof(file); test_array_unique(file); - test_array_distinct(file); + test_array_distinct_intersection(file); test_array_sum(file); test_array_avg(file); test_array_count(file); @@ -316,7 +316,7 @@ fn test_array_unique(file: &mut impl Write) { ]); } -fn test_array_distinct(file: &mut impl Write) { +fn test_array_distinct_intersection(file: &mut impl Write) { run_ast(file, "array_distinct([])", &[]); run_ast(file, "array_distinct([1, 1, 2, 2, 3, NULL])", &[]); run_ast( @@ -331,6 +331,12 @@ fn test_array_distinct(file: &mut impl Write) { ("c", Int16Type::from_data(vec![3i16, 1, 3, 4])), ("d", Int16Type::from_data(vec![4i16, 2, 3, 4])), ]); + + run_ast( + file, + "array_intersection(['a', NULL, 'a', 'b', NULL, 'c', 'd'], ['a', 'd'])", + &[], + ); } fn test_array_sum(file: &mut impl Write) { diff --git a/src/query/functions/tests/it/scalars/string.rs b/src/query/functions/tests/it/scalars/string.rs index c4d68a6f371a..6849ec62aa36 100644 --- a/src/query/functions/tests/it/scalars/string.rs +++ b/src/query/functions/tests/it/scalars/string.rs @@ -373,6 +373,7 @@ fn test_trim(file: &mut impl Write) { } fn test_concat(file: &mut impl Write) { + run_ast(file, "concat('5', 3, 4)", &[]); run_ast(file, "concat('5', '3', '4')", &[]); run_ast(file, "concat(NULL, '3', '4')", &[]); run_ast( diff --git a/src/query/functions/tests/it/scalars/testdata/array.txt b/src/query/functions/tests/it/scalars/testdata/array.txt index 6248be64eade..e513b477853b 100644 --- a/src/query/functions/tests/it/scalars/testdata/array.txt +++ b/src/query/functions/tests/it/scalars/testdata/array.txt @@ -1206,6 +1206,15 @@ evaluation (internal): +--------+--------------------------------------------------------------------------------------+ +ast : array_intersection(['a', NULL, 'a', 'b', NULL, 'c', 'd'], ['a', 'd']) +raw expr : array_intersection(array('a', NULL, 'a', 'b', NULL, 'c', 'd'), array('a', 'd')) +checked expr : array_intersection(array(CAST("a" AS String NULL), CAST(NULL AS String NULL), CAST("a" AS String NULL), CAST("b" AS String NULL), CAST(NULL AS String NULL), CAST("c" AS String NULL), CAST("d" AS String NULL)), CAST(array("a", "d") AS Array(String NULL))) +optimized expr : ['a', 'd'] +output type : Array(String NULL) +output domain : [{"a"..="d"}] +output : ['a', 'd'] + + ast : array_sum([]) raw expr : array_sum(array()) checked expr : array_sum(array<>()) diff --git a/src/query/functions/tests/it/scalars/testdata/function_list.txt b/src/query/functions/tests/it/scalars/testdata/function_list.txt index fb4bb55d4216..05679bfee4b5 100644 --- a/src/query/functions/tests/it/scalars/testdata/function_list.txt +++ b/src/query/functions/tests/it/scalars/testdata/function_list.txt @@ -155,6 +155,8 @@ Functions overloads: 0 array_indexof(NULL, NULL) :: NULL 1 array_indexof(Array(T0), T0) :: UInt64 2 array_indexof(Array(T0) NULL, T0 NULL) :: UInt64 NULL +0 array_intersection(Array(T0), Array(T0)) :: Array(T0) +1 array_intersection(Array(T0) NULL, Array(T0) NULL) :: Array(T0) NULL 0 array_kurtosis FACTORY 0 array_max FACTORY 0 array_median FACTORY diff --git a/src/query/functions/tests/it/scalars/testdata/string.txt b/src/query/functions/tests/it/scalars/testdata/string.txt index b44d47547661..90e38951e38f 100644 --- a/src/query/functions/tests/it/scalars/testdata/string.txt +++ b/src/query/functions/tests/it/scalars/testdata/string.txt @@ -1858,6 +1858,15 @@ evaluation (internal): +--------+----------------------------+ +ast : concat('5', 3, 4) +raw expr : concat('5', 3, 4) +checked expr : concat("5", CAST(3_u8 AS String), CAST(4_u8 AS String)) +optimized expr : "534" +output type : String +output domain : {"534"..="534"} +output : '534' + + ast : concat('5', '3', '4') raw expr : concat('5', '3', '4') checked expr : concat("5", "3", "4") diff --git a/tests/sqllogictests/suites/ee/05_ee_ddl/05_0003_ddl_create_add_computed_column.test b/tests/sqllogictests/suites/ee/05_ee_ddl/05_0003_ddl_create_add_computed_column.test index 3198e85afbfd..6af823a11112 100644 --- a/tests/sqllogictests/suites/ee/05_ee_ddl/05_0003_ddl_create_add_computed_column.test +++ b/tests/sqllogictests/suites/ee/05_ee_ddl/05_0003_ddl_create_add_computed_column.test @@ -87,7 +87,7 @@ alter table t2 modify column b drop stored statement ok create table t3(a string, b string as (concat(a, '-')) stored) -statement error 1117 +statement ok alter table t3 modify column a float statement ok