Skip to content

Add function metadata ability to push down struct argument in optimizer #25175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ protected QueryRunner createQueryRunner()
Optional.empty());
}

@Override
protected QueryRunner createExpectedQueryRunner()
throws Exception
{
return getQueryRunner();
}

@Test
public void testMetadataQueryOptimizationWithLimit()
{
Expand Down Expand Up @@ -1366,6 +1373,18 @@ public void testPushdownSubfields()
assertPushdownSubfields("SELECT x.a FROM test_pushdown_struct_subfields WHERE x.a > 10 AND x.b LIKE 'abc%'", "test_pushdown_struct_subfields",
ImmutableMap.of("x", toSubfields("x.a", "x.b")));

assertQuery("SELECT struct.b FROM (SELECT CUSTOM_STRUCT_WITH_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)");
assertQuery("SELECT struct.b FROM (SELECT CUSTOM_STRUCT_WITHOUT_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)");

assertPushdownSubfields("SELECT struct.b FROM (SELECT CUSTOM_STRUCT_WITH_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)", "test_pushdown_struct_subfields",
ImmutableMap.of("x", toSubfields("x.b")));

assertPushdownSubfields("SELECT struct.b FROM (SELECT CUSTOM_STRUCT_WITHOUT_PASSTHROUGH(x) AS struct FROM test_pushdown_struct_subfields)", "test_pushdown_struct_subfields",
ImmutableMap.of());

assertPushdownSubfields("SELECT struct.b FROM (SELECT x AS struct FROM test_pushdown_struct_subfields)", "test_pushdown_struct_subfields",
ImmutableMap.of("x", toSubfields("x.b")));

// Join
assertPlan("SELECT l.orderkey, x.a, mod(x.d.d1, 2) FROM lineitem l, test_pushdown_struct_subfields a WHERE l.linenumber = a.id",
anyTree(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.FunctionDescriptor;
import com.facebook.presto.spi.function.IsNull;
import com.facebook.presto.spi.function.LiteralParameters;
import com.facebook.presto.spi.function.LongVariableConstraint;
Expand Down Expand Up @@ -58,12 +60,16 @@
import static com.facebook.presto.common.function.OperatorType.NOT_EQUAL;
import static com.facebook.presto.common.type.StandardTypes.PARAMETRIC_TYPES;
import static com.facebook.presto.operator.annotations.ImplementationDependency.isImplementationDependencyAnnotation;
import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
import static com.facebook.presto.util.Failures.checkCondition;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.lang.reflect.Modifier.isPublic;
import static java.lang.reflect.Modifier.isStatic;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptySet;

public class FunctionsParserHelper
{
Expand Down Expand Up @@ -254,6 +260,30 @@ public static Optional<String> parseDescription(AnnotatedElement base)
return (description == null) ? Optional.empty() : Optional.of(description.value());
}

public static ComplexTypeFunctionDescriptor parseFunctionDescriptor(AnnotatedElement base)
{
FunctionDescriptor descriptor = base.getAnnotation(FunctionDescriptor.class);
if (descriptor == null) {
return ComplexTypeFunctionDescriptor.DEFAULT;
}

int pushdownSubfieldArgIndex = descriptor.pushdownSubfieldArgIndex();
Optional<Integer> descriptorPushdownIndex;
if (pushdownSubfieldArgIndex < 0) {
descriptorPushdownIndex = Optional.empty();
}
else {
descriptorPushdownIndex = Optional.of(pushdownSubfieldArgIndex);
}

return new ComplexTypeFunctionDescriptor(
true,
emptyList(),
Optional.of(emptySet()),
Optional.of(ComplexTypeFunctionDescriptor::allSubfieldsRequired),
descriptorPushdownIndex);
}

public static List<LongVariableConstraint> parseLongVariableConstraints(Method inputFunction)
{
return Stream.of(inputFunction.getAnnotationsByType(Constraint.class))
Expand All @@ -277,4 +307,25 @@ public static Map<String, Class<?>> getDeclaredSpecializedTypeParameters(Method
}
return specializedTypeParameters;
}

public static void checkPushdownSubfieldArgIndex(Method method, Signature signature, Optional<Integer> pushdownSubfieldArgIndex)
{
if (pushdownSubfieldArgIndex.isPresent()) {
Map<String, TypeVariableConstraint> typeConstraintMapping = new HashMap<>();
for (TypeVariableConstraint constraint : signature.getTypeVariableConstraints()) {
typeConstraintMapping.put(constraint.getName(), constraint);
}
checkCondition(signature.getArgumentTypes().size() > pushdownSubfieldArgIndex.get(), FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] has out of range pushdown subfield arg index", method);
String typeVariableName = signature.getArgumentTypes().get(pushdownSubfieldArgIndex.get()).toString();

// The type variable must be directly a ROW type
// or (it is a type alias that is not bounded by a type)
// or (it is a type alias that maps to a row type)
boolean meetsTypeConstraint = (!typeConstraintMapping.containsKey(typeVariableName) && typeVariableName.equals(com.facebook.presto.common.type.StandardTypes.ROW)) ||
(typeConstraintMapping.containsKey(typeVariableName) && typeConstraintMapping.get(typeVariableName).getVariadicBound() == null && !typeConstraintMapping.get(typeVariableName).isNonDecimalNumericRequired()) ||
(typeConstraintMapping.containsKey(typeVariableName) && typeConstraintMapping.get(typeVariableName).getVariadicBound().equals(com.facebook.presto.common.type.StandardTypes.ROW));

checkCondition(meetsTypeConstraint, FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] does not have a struct or row type as pushdown subfield arg", method);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.operator.ParametricImplementationsGroup;
import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlFunctionVisibility;
import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -67,6 +68,12 @@ public boolean isCalledOnNullInput()
return details.isCalledOnNullInput();
}

@Override
public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor()
{
return details.getComplexTypeFunctionDescriptor();
}

@Override
public String getDescription()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package com.facebook.presto.operator.scalar;

import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.SqlFunctionVisibility;

import java.util.Optional;
Expand All @@ -23,13 +24,15 @@ public class ScalarHeader
private final SqlFunctionVisibility visibility;
private final boolean deterministic;
private final boolean calledOnNullInput;
private final ComplexTypeFunctionDescriptor complexTypeFunctionDescriptor;

public ScalarHeader(Optional<String> description, SqlFunctionVisibility visibility, boolean deterministic, boolean calledOnNullInput)
public ScalarHeader(Optional<String> description, SqlFunctionVisibility visibility, boolean deterministic, boolean calledOnNullInput, ComplexTypeFunctionDescriptor complexTypeFunctionDescriptor)
{
this.description = description;
this.visibility = visibility;
this.deterministic = deterministic;
this.calledOnNullInput = calledOnNullInput;
this.complexTypeFunctionDescriptor = complexTypeFunctionDescriptor;
}

public Optional<String> getDescription()
Expand All @@ -51,4 +54,9 @@ public boolean isCalledOnNullInput()
{
return calledOnNullInput;
}

public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor()
{
return complexTypeFunctionDescriptor;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.BlockPosition;
import com.facebook.presto.spi.function.CodegenScalarFunction;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.IsNull;
Expand All @@ -51,7 +52,9 @@
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE;
import static com.facebook.presto.metadata.SignatureBinder.applyBoundVariables;
import static com.facebook.presto.operator.annotations.FunctionsParserHelper.checkPushdownSubfieldArgIndex;
import static com.facebook.presto.operator.annotations.FunctionsParserHelper.findPublicStaticMethods;
import static com.facebook.presto.operator.annotations.FunctionsParserHelper.parseFunctionDescriptor;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.functionTypeArgumentProperty;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty.valueTypeArgumentProperty;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.RETURN_NULL_ON_NULL;
Expand Down Expand Up @@ -122,6 +125,8 @@ private static SqlScalarFunction createSqlScalarFunction(Method method)
Arrays.stream(method.getParameters()).map(p -> parseTypeSignature(p.getAnnotation(SqlType.class).value())).collect(toImmutableList()),
false);

ComplexTypeFunctionDescriptor descriptor = parseAndCheckFunctionDescriptor(method, signature);

return new SqlScalarFunction(signature)
{
@Override
Expand Down Expand Up @@ -166,6 +171,19 @@ public boolean isCalledOnNullInput()
{
return codegenScalarFunction.calledOnNullInput();
}

@Override
public ComplexTypeFunctionDescriptor getComplexTypeFunctionDescriptor()
{
return descriptor;
}
};
}

private static ComplexTypeFunctionDescriptor parseAndCheckFunctionDescriptor(Method method, Signature signature)
{
ComplexTypeFunctionDescriptor descriptor = parseFunctionDescriptor(method);
checkPushdownSubfieldArgIndex(method, signature, descriptor.getPushdownSubfieldArgIndex());
return descriptor;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.operator.scalar.ParametricScalar;
import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation.SpecializedSignature;
import com.facebook.presto.spi.function.CodegenScalarFunction;
import com.facebook.presto.spi.function.FunctionDescriptor;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarOperator;
import com.facebook.presto.spi.function.Signature;
Expand All @@ -35,6 +36,7 @@
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.operator.annotations.FunctionsParserHelper.checkPushdownSubfieldArgIndex;
import static com.facebook.presto.operator.scalar.annotations.OperatorValidator.validateOperator;
import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
import static com.facebook.presto.util.Failures.checkCondition;
Expand Down Expand Up @@ -88,7 +90,7 @@ private static List<ScalarHeaderAndMethods> findScalarsInFunctionSetClass(Class<
ImmutableList.Builder<ScalarHeaderAndMethods> builder = ImmutableList.builder();
for (Method method : FunctionsParserHelper.findPublicMethods(
annotated,
ImmutableSet.of(SqlType.class, ScalarFunction.class, ScalarOperator.class),
ImmutableSet.of(SqlType.class, ScalarFunction.class, ScalarOperator.class, FunctionDescriptor.class),
ImmutableSet.of(SqlInvokedScalarFunction.class, CodegenScalarFunction.class))) {
checkCondition((method.getAnnotation(ScalarFunction.class) != null) || (method.getAnnotation(ScalarOperator.class) != null),
FUNCTION_IMPLEMENTATION_ERROR, "Method [%s] annotated with @SqlType is missing @ScalarFunction or @ScalarOperator", method);
Expand All @@ -106,6 +108,7 @@ private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods sc
Map<SpecializedSignature, ParametricScalarImplementation.Builder> signatures = new HashMap<>();
for (Method method : scalar.getMethods()) {
ParametricScalarImplementation implementation = ParametricScalarImplementation.Parser.parseImplementation(header, method, constructor);
checkPushdownSubfieldArgIndex(method, implementation.getSignature(), header.getHeader().getComplexTypeFunctionDescriptor().getPushdownSubfieldArgIndex());
if (!signatures.containsKey(implementation.getSpecializedSignature())) {
ParametricScalarImplementation.Builder builder = new ParametricScalarImplementation.Builder(
implementation.getSignature(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.operator.scalar.ScalarHeader;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarOperator;
import com.facebook.presto.spi.function.SqlFunctionVisibility;
Expand All @@ -28,6 +29,7 @@

import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE;
import static com.facebook.presto.operator.annotations.FunctionsParserHelper.parseDescription;
import static com.facebook.presto.operator.annotations.FunctionsParserHelper.parseFunctionDescriptor;
import static com.facebook.presto.spi.function.SqlFunctionVisibility.HIDDEN;
import static com.google.common.base.CaseFormat.LOWER_CAMEL;
import static com.google.common.base.CaseFormat.LOWER_UNDERSCORE;
Expand Down Expand Up @@ -76,20 +78,21 @@ public static List<ScalarImplementationHeader> fromAnnotatedElement(AnnotatedEle
ScalarFunction scalarFunction = annotated.getAnnotation(ScalarFunction.class);
ScalarOperator scalarOperator = annotated.getAnnotation(ScalarOperator.class);
Optional<String> description = parseDescription(annotated);
ComplexTypeFunctionDescriptor descriptor = parseFunctionDescriptor(annotated);

ImmutableList.Builder<ScalarImplementationHeader> builder = ImmutableList.builder();

if (scalarFunction != null) {
String baseName = scalarFunction.value().isEmpty() ? camelToSnake(annotatedName(annotated)) : scalarFunction.value();
builder.add(new ScalarImplementationHeader(baseName, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), scalarFunction.calledOnNullInput())));
builder.add(new ScalarImplementationHeader(baseName, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), scalarFunction.calledOnNullInput(), descriptor)));

for (String alias : scalarFunction.alias()) {
builder.add(new ScalarImplementationHeader(alias, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), scalarFunction.calledOnNullInput())));
builder.add(new ScalarImplementationHeader(alias, new ScalarHeader(description, scalarFunction.visibility(), scalarFunction.deterministic(), scalarFunction.calledOnNullInput(), descriptor)));
}
}

if (scalarOperator != null) {
builder.add(new ScalarImplementationHeader(scalarOperator.value(), new ScalarHeader(description, HIDDEN, true, scalarOperator.value().isCalledOnNullInput())));
builder.add(new ScalarImplementationHeader(scalarOperator.value(), new ScalarHeader(description, HIDDEN, true, scalarOperator.value().isCalledOnNullInput(), descriptor)));
}

List<ScalarImplementationHeader> result = builder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,16 @@ private static Optional<Subfield> toSubfield(
if (expression instanceof VariableReferenceExpression) {
return Optional.of(new Subfield(((VariableReferenceExpression) expression).getName(), elements.build().reverse()));
}
if (expression instanceof CallExpression) {
ComplexTypeFunctionDescriptor functionDescriptor = functionAndTypeManager.getFunctionMetadata(((CallExpression) expression).getFunctionHandle()).getDescriptor();
Optional<Integer> pushdownSubfieldArgIndex = functionDescriptor.getPushdownSubfieldArgIndex();
if (pushdownSubfieldArgIndex.isPresent() &&
((CallExpression) expression).getArguments().size() > pushdownSubfieldArgIndex.get() &&
((CallExpression) expression).getArguments().get(pushdownSubfieldArgIndex.get()).getType() instanceof RowType) {
expression = ((CallExpression) expression).getArguments().get(pushdownSubfieldArgIndex.get());
continue;
}
}

if (expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm() == DEREFERENCE) {
SpecialFormExpression dereference = (SpecialFormExpression) expression;
Expand Down
Loading
Loading