Skip to content

Push the runtime filter to table am. #1032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
365 changes: 365 additions & 0 deletions contrib/pax_storage/expected/runtime_filter.out

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions contrib/pax_storage/pax_schedule
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ test: dictionary_encoding

test: cluster
test: db_size_functions
test: runtime_filter

test: teardown
86 changes: 86 additions & 0 deletions contrib/pax_storage/sql/runtime_filter.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
SET optimizer TO on;

-- Test Suit 1: runtime filter main case
DROP TABLE IF EXISTS fact_rf, dim_rf;
CREATE TABLE fact_rf (fid int, did int, val int) using pax WITH(minmax_columns='fid,did,val');
CREATE TABLE dim_rf (did int, proj_id int, filter_val int) using pax WITH(minmax_columns='did,proj_id,filter_val');

-- Generating data, fact_rd.did and dim_rf.did is 80% matched
INSERT INTO fact_rf SELECT i, i % 8000 + 1, i FROM generate_series(1, 100000) s(i);
INSERT INTO dim_rf SELECT i, i % 10, i FROM generate_series(1, 10000) s(i);
ANALYZE fact_rf, dim_rf;

SET gp_enable_runtime_filter_pushdown TO off;
EXPLAIN ANALYZE SELECT COUNT(*) FROM fact_rf, dim_rf
WHERE fact_rf.did = dim_rf.did AND proj_id < 2 AND filter_val <= 1000;

SET gp_enable_runtime_filter_pushdown TO on;
EXPLAIN ANALYZE SELECT COUNT(*) FROM fact_rf, dim_rf
WHERE fact_rf.did = dim_rf.did AND proj_id < 2 AND filter_val <= 1000;

-- Test bad filter rate
EXPLAIN ANALYZE SELECT COUNT(*) FROM fact_rf, dim_rf
WHERE fact_rf.did = dim_rf.did AND proj_id < 7;

-- Test outer join
-- LeftJoin (eliminated and applicatable)
EXPLAIN ANALYZE SELECT COUNT(*) FROM
fact_rf LEFT JOIN dim_rf ON fact_rf.did = dim_rf.did
WHERE proj_id < 2 AND filter_val <= 1000;

-- LeftJoin
EXPLAIN ANALYZE SELECT COUNT(*) FROM
fact_rf LEFT JOIN dim_rf ON fact_rf.did = dim_rf.did
WHERE proj_id IS NULL OR proj_id < 2 AND filter_val <= 1000;

-- RightJoin (applicatable)
EXPLAIN ANALYZE SELECT COUNT(*) FROM
fact_rf RIGHT JOIN dim_rf ON fact_rf.did = dim_rf.did
WHERE proj_id < 2 AND filter_val <= 1000;

-- SemiJoin
EXPLAIN ANALYZE SELECT COUNT(*) FROM fact_rf
WHERE fact_rf.did IN (SELECT did FROM dim_rf WHERE proj_id < 2 AND filter_val <= 1000);

-- SemiJoin -> InnerJoin and deduplicate
EXPLAIN ANALYZE SELECT COUNT(*) FROM dim_rf
WHERE dim_rf.did IN (SELECT did FROM fact_rf) AND proj_id < 2 AND filter_val <= 1000;

-- Test correctness
SELECT * FROM fact_rf, dim_rf
WHERE fact_rf.did = dim_rf.did AND dim_rf.filter_val = 1
ORDER BY fid;

SELECT * FROM
fact_rf LEFT JOIN dim_rf ON fact_rf.did = dim_rf.did
WHERE dim_rf.filter_val = 1
ORDER BY fid;

SELECT COUNT(*) FROM
fact_rf LEFT JOIN dim_rf ON fact_rf.did = dim_rf.did
WHERE proj_id < 2 AND filter_val <= 1000;

SELECT COUNT(*) FROM
fact_rf LEFT JOIN dim_rf ON fact_rf.did = dim_rf.did
WHERE proj_id IS NULL OR proj_id < 2 AND filter_val <= 1000;

SELECT COUNT(*) FROM
fact_rf RIGHT JOIN dim_rf ON fact_rf.did = dim_rf.did
WHERE proj_id < 2 AND filter_val <= 1000;

SELECT COUNT(*) FROM fact_rf
WHERE fact_rf.did IN (SELECT did FROM dim_rf WHERE proj_id < 2 AND filter_val <= 1000);

SELECT COUNT(*) FROM dim_rf
WHERE dim_rf.did IN (SELECT did FROM fact_rf) AND proj_id < 2 AND filter_val <= 1000;

-- Test contain null values
INSERT INTO dim_rf VALUES (NULL,1, 1);
EXPLAIN ANALYZE SELECT COUNT(*) FROM fact_rf, dim_rf
WHERE fact_rf.did = dim_rf.did AND proj_id < 2 AND filter_val <= 1000;
SELECT COUNT(*) FROM fact_rf, dim_rf
WHERE fact_rf.did = dim_rf.did AND proj_id < 2 AND filter_val <= 1000;

-- Clean up: reset guc
SET gp_enable_runtime_filter_pushdown TO off;
RESET optimizer;
5 changes: 5 additions & 0 deletions contrib/pax_storage/src/cpp/access/pax_access_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ const TupleTableSlotOps *PaxAccessMethod::SlotCallbacks(

uint32 PaxAccessMethod::ScanFlags(Relation relation) {
uint32 flags = 0;
std::vector<int> minmax_columns;
#ifdef VEC_BUILD
flags |= SCAN_SUPPORT_VECTORIZATION | SCAN_SUPPORT_COLUMN_ORIENTED_SCAN;
#else
Expand All @@ -452,6 +453,10 @@ uint32 PaxAccessMethod::ScanFlags(Relation relation) {
#if defined(USE_MANIFEST_API) && !defined(USE_PAX_CATALOG)
flags |= SCAN_FORCE_BIG_WRITE_LOCK;
#endif
minmax_columns = cbdb::GetMinMaxColumnIndexes(relation);
if (!minmax_columns.empty()) {
flags |= SCAN_SUPPORT_RUNTIME_FILTER;
}

return flags;
}
Expand Down
47 changes: 39 additions & 8 deletions src/backend/executor/nodeHash.c
Original file line number Diff line number Diff line change
Expand Up @@ -2584,6 +2584,7 @@ ExecHashTableExplainEnd(PlanState *planstate, struct StringInfoData *buf)
Instrumentation *jinstrument = hjstate->js.ps.instrument;
int total_buckets;
int i;
HashState *hashState = (HashState *) innerPlanState(hjstate);

if (!hashtable ||
!hashtable->stats ||
Expand All @@ -2598,11 +2599,13 @@ ExecHashTableExplainEnd(PlanState *planstate, struct StringInfoData *buf)

if (!hashtable->eagerlyReleased)
{
HashState *hashState = (HashState *) innerPlanState(hjstate);

/* Report on batch in progress, in case the join is being ended early. */
ExecHashTableExplainBatchEnd(hashState, hashtable);
}
if (gp_enable_runtime_filter_pushdown && hashState->filters)
{
ExecRFExplainEnd(hashState, buf);
}

/* Report actual work_mem high water mark. */
jinstrument->workmemused = Max(jinstrument->workmemused, stats->workmem_max);
Expand Down Expand Up @@ -4161,37 +4164,59 @@ PushdownRuntimeFilter(HashState *node)
scankeys = NIL;

attr_filter = lfirst(lc);
if (!IsA(attr_filter->target, SeqScanState) || attr_filter->empty)
if (!IsA(attr_filter->target, SeqScanState)
|| attr_filter->empty || attr_filter->hasnulls)
continue;

SeqScanState *sss = castNode(SeqScanState, attr_filter->target);
/* bloom filter */
sk = (ScanKey)palloc0(sizeof(ScanKeyData));
sk->sk_flags = SK_BLOOM_FILTER;
sk->sk_attno = attr_filter->lattno;
sk->sk_subtype = INT8OID;
sk->sk_argument = PointerGetDatum(attr_filter->blm_filter);
sk->sk_collation = attr_filter->collation;
scankeys = lappend(scankeys, sk);

if (attr_filter->n_distinct > 0)
{
int64 range = attr_filter->max - attr_filter->min + 1;
if ((range / attr_filter->n_distinct) > gp_runtime_filter_selectivity_threshold)
{
/* push previous scankeys */
sss->filters = list_concat(sss->filters, scankeys);
continue;
}
}
/* range filter */
sk = (ScanKey)palloc0(sizeof(ScanKeyData));
sk->sk_flags = 0;
sk->sk_attno = attr_filter->lattno;
sk->sk_strategy = BTGreaterEqualStrategyNumber;
sk->sk_subtype = INT8OID;
sk->sk_subtype = attr_filter->vartype;
sk->sk_argument = attr_filter->min;
sk->sk_collation = attr_filter->collation;
scankeys = lappend(scankeys, sk);

sk = (ScanKey)palloc0(sizeof(ScanKeyData));
sk->sk_flags = 0;
sk->sk_attno = attr_filter->lattno;
sk->sk_strategy = BTLessEqualStrategyNumber;
sk->sk_subtype = INT8OID;
sk->sk_subtype = attr_filter->vartype;
sk->sk_argument = attr_filter->max;
sk->sk_collation = attr_filter->collation;
scankeys = lappend(scankeys, sk);

/* append new runtime filters to target node */
SeqScanState *sss = castNode(SeqScanState, attr_filter->target);
sss->filters = list_concat(sss->filters, scankeys);
if (sss->ss.ss_currentScanDesc != NULL)
{
/* if seqscan is started, we can't pushdown the runtime filter */
list_free_deep(scankeys);
}
else
{
sss->filters = list_concat(sss->filters, scankeys);
}
}
}

Expand All @@ -4206,10 +4231,15 @@ AddTupleValuesIntoRF(HashState *node, TupleTableSlot *slot)
foreach (lc, node->filters)
{
attr_filter = (AttrFilter *) lfirst(lc);
if (attr_filter->hasnulls)
continue;

val = slot_getattr(slot, attr_filter->rattno, &isnull);
if (isnull)
{
attr_filter->hasnulls = true;
continue;
}

attr_filter->empty = false;

Expand Down Expand Up @@ -4258,6 +4288,7 @@ ResetRuntimeFilter(HashState *node)
{
attr_filter = lfirst(lc);
attr_filter->empty = true;
attr_filter->hasnulls = false;

if (IsA(attr_filter->target, SeqScanState))
{
Expand All @@ -4274,7 +4305,7 @@ ResetRuntimeFilter(HashState *node)

attr_filter->blm_filter = bloom_create_aggresive(node->ps.plan->plan_rows,
work_mem,
random());
gp_session_id);

StaticAssertDecl(sizeof(LONG_MAX) == sizeof(Datum), "sizeof(LONG_MAX) should be equal to sizeof(Datum)");
StaticAssertDecl(sizeof(LONG_MIN) == sizeof(Datum), "sizeof(LONG_MIN) should be equal to sizeof(Datum)");
Expand Down
40 changes: 32 additions & 8 deletions src/backend/executor/nodeHashjoin.c
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@

#include "access/htup_details.h"
#include "access/parallel.h"
#include "catalog/pg_statistic.h"
#include "catalog/pg_namespace.h"
#include "executor/executor.h"
#include "executor/hashjoin.h"
#include "executor/instrument.h" /* Instrumentation */
Expand All @@ -118,9 +120,12 @@
#include "executor/nodeRuntimeFilter.h"
#include "miscadmin.h"
#include "pgstat.h"
#include "utils/datum.h"
#include "utils/guc.h"
#include "utils/fmgroids.h"
#include "utils/lsyscache.h"
#include "utils/memutils.h"
#include "utils/rel.h"
#include "utils/sharedtuplestore.h"

#include "cdb/cdbvars.h"
Expand Down Expand Up @@ -168,10 +173,10 @@ static bool IsEqualOp(Expr *expr);
static bool CheckEqualArgs(Expr *expr, AttrNumber *lattno, AttrNumber *rattno);
static bool CheckTargetNode(PlanState *node,
AttrNumber attno,
AttrNumber *lattno);
AttrNumber *lattno, Oid *collation, Oid *var_type);
static List *FindTargetNodes(HashJoinState *hjstate,
AttrNumber attno,
AttrNumber *lattno);
AttrNumber *lattno, Oid *collation, Oid *var_type);
static AttrFilter *CreateAttrFilter(PlanState *target,
AttrNumber lattno,
AttrNumber rattno,
Expand Down Expand Up @@ -2192,6 +2197,8 @@ CreateRuntimeFilter(HashJoinState* hjstate)
AttrFilter *attr_filter;
ListCell *lc;
List *targets;
Oid var_type;
Oid collation;

/*
* A build-side Bloom filter tells us if a row is definitely not in the build
Expand Down Expand Up @@ -2232,7 +2239,7 @@ CreateRuntimeFilter(HashJoinState* hjstate)
if (lattno < 1 || rattno < 1)
continue;

targets = FindTargetNodes(hjstate, lattno, &lattno);
targets = FindTargetNodes(hjstate, lattno, &lattno, &collation, &var_type);
if (lattno == -1 || targets == NULL)
continue;

Expand All @@ -2243,6 +2250,8 @@ CreateRuntimeFilter(HashJoinState* hjstate)

attr_filter = CreateAttrFilter(target, lattno, rattno,
hstate->ps.plan->plan_rows);
attr_filter->vartype = var_type;
attr_filter->collation = collation;
if (attr_filter->blm_filter)
hstate->filters = lappend(hstate->filters, attr_filter);
else
Expand Down Expand Up @@ -2329,7 +2338,7 @@ CheckEqualArgs(Expr *expr, AttrNumber *lattno, AttrNumber *rattno)
}

static bool
CheckTargetNode(PlanState *node, AttrNumber attno, AttrNumber *lattno)
CheckTargetNode(PlanState *node, AttrNumber attno, AttrNumber *lattno, Oid *collation, Oid *var_type)
{
Var *var;
TargetEntry *te;
Expand All @@ -2348,6 +2357,8 @@ CheckTargetNode(PlanState *node, AttrNumber attno, AttrNumber *lattno)
return false;

*lattno = var->varattno;
*collation = var->varcollid;
*var_type = var->vartype;

return true;
}
Expand All @@ -2360,7 +2371,7 @@ CheckTargetNode(PlanState *node, AttrNumber attno, AttrNumber *lattno)
* SeqScan <- target
*/
static List *
FindTargetNodes(HashJoinState *hjstate, AttrNumber attno, AttrNumber *lattno)
FindTargetNodes(HashJoinState *hjstate, AttrNumber attno, AttrNumber *lattno, Oid *collation, Oid *var_type)
{
Var *var;
PlanState *child, *parent;
Expand All @@ -2386,7 +2397,7 @@ FindTargetNodes(HashJoinState *hjstate, AttrNumber attno, AttrNumber *lattno)
* result
* seqscan
*/
if (!CheckTargetNode(child, attno, lattno))
if (!CheckTargetNode(child, attno, lattno, collation, var_type))
return NULL;

targetNodes = lappend(targetNodes, child);
Expand All @@ -2404,7 +2415,7 @@ FindTargetNodes(HashJoinState *hjstate, AttrNumber attno, AttrNumber *lattno)
for (int i = 0; i < as->as_nplans; i++)
{
child = as->appendplans[i];
if (!CheckTargetNode(child, attno, lattno))
if (!CheckTargetNode(child, attno, lattno, collation, var_type))
return NULL;

targetNodes = lappend(targetNodes, child);
Expand Down Expand Up @@ -2452,12 +2463,25 @@ CreateAttrFilter(PlanState *target, AttrNumber lattno, AttrNumber rattno,
{
AttrFilter *attr_filter = palloc0(sizeof(AttrFilter));
attr_filter->empty = true;
attr_filter->hasnulls = false;
attr_filter->target = target;

attr_filter->lattno = lattno;
attr_filter->rattno = rattno;
attr_filter->n_distinct = 0.0;

attr_filter->blm_filter = bloom_create_aggresive(plan_rows, work_mem, random());
attr_filter->blm_filter = bloom_create_aggresive(plan_rows, work_mem, gp_session_id);

if (target && IsA(target, SeqScanState))
{
HeapTuple statstuple;
SeqScanState *scan = (SeqScanState *)target;
statstuple = get_att_stats(RelationGetRelid(scan->ss.ss_currentRelation), lattno);
if (HeapTupleIsValid(statstuple))
{
attr_filter->n_distinct = ((Form_pg_statistic) GETSTRUCT(statstuple))->stadistinct;
}
}

StaticAssertDecl(sizeof(LONG_MAX) == sizeof(Datum), "sizeof(LONG_MAX) should be equal to sizeof(Datum)");
StaticAssertDecl(sizeof(LONG_MIN) == sizeof(Datum), "sizeof(LONG_MIN) should be equal to sizeof(Datum)");
Expand Down
Loading
Loading