@@ -18,135 +18,181 @@ using namespace duckdb;
1818 */
1919
2020// Collect CAST(bound_column, T) patterns where bound_column binds into given GET's index.
21- void CollectCastTypes (const Expression &expr, idx_t index, const vector<ColumnIndex> &column_ids,
22- unordered_map<column_t , LogicalType> &cast_map, unordered_set<column_t > &conflicts) {
23- auto collect_children = [&] {
24- ExpressionIterator::EnumerateChildren (
25- expr, [&](const Expression &child) { CollectCastTypes (child, index, column_ids, cast_map, conflicts); });
26- };
27-
28- if (expr.GetExpressionClass () != ExpressionClass::BOUND_CAST) {
29- return collect_children ();
30- }
31- auto &bound_cast = expr.Cast <BoundCastExpression>();
32-
33- if (bound_cast.child ->GetExpressionType () != ExpressionType::BOUND_COLUMN_REF) {
34- return collect_children ();
35- }
36- auto &bound_column = bound_cast.child ->Cast <BoundColumnRefExpression>();
37-
38- if (bound_column.depth > 0 || bound_column.binding .table_index != index) {
39- return collect_children ();
40- }
41-
42- // We are in a leaf
43- const column_t projection_id = bound_column.binding .column_index ;
44- if (IsVirtualColumn (projection_id)) {
45- return ;
46- }
47- D_ASSERT (projection_id < column_ids.size ());
48- const column_t column_id = column_ids[projection_id].GetPrimaryIndex ();
49- if (auto it = cast_map.find (column_id); it == cast_map.end ()) {
50- cast_map.emplace (column_id, bound_cast.return_type );
51- } else if (it->second != bound_cast.return_type ) {
52- conflicts.insert (column_id);
53- }
21+ // A bare bound_column ref (outside any CAST) is recorded as a conflict: the column is
22+ // consumed at its original type and its scan type must not change.
23+ static void CollectCastTypes (const Expression &expr,
24+ idx_t index,
25+ const vector<ColumnIndex> &column_ids,
26+ unordered_map<column_t , LogicalType> &cast_map,
27+ unordered_set<column_t > &conflicts) {
28+ auto collect_children = [&] {
29+ ExpressionIterator::EnumerateChildren (expr, [&](const Expression &child) {
30+ CollectCastTypes (child, index, column_ids, cast_map, conflicts);
31+ });
32+ };
33+
34+ // Bare column ref pointing to this GET: the column is used at its original type.
35+ if (expr.GetExpressionClass () == ExpressionClass::BOUND_COLUMN_REF) {
36+ auto &colref = expr.Cast <BoundColumnRefExpression>();
37+ if (colref.depth == 0 && colref.binding .table_index == index) {
38+ const column_t proj_id = colref.binding .column_index ;
39+ if (!IsVirtualColumn (proj_id) && proj_id < column_ids.size ()) {
40+ conflicts.insert (column_ids[proj_id].GetPrimaryIndex ());
41+ }
42+ }
43+ return ;
44+ }
45+
46+ if (expr.GetExpressionClass () != ExpressionClass::BOUND_CAST) {
47+ return collect_children ();
48+ }
49+ auto &bound_cast = expr.Cast <BoundCastExpression>();
50+
51+ if (bound_cast.child ->GetExpressionType () != ExpressionType::BOUND_COLUMN_REF) {
52+ return collect_children ();
53+ }
54+ auto &bound_column = bound_cast.child ->Cast <BoundColumnRefExpression>();
55+
56+ if (bound_column.depth > 0 || bound_column.binding .table_index != index) {
57+ return collect_children ();
58+ }
59+
60+ // We are in a leaf: CAST(colref, T) where colref binds into this GET.
61+ const column_t projection_id = bound_column.binding .column_index ;
62+ if (IsVirtualColumn (projection_id)) {
63+ return ;
64+ }
65+ D_ASSERT (projection_id < column_ids.size ());
66+ const column_t column_id = column_ids[projection_id].GetPrimaryIndex ();
67+ if (auto it = cast_map.find (column_id); it == cast_map.end ()) {
68+ cast_map.emplace (column_id, bound_cast.return_type );
69+ } else if (it->second != bound_cast.return_type ) {
70+ conflicts.insert (column_id);
71+ }
5472}
5573
5674// Replace every CAST(bound_column, T) with a bare bound_column at type T when T
5775// is listed in projection_cast.
58- static void ReplaceCastTypes (unique_ptr<Expression> &expr, idx_t index,
76+ static void ReplaceCastTypes (unique_ptr<Expression> &expr,
77+ idx_t index,
5978 const unordered_map<column_t , LogicalType> &projection_cast) {
60- auto replace_children = [&] {
61- ExpressionIterator::EnumerateChildren (
62- *expr, [&](unique_ptr<Expression> &child) { ReplaceCastTypes (child, index, projection_cast); });
63- };
64-
65- if (expr->GetExpressionClass () != ExpressionClass::BOUND_CAST) {
66- return replace_children ();
67- }
68- auto &bound_cast = expr->Cast <BoundCastExpression>();
69-
70- if (bound_cast.child ->GetExpressionType () != ExpressionType::BOUND_COLUMN_REF) {
71- return replace_children ();
72- }
73- auto &bound_column = bound_cast.child ->Cast <BoundColumnRefExpression>();
74-
75- if (bound_column.depth > 0 || bound_column.binding .table_index != index) {
76- return replace_children ();
77- }
78-
79- const column_t projection_id = bound_column.binding .column_index ;
80- auto it = projection_cast.find (projection_id);
81- if (it == projection_cast.end () || it->second != bound_cast.return_type ) {
82- return replace_children ();
83- }
84-
85- expr = make_uniq<BoundColumnRefExpression>(it->second , bound_column.binding );
79+ auto replace_children = [&] {
80+ ExpressionIterator::EnumerateChildren (*expr, [&](unique_ptr<Expression> &child) {
81+ ReplaceCastTypes (child, index, projection_cast);
82+ });
83+ };
84+
85+ if (expr->GetExpressionClass () != ExpressionClass::BOUND_CAST) {
86+ return replace_children ();
87+ }
88+ auto &bound_cast = expr->Cast <BoundCastExpression>();
89+
90+ if (bound_cast.child ->GetExpressionType () != ExpressionType::BOUND_COLUMN_REF) {
91+ return replace_children ();
92+ }
93+ auto &bound_column = bound_cast.child ->Cast <BoundColumnRefExpression>();
94+
95+ if (bound_column.depth > 0 || bound_column.binding .table_index != index) {
96+ return replace_children ();
97+ }
98+
99+ const column_t projection_id = bound_column.binding .column_index ;
100+ auto it = projection_cast.find (projection_id);
101+ if (it == projection_cast.end () || it->second != bound_cast.return_type ) {
102+ return replace_children ();
103+ }
104+
105+ expr = make_uniq<BoundColumnRefExpression>(it->second , bound_column.binding );
106+ }
107+
108+ // Collect cast-type candidates from every operator in the plan tree.
109+ static void CollectFromPlan (LogicalOperator &op,
110+ idx_t index,
111+ const vector<ColumnIndex> &column_ids,
112+ unordered_map<column_t , LogicalType> &cast_map,
113+ unordered_set<column_t > &conflicts) {
114+ LogicalOperatorVisitor::EnumerateExpressions (op, [&](unique_ptr<Expression> *expr_ptr) {
115+ CollectCastTypes (**expr_ptr, index, column_ids, cast_map, conflicts);
116+ });
117+ for (auto &child : op.children ) {
118+ CollectFromPlan (*child, index, column_ids, cast_map, conflicts);
119+ }
86120}
87121
88- // Walk the plan bottom-up and, for each node whose direct child is a GET that
89- // supports type_pushdown, push every CAST(colref, T) found in that node's
90- // expressions into the GET so the scan produces T directly.
91- unique_ptr<LogicalOperator> TryPushdownCastTypes (ClientContext& context, unique_ptr<LogicalOperator> op) {
92- for (auto &child : op->children ) {
93- child = TryPushdownCastTypes (context, std::move (child));
94- }
95-
96- for (const auto &child : op->children ) {
97- if (child->type != LogicalOperatorType::LOGICAL_GET) {
98- continue ;
99- }
100- auto &get = child->Cast <LogicalGet>();
101- if (!get.function .type_pushdown ) {
102- continue ;
103- }
104-
105- const vector<ColumnIndex> &column_ids = get.GetColumnIds ();
106- const idx_t index = get.table_index ;
107- unordered_map<column_t , LogicalType> cast_map;
108- unordered_set<column_t > conflicts;
109-
110- LogicalOperatorVisitor::EnumerateExpressions (*op, [&](unique_ptr<Expression> *expr_ptr) {
111- CollectCastTypes (**expr_ptr, index, column_ids, cast_map, conflicts);
112- });
113-
114- for (column_t col_id : conflicts) {
115- cast_map.erase (col_id);
116- }
117- if (cast_map.empty ()) {
118- continue ;
119- }
120-
121- get.function .type_pushdown (context, get.bind_data , cast_map);
122- for (const auto &[col_id, new_type] : cast_map) {
123- get.returned_types [col_id] = new_type;
124- }
125-
126- unordered_map<idx_t , LogicalType> proj_to_type;
127- for (idx_t i = 0 ; i < column_ids.size (); i++) {
128- const column_t col_idx = column_ids[i].GetPrimaryIndex ();
129- if (auto it = cast_map.find (col_idx); it != cast_map.end ()) {
130- proj_to_type[i] = it->second ;
131- }
132- }
133-
134- LogicalOperatorVisitor::EnumerateExpressions (
135- *op, [&](unique_ptr<Expression> *expr_ptr) { ReplaceCastTypes (*expr_ptr, get.table_index , proj_to_type); });
136- }
137-
138- return op;
122+ // Replace cast expressions in every operator in the plan tree.
123+ static void
124+ ReplaceInPlan (LogicalOperator &op, idx_t index, const unordered_map<column_t , LogicalType> &proj_to_type) {
125+ LogicalOperatorVisitor::EnumerateExpressions (op, [&](unique_ptr<Expression> *expr_ptr) {
126+ ReplaceCastTypes (*expr_ptr, index, proj_to_type);
127+ });
128+ for (auto &child : op.children ) {
129+ ReplaceInPlan (*child, index, proj_to_type);
130+ }
131+ }
132+
133+ static void FindGetWithTypePushdown (LogicalOperator &op, vector<LogicalGet *> &gets) {
134+ if (op.type == LogicalOperatorType::LOGICAL_GET) {
135+ auto &get = op.Cast <LogicalGet>();
136+ if (get.function .type_pushdown ) {
137+ gets.push_back (&get);
138+ }
139+ }
140+ for (auto &child : op.children ) {
141+ FindGetWithTypePushdown (*child, gets);
142+ }
143+ }
144+
145+ // For each GET that supports type_pushdown, collect CAST(col, T) patterns from
146+ // the *entire* plan. Columns that appear bare (outside any cast) or are cast to
147+ // multiple conflicting types are excluded. The surviving types are pushed into
148+ // the GET's bind_data and returned_types, and the redundant CASTs are stripped
149+ // from all operator expressions throughout the plan.
150+ static unique_ptr<LogicalOperator> TryPushdownCastTypes (ClientContext &context,
151+ unique_ptr<LogicalOperator> plan) {
152+ vector<LogicalGet *> gets;
153+ FindGetWithTypePushdown (*plan, gets);
154+
155+ for (LogicalGet *get : gets) {
156+ const vector<ColumnIndex> &column_ids = get->GetColumnIds ();
157+ const idx_t index = get->table_index ;
158+ unordered_map<column_t , LogicalType> cast_map;
159+ unordered_set<column_t > conflicts;
160+
161+ CollectFromPlan (*plan, index, column_ids, cast_map, conflicts);
162+
163+ for (column_t col_id : conflicts) {
164+ cast_map.erase (col_id);
165+ }
166+ if (cast_map.empty ()) {
167+ continue ;
168+ }
169+
170+ get->function .type_pushdown (context, get->bind_data , cast_map);
171+ for (const auto &[col_id, new_type] : cast_map) {
172+ get->returned_types [col_id] = new_type;
173+ }
174+
175+ unordered_map<idx_t , LogicalType> proj_to_type;
176+ for (idx_t i = 0 ; i < column_ids.size (); i++) {
177+ const column_t col_idx = column_ids[i].GetPrimaryIndex ();
178+ if (auto it = cast_map.find (col_idx); it != cast_map.end ()) {
179+ proj_to_type[i] = it->second ;
180+ }
181+ }
182+
183+ ReplaceInPlan (*plan, index, proj_to_type);
184+ }
185+
186+ return plan;
139187}
140188
141189static void VortexOptimizeFunction (OptimizerExtensionInput &input, unique_ptr<LogicalOperator> &plan) {
142190 plan = TryPushdownCastTypes (input.context , std::move (plan));
143191}
144192
145- class VortexOptimizerExtension final : public OptimizerExtension {
146- public:
147- VortexOptimizerExtension () {
148- optimize_function = VortexOptimizeFunction;
149- }
193+ struct VortexOptimizerExtension final : OptimizerExtension {
194+ VortexOptimizerExtension () : OptimizerExtension(VortexOptimizeFunction, nullptr , {}) {
195+ }
150196};
151197
152198extern " C" duckdb_state duckdb_vx_optimizer_extension_register (duckdb_database ffi_db) {
0 commit comments