diff --git a/plugin/trino-teradata/src/main/java/io/trino/plugin/teradata/RewriteSubstring.java b/plugin/trino-teradata/src/main/java/io/trino/plugin/teradata/RewriteSubstring.java new file mode 100644 index 000000000000..34caa256b3b0 --- /dev/null +++ b/plugin/trino-teradata/src/main/java/io/trino/plugin/teradata/RewriteSubstring.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.teradata; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.FunctionName; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; + +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; +import static java.lang.String.format; +import static java.util.stream.Collectors.toList; + +public class RewriteSubstring + implements ConnectorExpressionRule +{ + private static final Capture VALUE = newCapture(); + private static final Capture START = newCapture(); + + public static final FunctionName SUBSTRING_FUNCTION_NAME = new FunctionName("substring"); + + @Override + public Pattern getPattern() + { + return call() + .with(functionName().equalTo(SUBSTRING_FUNCTION_NAME)) + .with(argumentCount().matching(count -> count == 2 || count == 3)) + .with(argument(0).matching(expression().capturedAs(VALUE))) + .with(argument(1).matching(expression().capturedAs(START))); + } + + @Override + public Optional rewrite(Call call, Captures captures, RewriteContext context) + { + Optional value = context.defaultRewrite(captures.get(VALUE)); + Optional start = context.defaultRewrite(captures.get(START)); + + if (value.isEmpty() || start.isEmpty()) { + return Optional.empty(); + } + + if (call.getArguments().size() == 3) { + Optional length = context.defaultRewrite(call.getArguments().get(2)); + if (length.isEmpty()) { + return Optional.empty(); + } + + return Optional.of(new ParameterizedExpression( + format("SUBSTRING(%s FROM %s FOR %s)", + value.get().expression(), + start.get().expression(), + length.get().expression()), + combineParameters(value.get(), start.get(), length.get()))); + } + else { + return Optional.of(new ParameterizedExpression( + format("SUBSTRING(%s FROM %s)", + value.get().expression(), + start.get().expression()), + combineParameters(value.get(), start.get()))); + } + } + + private List combineParameters(ParameterizedExpression... expressions) + { + return Stream.of(expressions) + .flatMap(expr -> expr.parameters().stream()) + .collect(toList()); + } +} diff --git a/plugin/trino-teradata/src/main/java/io/trino/plugin/teradata/RewriteSubstringFunction.java b/plugin/trino-teradata/src/main/java/io/trino/plugin/teradata/RewriteSubstringFunction.java new file mode 100644 index 000000000000..7f7483475dab --- /dev/null +++ b/plugin/trino-teradata/src/main/java/io/trino/plugin/teradata/RewriteSubstringFunction.java @@ -0,0 +1,135 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.teradata; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.projection.ProjectFunctionRule; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.expression.FunctionName; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.VarcharType; + +import java.sql.JDBCType; +import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; + +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.argumentCount; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.expression; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static java.lang.String.format; +import static java.util.stream.Collectors.toList; + +public class RewriteSubstringFunction + implements ProjectFunctionRule +{ + private static final Capture VALUE = newCapture(); + + private static final Pattern PATTERN = call() + .with(functionName().equalTo(new FunctionName("substring"))) + .with(type().matching(type -> type instanceof VarcharType)) + .with(argumentCount().matching(count -> count >= 2 && count <= 3)) + .with(argument(0).matching(expression().capturedAs(VALUE).with(type().matching(type -> type instanceof VarcharType)))); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Optional rewrite(ConnectorTableHandle handle, ConnectorExpression projectionExpression, Captures captures, RewriteContext context) + { + Call call = (Call) projectionExpression; + ConnectorExpression valueExpression = captures.get(VALUE); + + // Get JDBC type handle for the value expression + JdbcTypeHandle typeHandle = getTypeHandle(valueExpression, context); + if (typeHandle == null) { + return Optional.empty(); + } + + // Only rewrite for plain VARCHAR JDBC type named "varchar" + if (JDBCType.valueOf(typeHandle.jdbcType()) != JDBCType.VARCHAR || + !typeHandle.jdbcTypeName().map(name -> name.equalsIgnoreCase("varchar")).orElse(false)) { + return Optional.empty(); + } + + Optional value = context.rewriteExpression(valueExpression); + if (value.isEmpty()) { + return Optional.empty(); + } + + String expression; + List parameters; + if (call.getArguments().size() == 2) { + // Two argument SUBSTRING(value, start) + Optional start = context.rewriteExpression(call.getArguments().get(1)); + if (start.isEmpty()) { + return Optional.empty(); + } + expression = format("SUBSTRING(%s FROM %s)", value.get().expression(), start.get().expression()); + parameters = combineParameters(value.get(), start.get()); + } + else if (call.getArguments().size() == 3) { + // Three argument SUBSTRING(value, start, length) + Optional start = context.rewriteExpression(call.getArguments().get(1)); + Optional length = context.rewriteExpression(call.getArguments().get(2)); + if (start.isEmpty() || length.isEmpty()) { + return Optional.empty(); + } + expression = format("SUBSTRING(%s FROM %s FOR %s)", + value.get().expression(), + start.get().expression(), + length.get().expression()); + parameters = combineParameters(value.get(), start.get(), length.get()); + } + else { + return Optional.empty(); + } + + return Optional.of(new JdbcExpression(expression, ImmutableList.copyOf(parameters), typeHandle)); + } + + private JdbcTypeHandle getTypeHandle(ConnectorExpression expression, RewriteContext context) + { + if (expression instanceof Variable variable) { + return ((JdbcColumnHandle) context.getAssignment(variable.getName())).getJdbcTypeHandle(); + } + // For non-variable expressions, we might need to derive the type handle differently + // This is a simplified approach - you might need more sophisticated type handling + return new JdbcTypeHandle(JDBCType.VARCHAR.getVendorTypeNumber(), Optional.of("varchar"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + } + + private List combineParameters(ParameterizedExpression... expressions) + { + return Stream.of(expressions) + .flatMap(expr -> expr.parameters().stream()) + .collect(toList()); + } +} diff --git a/plugin/trino-teradata/src/main/java/io/trino/plugin/teradata/TeradataClient.java b/plugin/trino-teradata/src/main/java/io/trino/plugin/teradata/TeradataClient.java index 7da284b393a0..4e13de72f095 100644 --- a/plugin/trino-teradata/src/main/java/io/trino/plugin/teradata/TeradataClient.java +++ b/plugin/trino-teradata/src/main/java/io/trino/plugin/teradata/TeradataClient.java @@ -210,8 +210,8 @@ public class TeradataClient extends BaseJdbcClient { - private static final PredicatePushdownController TERADATA_STRING_PUSHDOWN = FULL_PUSHDOWN; + private static final PredicatePushdownController TERADATA_STRING_PUSHDOWN = FULL_PUSHDOWN; private static final long MAX_FALLBACK_NDV = 1_000_000L; private static final double DEFAULT_FALLBACK_FRACTION = 0.1; @@ -230,6 +230,7 @@ public class TeradataClient private ProjectFunctionRewriter projectFunctionRewriter; + @Inject public TeradataClient(BaseJdbcConfig config, TeradataConfig teradataConfig, @@ -747,6 +748,7 @@ private void buildExpressionRewriter() .add(new RewriteIn()) .add(new RewriteLikeWithCaseSensitivity()) .add(new RewriteLikeEscapeWithCaseSensitivity()) + .add(new RewriteSubstring()) .add(new RewriteLower()) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double")) diff --git a/plugin/trino-teradata/src/test/java/io/trino/plugin/unit/TestRewriteSubstring.java b/plugin/trino-teradata/src/test/java/io/trino/plugin/unit/TestRewriteSubstring.java new file mode 100644 index 000000000000..57f40b3f1230 --- /dev/null +++ b/plugin/trino-teradata/src/test/java/io/trino/plugin/unit/TestRewriteSubstring.java @@ -0,0 +1,115 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.unit; + +import io.trino.plugin.teradata.RewriteSubstring; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.Constant; +import io.trino.spi.expression.FunctionName; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.IntegerType; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static io.trino.spi.type.VarcharType.VARCHAR; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestRewriteSubstring +{ + private final RewriteSubstring rewriteSubstring = new RewriteSubstring(); + + @Test + public void testPatternMatchesSubstringWithTwoArguments() + { + Variable value = new Variable("test_column", VARCHAR); + Constant start = new Constant(1L, IntegerType.INTEGER); + + Call substringCall = new Call( + VARCHAR, + new FunctionName("substring"), + List.of(value, start)); + boolean matches = rewriteSubstring.getPattern().match(substringCall).findFirst().isPresent(); + assertThat(matches).isTrue(); + } + + @Test + public void testPatternMatchesSubstringWithThreeArguments() + { + Variable value = new Variable("test_column", VARCHAR); + Constant start = new Constant(1L, IntegerType.INTEGER); + Constant length = new Constant(5L, IntegerType.INTEGER); + + Call substringCall = new Call( + VARCHAR, + new FunctionName("substring"), + List.of(value, start, length)); + boolean matches = rewriteSubstring.getPattern().match(substringCall).findFirst().isPresent(); + assertThat(matches).isTrue(); + } + + @Test + public void testPatternDoesNotMatchOtherFunctions() + { + Variable value = new Variable("test_column", VARCHAR); + Constant start = new Constant(1L, IntegerType.INTEGER); + + Call upperCall = new Call( + VARCHAR, + new FunctionName("upper"), + List.of(value, start)); + boolean matches = rewriteSubstring.getPattern().match(upperCall).findFirst().isPresent(); + assertThat(matches).isFalse(); + } + + @Test + public void testPatternDoesNotMatchWithOneArgument() + { + Variable value = new Variable("test_column", VARCHAR); + + Call substringCall = new Call( + VARCHAR, + new FunctionName("substring"), + List.of(value)); + boolean matches = rewriteSubstring.getPattern().match(substringCall).findFirst().isPresent(); + assertThat(matches).isFalse(); + } + + @Test + public void testPatternDoesNotMatchWithFourArguments() + { + Variable value = new Variable("test_column", VARCHAR); + Constant start = new Constant(1L, IntegerType.INTEGER); + Constant length = new Constant(5L, IntegerType.INTEGER); + Constant extra = new Constant(10L, IntegerType.INTEGER); + + Call substringCall = new Call( + VARCHAR, + new FunctionName("substring"), + List.of(value, start, length, extra)); + boolean matches = rewriteSubstring.getPattern().match(substringCall).findFirst().isPresent(); + assertThat(matches).isFalse(); + } + + @Test + public void testPatternDoesNotMatchEmptyArguments() + { + Call substringCall = new Call( + VARCHAR, + new FunctionName("substring"), + List.of()); + boolean matches = rewriteSubstring.getPattern().match(substringCall).findFirst().isPresent(); + assertThat(matches).isFalse(); + } +} diff --git a/plugin/trino-teradata/src/test/java/io/trino/plugin/unit/TestRewriteSubstringFunction.java b/plugin/trino-teradata/src/test/java/io/trino/plugin/unit/TestRewriteSubstringFunction.java new file mode 100644 index 000000000000..71ad5e39505b --- /dev/null +++ b/plugin/trino-teradata/src/test/java/io/trino/plugin/unit/TestRewriteSubstringFunction.java @@ -0,0 +1,129 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.unit; + +import io.trino.plugin.teradata.RewriteSubstringFunction; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.Constant; +import io.trino.spi.expression.FunctionName; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.IntegerType; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static io.trino.spi.type.VarcharType.VARCHAR; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestRewriteSubstringFunction +{ + private final RewriteSubstringFunction rewriteSubstringFunction = new RewriteSubstringFunction(); + + @Test + public void testPatternMatchesSubstringFunctionWithTwoArguments() + { + Variable value = new Variable("test_column", VARCHAR); + Constant start = new Constant(1L, IntegerType.INTEGER); + + Call substringCall = new Call( + VARCHAR, + new FunctionName("substring"), + List.of(value, start)); + boolean matches = rewriteSubstringFunction.getPattern().match(substringCall).findFirst().isPresent(); + assertThat(matches).isTrue(); + } + + @Test + public void testPatternMatchesSubstringFunctionWithThreeArguments() + { + Variable value = new Variable("test_column", VARCHAR); + Constant start = new Constant(1L, IntegerType.INTEGER); + Constant length = new Constant(5L, IntegerType.INTEGER); + + Call substringCall = new Call( + VARCHAR, + new FunctionName("substring"), + List.of(value, start, length)); + boolean matches = rewriteSubstringFunction.getPattern().match(substringCall).findFirst().isPresent(); + assertThat(matches).isTrue(); + } + + @Test + public void testPatternDoesNotMatchNonVarcharFirstArgument() + { + Variable value = new Variable("test_column", IntegerType.INTEGER); + Constant start = new Constant(1L, IntegerType.INTEGER); + + Call substringCall = new Call( + VARCHAR, + new FunctionName("substring"), + List.of(value, start)); + boolean matches = rewriteSubstringFunction.getPattern().match(substringCall).findFirst().isPresent(); + assertThat(matches).isFalse(); + } + + @Test + public void testPatternDoesNotMatchOtherFunctions() + { + Variable value = new Variable("test_column", VARCHAR); + Constant start = new Constant(1L, IntegerType.INTEGER); + + Call lowerCall = new Call( + VARCHAR, + new FunctionName("lower"), + List.of(value, start)); + boolean matches = rewriteSubstringFunction.getPattern().match(lowerCall).findFirst().isPresent(); + assertThat(matches).isFalse(); + } + + @Test + public void testPatternDoesNotMatchWithOneArgument() + { + Variable value = new Variable("test_column", VARCHAR); + + Call substringCall = new Call( + VARCHAR, + new FunctionName("substring"), + List.of(value)); + boolean matches = rewriteSubstringFunction.getPattern().match(substringCall).findFirst().isPresent(); + assertThat(matches).isFalse(); + } + + @Test + public void testPatternDoesNotMatchWithFourArguments() + { + Variable value = new Variable("test_column", VARCHAR); + Constant start = new Constant(1L, IntegerType.INTEGER); + Constant length = new Constant(5L, IntegerType.INTEGER); + Constant extra = new Constant(10L, IntegerType.INTEGER); + + Call substringCall = new Call( + VARCHAR, + new FunctionName("substring"), + List.of(value, start, length, extra)); + boolean matches = rewriteSubstringFunction.getPattern().match(substringCall).findFirst().isPresent(); + assertThat(matches).isFalse(); + } + + @Test + public void testPatternDoesNotMatchEmptyArguments() + { + Call substringCall = new Call( + VARCHAR, + new FunctionName("substring"), + List.of()); + boolean matches = rewriteSubstringFunction.getPattern().match(substringCall).findFirst().isPresent(); + assertThat(matches).isFalse(); + } +}