Skip to content

[Feature][transform-v2] sql transform support multi_if function #9154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/en/transform-v2/sql-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,16 @@ Example:

NULLIF(A, B)


### MULTI_IF
```MULTI_IF(condition1, value1, condition2, value2, ... conditionN, valueN, bValue)```

returns the first value for which the corresponding condition is true. If all conditions are false, it returns the last value.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Example:

MULTI_IF(A > 1, 'A', B > 1, 'B', C > 1, 'C', 'D')

### CASE WHEN

```
Expand Down
9 changes: 9 additions & 0 deletions docs/zh/transform-v2/sql-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,15 @@ IFNULL(A, B)

NULLIF(A, B)

### MULTI_IF
```MULTI_IF(condition1, value1, condition2, value2, ... conditionN, valueN, bValue)```

返回第一个满足相应条件的值。如果所有条件均为假,则返回最后一个值。

示例:

MULTI_IF(A > 1, 'A', B > 1, 'B', C > 1, 'C', 'D')

### CASE WHEN

```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ public void testSQLTransform(TestContainer container) throws IOException, Interr
Container.ExecResult maxMinSql =
container.executeJob("/sql_transform/func_array_max_min.conf");
Assertions.assertEquals(0, maxMinSql.getExitCode());

Container.ExecResult multiIfSql = container.executeJob("/sql_transform/func_multi_if.conf");
Assertions.assertEquals(0, multiIfSql.getExitCode());
}

@TestTemplate
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
######
###### This config file is a demonstration of streaming processing in seatunnel config
######

env {
parallelism = 1
job.mode = "BATCH"
checkpoint.interval = 10000
}

source {
FakeSource {
plugin_output = "fake"
schema = {
fields {
id = "int"
age = "int"
score = "double"
name = "string"
}
}
rows = [
{fields = [1, 15, 85.5, "Alice"], kind = INSERT}
]
}
}

transform {
Sql {
plugin_input = "fake"
plugin_output = "fake1"
query = """
SELECT
id,
age,
score,
name,
MULTI_IF(age < 18, 'Minor', age < 30, 'Young Adult', age < 40, 'Adult', 'Senior') as age_category,
MULTI_IF(score >= 90, 'A', score >= 80, 'B', score >= 70, 'C', score >= 60, 'D', 'F') as grade,
MULTI_IF(score >= 90, 'excellent', 'pass') as grade_category
FROM fake
"""
}
}

sink {
Assert {
plugin_input = "fake1"
rules = {
row_rules = [
{
rule_type = "MIN_ROW"
rule_value = 1
},
{
rule_type = "MAX_ROW"
rule_value = 1
}
],
field_rules = [
{
field_name = "id"
field_type = "int"
field_value = [
{equals_to = 1}
]
},
{
field_name = "age_category"
field_type = "string"
field_value = [
{equals_to = "Minor"}
]
},
{
field_name = "grade"
field_type = "string"
field_value = [
{equals_to = "B"}
]
},
{
field_name = "grade_category"
field_type = "string"
field_value = [
{equals_to = "pass"}
]
}
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ public class ZetaSQLFunction {
public static final String COALESCE = "COALESCE";
public static final String IFNULL = "IFNULL";
public static final String NULLIF = "NULLIF";
public static final String MULTI_IF = "MULTI_IF";

public static final String UUID = "UUID";

Expand Down Expand Up @@ -313,6 +314,14 @@ public Object computeForValue(Expression expression, Object[] inputFields) {
}
if (expression instanceof Function) {
Function function = (Function) expression;
String functionName = function.getName();

// Special handling for MULTI_IF to properly evaluate comparison expressions
if (MULTI_IF.equalsIgnoreCase(functionName)) {
return multiIfFunction(function, inputFields);
}

// Standard handling for other functions
ExpressionList<Expression> expressionList =
(ExpressionList<Expression>) function.getParameters();
List<Object> functionArgs = new ArrayList<>();
Expand All @@ -321,7 +330,7 @@ public Object computeForValue(Expression expression, Object[] inputFields) {
functionArgs.add(computeForValue(funcArgExpression, inputFields));
}
}
return executeFunctionExpr(function.getName(), functionArgs);
return executeFunctionExpr(functionName, functionArgs);
}
if (expression instanceof TimeKeyExpression) {
return executeTimeKeyExpr(((TimeKeyExpression) expression).getStringValue());
Expand Down Expand Up @@ -897,4 +906,47 @@ public SeaTunnelRowType lateralViewMapping(
}
return new SeaTunnelRowType(fieldNames, seaTunnelDataTypes);
}

private Object multiIfFunction(Function function, Object[] inputFields) {
ExpressionList<Expression> expressionList =
(ExpressionList<Expression>) function.getParameters();
if (expressionList == null
|| expressionList.getExpressions() == null
|| expressionList.getExpressions().isEmpty()) {
throw new TransformException(
CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
"MULTI_IF function requires parameters");
}

List<Expression> expressions = expressionList.getExpressions();
if (expressions.size() < 3 || expressions.size() % 2 == 0) {
throw new TransformException(
CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
String.format(
"MULTI_IF function requires at least 3 arguments and an odd number of arguments: %s",
function));
}

// Process pairs of condition-result with special handling for comparison expressions
for (int i = 0; i < expressions.size() - 1; i += 2) {
Expression conditionExpr = expressions.get(i);
Object conditionResult;

// Special handling for comparison expressions
if (conditionExpr instanceof BinaryExpression
&& zetaSQLFilter.isConditionExpr(conditionExpr)) {
conditionResult = zetaSQLFilter.executeFilter(conditionExpr, inputFields);
} else {
conditionResult = computeForValue(conditionExpr, inputFields);
}

if (conditionResult instanceof Boolean && (Boolean) conditionResult) {
// Condition is true, evaluate and return the corresponding result
return computeForValue(expressions.get(i + 1), inputFields);
}
}

// No condition was true, evaluate and return the default value (last argument)
return computeForValue(expressions.get(expressions.size() - 1), inputFields);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,35 @@ private SeaTunnelDataType<?> getFunctionType(Function function) {
case ZetaSQLFunction.IFNULL:
// Result has the same type as first argument
return getExpressionType(function.getParameters().getExpressions().get(0));
case ZetaSQLFunction.MULTI_IF:
ExpressionList multiIfExpressionList = function.getParameters();
if (multiIfExpressionList == null) {
throw new TransformException(
CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
"MULTI_IF function requires parameters");
}

List<Expression> multiIfExpressions = multiIfExpressionList.getExpressions();
if (multiIfExpressions == null || multiIfExpressions.isEmpty()) {
throw new TransformException(
CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
"MULTI_IF function requires parameters");
}

if (multiIfExpressions.size() < 3 || multiIfExpressions.size() % 2 == 0) {
throw new TransformException(
CommonErrorCodeDeprecated.UNSUPPORTED_OPERATION,
String.format(
"MULTI_IF function requires at least 3 arguments and an odd number of arguments"));
}

List<SeaTunnelDataType<?>> resultTypes = new ArrayList<>();
for (int i = 1; i < multiIfExpressions.size() - 1; i += 2) {
resultTypes.add(getExpressionType(multiIfExpressions.get(i)));
}
resultTypes.add(
getExpressionType(multiIfExpressions.get(multiIfExpressions.size() - 1)));
return getMaxType(resultTypes);
case ZetaSQLFunction.MOD:
// Result has the same type as second argument
return getExpressionType(function.getParameters().getExpressions().get(1));
Expand Down
Loading