diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.cpp b/be/src/vec/aggregate_functions/aggregate_function_window.cpp index 9df45611f0fc95..08c70fa2d26c21 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_window.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_window.cpp @@ -44,7 +44,8 @@ AggregateFunctionPtr create_function_lead_lag_first_last(const String& name, // FE have rewrite case first_value(k1,false)--->first_value(k1) // so size is 2, must will be arg_ignore_null_value if (argument_types.size() == 2) { - DCHECK(name == "first_value" || name == "last_value") << "invalid function name: " << name; + DCHECK(name == "first_value" || name == "last_value" || name == "nth_value") + << "invalid function name: " << name; arg_ignore_null_value = true; } @@ -101,6 +102,8 @@ CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_first WindowFunctionFirstImpl); CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_last, FirstLastData, WindowFunctionLastImpl); +CREATE_WINDOW_FUNCTION_WITH_NAME_AND_DATA(create_aggregate_function_window_nth_value, FirstLastData, + WindowFunctionNthValueImpl); void register_aggregate_function_window_rank(AggregateFunctionSimpleFactory& factory) { factory.register_function("dense_rank", creator_without_type::creator); @@ -118,6 +121,7 @@ void register_aggregate_function_window_lead_lag_first_last( factory.register_function_both("lag", create_aggregate_function_window_lag); factory.register_function_both("first_value", create_aggregate_function_window_first); factory.register_function_both("last_value", create_aggregate_function_window_last); + factory.register_function_both("nth_value", create_aggregate_function_window_nth_value); } } // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_window.h b/be/src/vec/aggregate_functions/aggregate_function_window.h index 9bc05f9f8684dc..73bb44892d984e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_window.h +++ b/be/src/vec/aggregate_functions/aggregate_function_window.h @@ -592,6 +592,26 @@ struct WindowFunctionLastImpl : Data { static const char* name() { return "last_value"; } }; +template +struct WindowFunctionNthValueImpl : Data { + void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start, + int64_t frame_end, const IColumn** columns) { + DCHECK_LE(frame_start, frame_end); + frame_start = std::max(frame_start, partition_start); + frame_end = std::min(frame_end, partition_end); + int64_t offset = assert_cast(*columns[1]) + .get_data()[0] - + 1; + if (frame_end - frame_start <= offset) { + this->set_is_null(); + return; + } + this->set_value(columns, offset + frame_start); + } + + static const char* name() { return "nth_value"; } +}; + template class WindowFunctionData final : public IAggregateFunctionDataHelper> { diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinWindowFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinWindowFunctions.java index b35903d29fc4e4..65fd4ba3229ca7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinWindowFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinWindowFunctions.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.functions.window.Lag; import org.apache.doris.nereids.trees.expressions.functions.window.LastValue; import org.apache.doris.nereids.trees.expressions.functions.window.Lead; +import org.apache.doris.nereids.trees.expressions.functions.window.NthValue; import org.apache.doris.nereids.trees.expressions.functions.window.Ntile; import org.apache.doris.nereids.trees.expressions.functions.window.PercentRank; import org.apache.doris.nereids.trees.expressions.functions.window.Rank; @@ -45,6 +46,7 @@ public class BuiltinWindowFunctions implements FunctionHelper { window(LastValue.class, "last_value"), window(Lead.class, "lead"), window(Ntile.class, "ntile"), + window(NthValue.class, "nth_value"), window(PercentRank.class, "percent_rank"), window(Rank.class, "rank"), window(RowNumber.class, "row_number"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java index 2da28269fd711c..cbc5061eaccf6a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/WindowFunctionChecker.java @@ -33,6 +33,7 @@ import org.apache.doris.nereids.trees.expressions.functions.window.Lag; import org.apache.doris.nereids.trees.expressions.functions.window.LastValue; import org.apache.doris.nereids.trees.expressions.functions.window.Lead; +import org.apache.doris.nereids.trees.expressions.functions.window.NthValue; import org.apache.doris.nereids.trees.expressions.functions.window.Ntile; import org.apache.doris.nereids.trees.expressions.functions.window.PercentRank; import org.apache.doris.nereids.trees.expressions.functions.window.Rank; @@ -433,6 +434,12 @@ public Ntile visitNtile(Ntile ntile, Void ctx) { return ntile; } + @Override + public NthValue visitNthValue(NthValue nthValue, Void ctx) { + NthValue.checkSecondParameter(nthValue); + return nthValue; + } + /** * check if the current WindowFrame equals with the required WindowFrame; if current WindowFrame is null, * the requiredFrame should be used as default frame. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/window/NthValue.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/window/NthValue.java new file mode 100644 index 00000000000000..e00a308e7e57a3 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/window/NthValue.java @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.window; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.coercion.AnyDataType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * class for nth_value(column, offset) + */ +public class NthValue extends WindowFunction + implements AlwaysNullable, ExplicitlyCastableSignature { + + private static final List SIGNATURES = ImmutableList.of( + FunctionSignature.retArgType(0).args(AnyDataType.INSTANCE_WITHOUT_INDEX, BigIntType.INSTANCE) + ); + + public NthValue(Expression child, Expression offset) { + super("nth_value", child, offset); + } + + public NthValue(List children) { + super("nth_value", children); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } + + @Override + public NthValue withChildren(List children) { + Preconditions.checkArgument(children.size() == 2); + return new NthValue(children); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitNthValue(this, context); + } + + @Override + public DataType getDataType() { + return child(0).getDataType(); + } + + /** + * Check the second parameter of NthValue function. + * The second parameter must be a constant positive integer. + */ + public static void checkSecondParameter(NthValue nthValue) { + Preconditions.checkArgument(nthValue.arity() == 2); + Expression offset = nthValue.child(1); + if (offset instanceof Literal) { + if (((Literal) offset).getDouble() <= 0) { + throw new AnalysisException( + "The offset parameter of NthValue must be a constant positive integer: " + offset); + } + } else { + throw new AnalysisException( + "The offset parameter of NthValue must be a constant positive integer: " + offset); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/WindowFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/WindowFunctionVisitor.java index 90782adc8c685d..546d81d77763f3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/WindowFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/WindowFunctionVisitor.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.functions.window.Lag; import org.apache.doris.nereids.trees.expressions.functions.window.LastValue; import org.apache.doris.nereids.trees.expressions.functions.window.Lead; +import org.apache.doris.nereids.trees.expressions.functions.window.NthValue; import org.apache.doris.nereids.trees.expressions.functions.window.Ntile; import org.apache.doris.nereids.trees.expressions.functions.window.PercentRank; import org.apache.doris.nereids.trees.expressions.functions.window.Rank; @@ -58,6 +59,10 @@ default R visitNtile(Ntile ntile, C context) { return visitWindowFunction(ntile, context); } + default R visitNthValue(NthValue nthValue, C context) { + return visitWindowFunction(nthValue, context); + } + default R visitPercentRank(PercentRank percentRank, C context) { return visitWindowFunction(percentRank, context); } diff --git a/regression-test/data/query_p0/sql_functions/window_functions/test_nthvalue_function.out b/regression-test/data/query_p0/sql_functions/window_functions/test_nthvalue_function.out new file mode 100644 index 00000000000000..7f0970be490c7e --- /dev/null +++ b/regression-test/data/query_p0/sql_functions/window_functions/test_nthvalue_function.out @@ -0,0 +1,94 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select -- +16 + +-- !select_1 -- +\N \N \N \N +1 1989 1001 \N +2 1986 1001 \N +3 1989 1002 \N +4 1991 3021 \N +5 1985 5014 \N +6 32767 3021 \N +7 -32767 1002 \N +8 255 2147483647 \N +9 1991 -2147483647 \N +10 1991 5014 \N +11 1989 25699 \N +12 32767 -2147483647 \N +13 -32767 2147483647 \N +14 255 103 \N +15 1992 3021 \N + +-- !select_2 -- +\N \N \N \N +1 1989 1001 \N +2 1986 1001 \N +3 1989 1002 \N +4 1991 3021 \N +5 1985 5014 \N +6 32767 3021 \N +7 -32767 1002 \N +8 255 2147483647 \N +9 1991 -2147483647 \N +10 1991 5014 \N +11 1989 25699 \N +12 32767 -2147483647 \N +13 -32767 2147483647 \N +14 255 103 \N +15 1992 3021 \N + +-- !select_3 -- +\N \N \N 13 +-32767 7 1002 13 +-32767 13 2147483647 13 +255 8 2147483647 8 +255 14 103 14 +1985 5 5014 5 +1986 2 1001 2 +1989 1 1001 1 +1989 3 1002 3 +1989 11 25699 11 +1991 4 3021 4 +1991 9 -2147483647 9 +1991 10 5014 10 +1992 15 3021 15 +32767 6 3021 6 +32767 12 -2147483647 12 + +-- !select_4 -- +\N \N \N \N +-2147483647 1991 9 \N +-2147483647 32767 12 \N +103 255 14 \N +1001 1986 2 \N +1001 1989 1 \N +1002 -32767 7 \N +1002 1989 3 \N +3021 1991 4 \N +3021 1992 15 \N +3021 32767 6 \N +5014 1985 5 \N +5014 1991 10 \N +25699 1989 11 \N +2147483647 -32767 13 \N +2147483647 255 8 \N + +-- !select_6 -- +\N \N \N \N +1002 -32767 7 \N +2147483647 -32767 13 \N +103 255 14 \N +1001 1986 2 \N +1002 1989 3 \N +3021 1991 4 \N +5014 1991 10 \N +-2147483647 32767 12 14 +2147483647 255 8 \N +5014 1985 5 \N +1001 1989 1 \N +25699 1989 11 \N +-2147483647 1991 9 \N +3021 1992 15 \N +3021 32767 6 \N + diff --git a/regression-test/suites/query_p0/sql_functions/window_functions/test_nthvalue_function.groovy b/regression-test/suites/query_p0/sql_functions/window_functions/test_nthvalue_function.groovy new file mode 100644 index 00000000000000..ee7dcf007338de --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/window_functions/test_nthvalue_function.groovy @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("test_nthvalue_function") { + + def dbName = "test_nthvalue_function_db" + sql "DROP DATABASE IF EXISTS ${dbName}" + sql "CREATE DATABASE ${dbName}" + sql "USE $dbName" + + sql "DROP TABLE IF EXISTS test_nthvalue_function" + sql """ + CREATE TABLE IF NOT EXISTS `test_nthvalue_function` ( + `k0` boolean null comment "", + `k1` tinyint(4) null comment "", + `k2` smallint(6) null comment "", + `k3` int(11) null comment "", + `k4` bigint(20) null comment "", + `k5` decimal(10, 6) null comment "", + `k6` char(5) null comment "", + `k10` date null comment "", + `k11` datetime null comment "", + `k7` varchar(20) null comment "", + `k8` double max null comment "", + `k9` float sum null comment "", + `k12` string replace null comment "", + `k13` largeint(40) replace null comment "" + ) engine=olap + DISTRIBUTED BY HASH(`k1`) BUCKETS 5 properties("replication_num" = "1") + """ + + streamLoad { + table "test_nthvalue_function" + db dbName + set 'column_separator', ',' + file "../../baseall.txt" + } + sql "sync" + + qt_select "select count() from test_nthvalue_function;" + + test { + sql "select k1, k2, k3, nth_value(k1,0) over (partition by k1 order by k2) as ntile from test_nthvalue_function order by k1, k2, k3 desc;" + exception "positive" + } + + test { + sql "select k1, k2, k3, nth_value(k1,-1) over (partition by k1 order by k2) as ntile from test_nthvalue_function order by k1, k2, k3 desc;" + exception "positive" + } + + test { + sql "select k1, k2, k3, nth_value(k1,NULL) over (partition by k1 order by k2) as ntile from test_nthvalue_function order by k1, k2, k3 desc;" + exception "positive" + } + + qt_select_1 "select k1, k2, k3, nth_value(k1,3) over (partition by k1 order by k2) from test_nthvalue_function order by k1, k2, k3 desc;" + qt_select_2 "select k1, k2, k3, nth_value(k1,5) over (partition by k1 order by k2) from test_nthvalue_function order by k1, k2, k3 desc;" + qt_select_3 "select k2, k1, k3, nth_value(k1,3) over (order by k2 rows BETWEEN 2 PRECEDING AND 2 following) from test_nthvalue_function order by k2,k1;" + qt_select_4 "select k3, k2, k1, nth_value(k1,3) over (partition by k3 order by k2) from test_nthvalue_function order by k3, k2, k1;" + qt_select_6 "select k3, k2, k1, nth_value(k1,3) over (partition by k6 order by k2 rows between 10 preceding and 5 preceding) as res from test_nthvalue_function order by k6, k2, k1,res;" + + +} + + + + +