|
24 | 24 | #include "duckdb/planner/expression/bound_constant_expression.hpp" |
25 | 25 | #include "duckdb/planner/expression/bound_comparison_expression.hpp" |
26 | 26 | #include "duckdb/planner/filter/conjunction_filter.hpp" |
| 27 | +#include "duckdb/common/types/value_map.hpp" |
27 | 28 |
|
28 | 29 | namespace duckdb { |
29 | 30 |
|
@@ -361,56 +362,137 @@ unique_ptr<GlobalTableFunctionState> DuckIndexScanInitGlobal(ClientContext &cont |
361 | 362 | return std::move(g_state); |
362 | 363 | } |
363 | 364 |
|
364 | | -void ExtractInFilter(unique_ptr<TableFilter> &filter, BoundColumnRefExpression &bound_ref, |
365 | | - unique_ptr<vector<unique_ptr<Expression>>> &filter_expressions) { |
366 | | - // Special-handling of IN filters. |
367 | | - // They are part of a CONJUNCTION_AND. |
368 | | - if (filter->filter_type != TableFilterType::CONJUNCTION_AND) { |
369 | | - return; |
| 365 | +void ExtractExpressionsFromValues(value_set_t &unique_values, BoundColumnRefExpression &bound_ref, |
| 366 | + vector<unique_ptr<Expression>> &expressions) { |
| 367 | + for (const auto &value : unique_values) { |
| 368 | + auto bound_constant = make_uniq<BoundConstantExpression>(value); |
| 369 | + auto filter_expr = make_uniq<BoundComparisonExpression>(ExpressionType::COMPARE_EQUAL, bound_ref.Copy(), |
| 370 | + std::move(bound_constant)); |
| 371 | + expressions.push_back(std::move(filter_expr)); |
370 | 372 | } |
| 373 | +} |
371 | 374 |
|
372 | | - auto &and_filter = filter->Cast<ConjunctionAndFilter>(); |
373 | | - auto &children = and_filter.child_filters; |
374 | | - if (children.empty()) { |
375 | | - return; |
| 375 | +void ExtractIn(InFilter &filter, BoundColumnRefExpression &bound_ref, vector<unique_ptr<Expression>> &expressions) { |
| 376 | + // Eliminate any duplicates. |
| 377 | + value_set_t unique_values; |
| 378 | + for (const auto &value : filter.values) { |
| 379 | + if (unique_values.find(value) == unique_values.end()) { |
| 380 | + unique_values.insert(value); |
| 381 | + } |
376 | 382 | } |
377 | | - if (children[0]->filter_type != TableFilterType::OPTIONAL_FILTER) { |
| 383 | + ExtractExpressionsFromValues(unique_values, bound_ref, expressions); |
| 384 | +} |
| 385 | + |
| 386 | +void ExtractConjunctionAnd(ConjunctionAndFilter &filter, BoundColumnRefExpression &bound_ref, |
| 387 | + vector<unique_ptr<Expression>> &expressions) { |
| 388 | + if (filter.child_filters.empty()) { |
378 | 389 | return; |
379 | 390 | } |
380 | 391 |
|
381 | | - auto &optional_filter = children[0]->Cast<OptionalFilter>(); |
382 | | - auto &child = optional_filter.child_filter; |
383 | | - if (child->filter_type != TableFilterType::IN_FILTER) { |
| 392 | + // Extract the CONSTANT_COMPARISON and IN_FILTER children. |
| 393 | + vector<reference<ConstantFilter>> comparisons; |
| 394 | + vector<reference<InFilter>> in_filters; |
| 395 | + |
| 396 | + for (idx_t i = 0; i < filter.child_filters.size(); i++) { |
| 397 | + if (filter.child_filters[i]->filter_type == TableFilterType::CONSTANT_COMPARISON) { |
| 398 | + auto &comparison = filter.child_filters[i]->Cast<ConstantFilter>(); |
| 399 | + comparisons.push_back(comparison); |
| 400 | + continue; |
| 401 | + } |
| 402 | + |
| 403 | + if (filter.child_filters[i]->filter_type == TableFilterType::OPTIONAL_FILTER) { |
| 404 | + auto &optional_filter = filter.child_filters[i]->Cast<OptionalFilter>(); |
| 405 | + if (!optional_filter.child_filter) { |
| 406 | + return; |
| 407 | + } |
| 408 | + if (optional_filter.child_filter->filter_type != TableFilterType::IN_FILTER) { |
| 409 | + // No support for other optional filter types yet. |
| 410 | + return; |
| 411 | + } |
| 412 | + auto &in_filter = optional_filter.child_filter->Cast<InFilter>(); |
| 413 | + in_filters.push_back(in_filter); |
| 414 | + continue; |
| 415 | + } |
| 416 | + |
| 417 | + // No support for other filter types than CONSTANT_COMPARISON and IN_FILTER in CONJUNCTION_AND yet. |
384 | 418 | return; |
385 | 419 | } |
386 | 420 |
|
387 | | - auto &in_filter = child->Cast<InFilter>(); |
388 | | - if (!in_filter.origin_is_hash_join) { |
| 421 | + // No support for other CONJUNCTION_AND cases yet. |
| 422 | + if (in_filters.empty()) { |
389 | 423 | return; |
390 | 424 | } |
391 | 425 |
|
392 | | - // They are all on the same column, so we can split them. |
393 | | - for (const auto &value : in_filter.values) { |
394 | | - auto bound_constant = make_uniq<BoundConstantExpression>(value); |
395 | | - auto filter_expr = make_uniq<BoundComparisonExpression>(ExpressionType::COMPARE_EQUAL, bound_ref.Copy(), |
396 | | - std::move(bound_constant)); |
397 | | - filter_expressions->push_back(std::move(filter_expr)); |
| 426 | + // Get the combined unique values of the IN filters. |
| 427 | + value_set_t unique_values; |
| 428 | + for (idx_t filter_idx = 0; filter_idx < in_filters.size(); filter_idx++) { |
| 429 | + auto &in_filter = in_filters[filter_idx].get(); |
| 430 | + for (idx_t value_idx = 0; value_idx < in_filter.values.size(); value_idx++) { |
| 431 | + auto &value = in_filter.values[value_idx]; |
| 432 | + if (unique_values.find(value) != unique_values.end()) { |
| 433 | + continue; |
| 434 | + } |
| 435 | + unique_values.insert(value); |
| 436 | + } |
| 437 | + } |
| 438 | + |
| 439 | + // Extract all qualifying values. |
| 440 | + for (auto value_it = unique_values.begin(); value_it != unique_values.end();) { |
| 441 | + bool qualifies = true; |
| 442 | + for (idx_t comp_idx = 0; comp_idx < comparisons.size(); comp_idx++) { |
| 443 | + if (!comparisons[comp_idx].get().Compare(*value_it)) { |
| 444 | + qualifies = false; |
| 445 | + value_it = unique_values.erase(value_it); |
| 446 | + break; |
| 447 | + } |
| 448 | + } |
| 449 | + if (qualifies) { |
| 450 | + value_it++; |
| 451 | + } |
| 452 | + } |
| 453 | + |
| 454 | + ExtractExpressionsFromValues(unique_values, bound_ref, expressions); |
| 455 | +} |
| 456 | + |
| 457 | +void ExtractFilter(TableFilter &filter, BoundColumnRefExpression &bound_ref, |
| 458 | + vector<unique_ptr<Expression>> &expressions) { |
| 459 | + switch (filter.filter_type) { |
| 460 | + case TableFilterType::OPTIONAL_FILTER: { |
| 461 | + auto &optional_filter = filter.Cast<OptionalFilter>(); |
| 462 | + if (!optional_filter.child_filter) { |
| 463 | + return; |
| 464 | + } |
| 465 | + return ExtractFilter(*optional_filter.child_filter, bound_ref, expressions); |
| 466 | + } |
| 467 | + case TableFilterType::IN_FILTER: { |
| 468 | + auto &in_filter = filter.Cast<InFilter>(); |
| 469 | + ExtractIn(in_filter, bound_ref, expressions); |
| 470 | + return; |
| 471 | + } |
| 472 | + case TableFilterType::CONJUNCTION_AND: { |
| 473 | + auto &conjunction_and = filter.Cast<ConjunctionAndFilter>(); |
| 474 | + ExtractConjunctionAnd(conjunction_and, bound_ref, expressions); |
| 475 | + return; |
| 476 | + } |
| 477 | + default: |
| 478 | + return; |
398 | 479 | } |
399 | 480 | } |
400 | 481 |
|
401 | | -unique_ptr<vector<unique_ptr<Expression>>> ExtractFilters(const ColumnDefinition &col, unique_ptr<TableFilter> &filter, |
402 | | - idx_t storage_idx) { |
| 482 | +vector<unique_ptr<Expression>> ExtractFilterExpressions(const ColumnDefinition &col, unique_ptr<TableFilter> &filter, |
| 483 | + idx_t storage_idx) { |
403 | 484 | ColumnBinding binding(0, storage_idx); |
404 | 485 | auto bound_ref = make_uniq<BoundColumnRefExpression>(col.Name(), col.Type(), binding); |
405 | 486 |
|
406 | | - auto filter_expressions = make_uniq<vector<unique_ptr<Expression>>>(); |
407 | | - ExtractInFilter(filter, *bound_ref, filter_expressions); |
| 487 | + vector<unique_ptr<Expression>> expressions; |
| 488 | + ExtractFilter(*filter, *bound_ref, expressions); |
408 | 489 |
|
409 | | - if (filter_expressions->empty()) { |
| 490 | + // Attempt matching the top-level filter to the index expression. |
| 491 | + if (expressions.empty()) { |
410 | 492 | auto filter_expr = filter->ToExpression(*bound_ref); |
411 | | - filter_expressions->push_back(std::move(filter_expr)); |
| 493 | + expressions.push_back(std::move(filter_expr)); |
412 | 494 | } |
413 | | - return filter_expressions; |
| 495 | + return expressions; |
414 | 496 | } |
415 | 497 |
|
416 | 498 | bool TryScanIndex(ART &art, const ColumnList &column_list, TableFunctionInitInput &input, TableFilterSet &filter_set, |
@@ -453,8 +535,8 @@ bool TryScanIndex(ART &art, const ColumnList &column_list, TableFunctionInitInpu |
453 | 535 | return false; |
454 | 536 | } |
455 | 537 |
|
456 | | - auto filter_expressions = ExtractFilters(col, filter->second, storage_index.GetIndex()); |
457 | | - for (const auto &filter_expr : *filter_expressions) { |
| 538 | + auto expressions = ExtractFilterExpressions(col, filter->second, storage_index.GetIndex()); |
| 539 | + for (const auto &filter_expr : expressions) { |
458 | 540 | auto scan_state = art.TryInitializeScan(*index_expr, *filter_expr); |
459 | 541 | if (!scan_state) { |
460 | 542 | return false; |
|
0 commit comments