Skip to content

Commit d1a58be

Browse files
Orri Erlingmeta-codesync[bot]
authored andcommitted
Whole function subfield tracking
Differential Revision: D89635456
1 parent 5f87d83 commit d1a58be

File tree

8 files changed

+281
-33
lines changed

8 files changed

+281
-33
lines changed

axiom/optimizer/FunctionRegistry.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ class Call;
121121
struct FunctionMetadata {
122122
bool processSubfields() const {
123123
return subfieldArg.has_value() || !fieldIndexForArg.empty() ||
124-
isArrayConstructor || isMapConstructor || valuePathToArgPath;
124+
isArrayConstructor || isMapConstructor || valuePathToArgPath ||
125+
explode || expandFunction || !lambdas.empty();
125126
}
126127

127128
const LambdaInfo* lambdaInfo(int32_t index) const {
@@ -186,6 +187,16 @@ struct FunctionMetadata {
186187
std::vector<PathCP>& paths)>
187188
explode;
188189

190+
/// Hook for rewriting a call to a function. In the case of a
191+
/// complex type function with subfield related metadata,
192+
/// 'logicalExplode' is used if there are only getters over the
193+
/// function. For functions with subfield related metadata options,
194+
/// 'expandFunction is used only if the function is accessed as a
195+
/// whole. If returns non-nullptr, the returned expression is used
196+
/// in the place of the function.
197+
std::function<logical_plan::ExprPtr(const logical_plan::CallExpr*)>
198+
expandFunction;
199+
189200
/// Function to compute derived constraints for function calls.
190201
std::function<std::optional<Value>(ExprCP, PlanState& state)>
191202
functionConstraint;

axiom/optimizer/SubfieldTracker.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,25 @@ namespace lp = facebook::axiom::logical_plan;
2424

2525
namespace facebook::axiom::optimizer {
2626

27+
lp::ExprPtr ExpandFunctionCache::getOrExpand(
28+
const lp::Expr* expr,
29+
const std::function<lp::ExprPtr(const lp::CallExpr*)>& expandFunction) {
30+
auto it = cache_.find(expr);
31+
if (it != cache_.end()) {
32+
return it->second;
33+
}
34+
35+
auto* call = dynamic_cast<const lp::CallExpr*>(expr);
36+
VELOX_CHECK_NOT_NULL(call, "Expression must be a CallExpr");
37+
38+
auto result = expandFunction(call);
39+
// Only cache non-null results to avoid caching failures
40+
if (result) {
41+
cache_[expr] = result;
42+
}
43+
return result;
44+
}
45+
2746
SubfieldTracker::SubfieldTracker(
2847
std::function<logical_plan::ConstantExprPtr(const logical_plan::ExprPtr&)>
2948
tryFoldConstant)
@@ -428,7 +447,17 @@ void SubfieldTracker::markSubfields(
428447
// If the function is some kind of constructor, like
429448
// make_row_from_map or make_named_row, then a path over it
430449
// selects one argument. If there is no path, all arguments are
431-
// implicitly accessed.
450+
// implicitly accessed. If the whole value of make_row_from_map is accessed,
451+
// this does not yet access the whole map but just the keys listed in
452+
// make_row_from_map.
453+
if (metadata->expandFunction && steps.empty()) {
454+
auto newExpr =
455+
expandFunctionCache_.getOrExpand(call, metadata->expandFunction);
456+
if (newExpr) {
457+
markSubfields(newExpr, steps, isControl, context);
458+
return;
459+
}
460+
}
432461
if (metadata->valuePathToArgPath && !steps.empty()) {
433462
auto pair = metadata->valuePathToArgPath(steps, *call);
434463
markSubfields(expr->inputAt(pair.second), pair.first, isControl, context);

axiom/optimizer/SubfieldTracker.h

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,42 @@
2121

2222
namespace facebook::axiom::optimizer {
2323

24+
/// Cache for expandFunction results. Maps from lp::Expr* to expanded ExprPtr.
25+
/// Uses pointer-based hashing and comparison.
26+
struct ExpandFunctionCache {
27+
/// Checks if the expression is in the cache and returns the cached result.
28+
/// If not in cache, calls expandFunction, caches the result, and returns it.
29+
/// @param expr The expression to expand (must be a CallExpr)
30+
/// @param expandFunction The expansion function to call if not cached
31+
/// @return The expanded expression, or nullptr if expansion returns nullptr
32+
logical_plan::ExprPtr getOrExpand(
33+
const logical_plan::Expr* expr,
34+
const std::function<logical_plan::ExprPtr(const logical_plan::CallExpr*)>&
35+
expandFunction);
36+
37+
private:
38+
struct PointerHash {
39+
size_t operator()(const logical_plan::Expr* ptr) const {
40+
return std::hash<const logical_plan::Expr*>()(ptr);
41+
}
42+
};
43+
44+
struct PointerEqual {
45+
bool operator()(
46+
const logical_plan::Expr* lhs,
47+
const logical_plan::Expr* rhs) const {
48+
return lhs == rhs;
49+
}
50+
};
51+
52+
folly::F14FastMap<
53+
const logical_plan::Expr*,
54+
logical_plan::ExprPtr,
55+
PointerHash,
56+
PointerEqual>
57+
cache_;
58+
};
59+
2460
/// Set of accessed subfields given ordinal of output column or function
2561
/// argument.
2662
struct ResultAccess {
@@ -74,10 +110,14 @@ class SubfieldTracker {
74110

75111
/// Goes over the local plan and collects all accessed columns and subfields.
76112
/// Reports 'control' and 'payload' columns and subfields separately.
77-
std::pair<PlanSubfields, PlanSubfields> markAll(
113+
/// Returns the subfields and the expandFunction cache that can be reused.
114+
std::tuple<PlanSubfields, PlanSubfields, ExpandFunctionCache> markAll(
78115
const logical_plan::LogicalPlanNode& node) && {
79116
markAllSubfields(node, {});
80-
return {controlSubfields_, payloadSubfields_};
117+
return {
118+
std::move(controlSubfields_),
119+
std::move(payloadSubfields_),
120+
std::move(expandFunctionCache_)};
81121
}
82122

83123
// if 'step' applied to result of the function of 'metadata'
@@ -147,6 +187,8 @@ class SubfieldTracker {
147187

148188
PlanSubfields controlSubfields_;
149189
PlanSubfields payloadSubfields_;
190+
191+
ExpandFunctionCache expandFunctionCache_;
150192
};
151193

152194
} // namespace facebook::axiom::optimizer

axiom/optimizer/ToGraph.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,9 +1024,10 @@ ExprCP ToGraph::translateExpr(const lp::ExprPtr& expr) {
10241024

10251025
const auto* call = expr->isCall() ? expr->as<lp::CallExpr>() : nullptr;
10261026
std::string callName;
1027+
const FunctionMetadata* metadata = nullptr;
10271028
if (call) {
10281029
callName = velox::exec::sanitizeName(call->name());
1029-
auto* metadata = functionMetadata(callName);
1030+
metadata = functionMetadata(callName);
10301031
if (metadata && metadata->processSubfields()) {
10311032
auto translated = translateSubfieldFunction(call, metadata);
10321033
if (translated.has_value()) {
@@ -1035,6 +1036,14 @@ ExprCP ToGraph::translateExpr(const lp::ExprPtr& expr) {
10351036
}
10361037
}
10371038

1039+
if (metadata && metadata->expandFunction) {
1040+
auto newExpr =
1041+
expandFunctionCache_.getOrExpand(call, metadata->expandFunction);
1042+
if (newExpr) {
1043+
return translateExpr(newExpr);
1044+
}
1045+
}
1046+
10381047
const auto* specialForm =
10391048
expr->isSpecialForm() ? expr->as<lp::SpecialFormExpr>() : nullptr;
10401049

@@ -2492,7 +2501,7 @@ void ToGraph::translateUnion(const lp::SetNode& set) {
24922501
}
24932502

24942503
DerivedTableP ToGraph::makeQueryGraph(const lp::LogicalPlanNode& logicalPlan) {
2495-
std::tie(controlSubfields_, payloadSubfields_) =
2504+
std::tie(controlSubfields_, payloadSubfields_, expandFunctionCache_) =
24962505
SubfieldTracker([&](const auto& expr) {
24972506
return tryFoldConstant(expr);
24982507
}).markAll(logicalPlan);

axiom/optimizer/ToGraph.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,9 @@ class ToGraph {
419419
// Column and subfield info for items that only affect column values.
420420
PlanSubfields payloadSubfields_;
421421

422+
// Cache for expandFunction results, populated during SubfieldTracker::markAll
423+
ExpandFunctionCache expandFunctionCache_;
424+
422425
/// Expressions corresponding to skyline paths over a subfield decomposable
423426
/// function.
424427
folly::F14FastMap<const logical_plan::CallExpr*, SubfieldProjections>

axiom/optimizer/tests/Genies.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,46 @@ VELOX_DECLARE_VECTOR_FUNCTION_WITH_METADATA(
6161
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
6262
std::make_unique<GenieFunction>());
6363

64+
class IdentityFunction : public exec::VectorFunction {
65+
public:
66+
void apply(
67+
const SelectivityVector& rows,
68+
std::vector<VectorPtr>& args,
69+
const TypePtr& /* outputType */,
70+
exec::EvalCtx& context,
71+
VectorPtr& result) const override {
72+
VELOX_CHECK_EQ(args.size(), 1);
73+
74+
auto& input = args[0];
75+
if (!result) {
76+
// If no result vector provided, just return the input vector directly
77+
result = input;
78+
} else {
79+
// If result vector is provided, copy input to result
80+
BaseVector::ensureWritable(rows, input->type(), context.pool(), result);
81+
result->copy(input.get(), rows, nullptr);
82+
}
83+
}
84+
85+
static std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
86+
// T -> T (accepts any type and returns the same type)
87+
return {exec::FunctionSignatureBuilder()
88+
.typeVariable("T")
89+
.returnType("T")
90+
.argumentType("T")
91+
.build()};
92+
}
93+
};
94+
95+
VELOX_DECLARE_VECTOR_FUNCTION(
96+
udf_identity,
97+
IdentityFunction::signatures(),
98+
std::make_unique<IdentityFunction>());
99+
64100
void registerGenieUdfs() {
65101
VELOX_REGISTER_VECTOR_FUNCTION(udf_genie, "genie");
66102
VELOX_REGISTER_VECTOR_FUNCTION(udf_genie, "exploding_genie");
103+
VELOX_REGISTER_VECTOR_FUNCTION(udf_identity, "identity");
67104
}
68105

69106
} // namespace facebook::axiom::optimizer::test

axiom/optimizer/tests/SubfieldTest.cpp

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,10 @@ TEST_P(SubfieldTest, overAggregation) {
674674
}
675675

676676
TEST_P(SubfieldTest, blackbox) {
677+
registerGenieUdfs();
678+
if (GetParam() == 3) {
679+
optimizerOptions_.allMapsAsStruct = true;
680+
}
677681
auto data = makeRowVector(
678682
{"id", "m"},
679683
{
@@ -686,7 +690,8 @@ TEST_P(SubfieldTest, blackbox) {
686690

687691
lp::PlanBuilder::Context ctx(kHiveConnectorId);
688692
ctx.hook = [](const auto& name, const auto& args) -> lp::ExprPtr {
689-
if (name == "map_row_from_map") {
693+
if (name == "map_row_from_map" || name == "make_row_from_map" ||
694+
name == "padded_make_row_from_map") {
690695
VELOX_CHECK(args.at(2)->isConstant());
691696
auto names = args.at(2)
692697
->template as<lp::ConstantExpr>()
@@ -723,6 +728,82 @@ TEST_P(SubfieldTest, blackbox) {
723728
.build();
724729

725730
ASSERT_NO_THROW(toSingleNodePlan(logicalPlan));
731+
732+
logicalPlan =
733+
lp::PlanBuilder(ctx)
734+
.tableScan("t")
735+
.project(
736+
{"make_row_from_map(m, array[1, 2, 3], array['f1', 'f2', 'f3']) as m"})
737+
.build();
738+
739+
auto plan = toSingleNodePlan(logicalPlan);
740+
741+
verifyRequiredSubfields(
742+
plan, {{"m", {subfield("1"), subfield("2"), subfield("3")}}});
743+
744+
if (GetParam() == 1) {
745+
auto matcher =
746+
core::PlanMatcherBuilder()
747+
.tableScan()
748+
.project(
749+
{"row_constructor(subscript(m_4,1),subscript(m_4,2),subscript(m_4,3))"})
750+
.build();
751+
752+
ASSERT_TRUE(matcher->match(plan));
753+
} else {
754+
auto matcher =
755+
core::PlanMatcherBuilder().tableScan().project().project().build();
756+
757+
ASSERT_TRUE(matcher->match(plan));
758+
}
759+
760+
logicalPlan =
761+
lp::PlanBuilder(ctx, true)
762+
.tableScan("t")
763+
.project(
764+
{"identity(make_row_from_map(m, array[1, 2, 3], array['f1', 'f2', 'f3'])) as m"})
765+
.project(
766+
{"if (m.f1 < 0, ceil(m.f1), floor(m.f1)) as f1b",
767+
"if (m.f2 < 0, 1 + floor(m.f2), ceil(m.f2) + 1) as f2b"})
768+
.project({"f1b + 1 as f1b1", "f1b * 2 as f1b2b", "f2b * 3 as f2b3"})
769+
.build();
770+
771+
// Enable parallel project for the remainder of this test. This is reset in
772+
// SetUp().
773+
optimizerOptions_.parallelProjectWidth = 2;
774+
plan = toSingleNodePlan(logicalPlan);
775+
776+
verifyRequiredSubfields(
777+
plan, {{"m", {subfield("1"), subfield("2"), subfield("3")}}});
778+
779+
if (GetParam() == 1) {
780+
auto matcher =
781+
core::PlanMatcherBuilder()
782+
.tableScan()
783+
.parallelProject()
784+
.parallelProject(
785+
{"identity(row_constructor(subscript(m_7,1),subscript(m_7,2),subscript(m_7,3)))"})
786+
.parallelProject()
787+
.parallelProject()
788+
.project()
789+
.build();
790+
791+
ASSERT_TRUE(matcher->match(plan));
792+
}
793+
794+
logicalPlan =
795+
lp::PlanBuilder(ctx, true)
796+
.tableScan("t")
797+
.project(
798+
{"make_row_from_map(m, array[1, 2, 3], array['f1', 'f2', 'f3']) as m"})
799+
.project(
800+
{"if (coalesce(m.f1, 1::REAL) < 0, ceil(coalesce(m.f1, 1::REAL)), floor(coalesce(m.f1, 1::REAL))) as f1b",
801+
"if (coalesce(m.f2, 1::REAL) < 0, 1 + floor(coalesce(m.f2, 2::real)), ceil(coalesce(m.f2, 3::REAL)) + 1) as f2b"})
802+
.project({"f1b + 1 as f1b1", "f1b * 2 as f1b2b", "f2b * 3 as f2b3"})
803+
.build();
804+
805+
plan = toSingleNodePlan(logicalPlan);
806+
std::cout << plan->toString(true, true);
726807
}
727808

728809
VELOX_INSTANTIATE_TEST_SUITE_P(

0 commit comments

Comments
 (0)