2626import com .alibaba .fluss .flink .utils .FlinkConnectorOptionsUtils ;
2727import com .alibaba .fluss .flink .utils .FlinkConversions ;
2828import com .alibaba .fluss .flink .utils .PushdownUtils ;
29- import com .alibaba .fluss .flink .utils .PushdownUtils .ValueConversion ;
29+ import com .alibaba .fluss .flink .utils .PushdownUtils .FieldEqual ;
3030import com .alibaba .fluss .metadata .MergeEngineType ;
3131import com .alibaba .fluss .metadata .TablePath ;
3232import com .alibaba .fluss .types .RowType ;
7474import java .util .List ;
7575import java .util .Map ;
7676
77+ import static com .alibaba .fluss .flink .utils .PushdownUtils .ValueConversion .FLINK_INTERNAL_VALUE ;
7778import static com .alibaba .fluss .utils .Preconditions .checkNotNull ;
7879
7980/** Flink table source to scan Fluss data. */
@@ -124,6 +125,8 @@ public class FlinkTableSource
124125
125126 private long limit = -1 ;
126127
128+ private List <FieldEqual > partitionFilters = Collections .emptyList ();
129+
127130 public FlinkTableSource (
128131 TablePath tablePath ,
129132 Configuration flussConfig ,
@@ -263,7 +266,8 @@ public boolean isBounded() {
263266 offsetsInitializer ,
264267 scanPartitionDiscoveryIntervalMs ,
265268 new RowDataDeserializationSchema (),
266- streaming );
269+ streaming ,
270+ partitionFilters );
267271
268272 if (!streaming ) {
269273 // return a bounded source provide to make planner happy,
@@ -357,6 +361,7 @@ public DynamicTableSource copy() {
357361 source .projectedFields = projectedFields ;
358362 source .singleRowFilter = singleRowFilter ;
359363 source .modificationScanType = modificationScanType ;
364+ source .partitionFilters = partitionFilters ;
360365 return source ;
361366 }
362367
@@ -378,41 +383,57 @@ public void applyProjection(int[][] projectedFields, DataType producedDataType)
378383
379384 @ Override
380385 public Result applyFilters (List <ResolvedExpression > filters ) {
381- // only apply pk equal filters when all the condition satisfied:
386+ List <ResolvedExpression > acceptedFilters = new ArrayList <>();
387+ List <ResolvedExpression > remainingFilters = new ArrayList <>();
388+
389+ // primary pushdown
382390 // (1) batch execution mode,
383391 // (2) default (full) startup mode,
384392 // (3) the table is a pk table,
385393 // (4) all filters are pk field equal expression
386- if (streaming
387- || startupOptions .startupMode != FlinkConnectorOptions .ScanStartupMode .FULL
388- || !hasPrimaryKey ()
389- || filters .size () != primaryKeyIndexes .length ) {
390- return Result .of (Collections .emptyList (), filters );
391- }
392-
393- List <ResolvedExpression > acceptedFilters = new ArrayList <>();
394- List <ResolvedExpression > remainingFilters = new ArrayList <>();
395- Map <Integer , LogicalType > primaryKeyTypes = getPrimaryKeyTypes ();
396- List <PushdownUtils .FieldEqual > fieldEquals =
397- PushdownUtils .extractFieldEquals (
398- filters ,
399- primaryKeyTypes ,
400- acceptedFilters ,
401- remainingFilters ,
402- ValueConversion .FLINK_INTERNAL_VALUE );
403- int [] keyRowProjection = getKeyRowProjection ();
404- HashSet <Integer > visitedPkFields = new HashSet <>();
405- GenericRowData lookupRow = new GenericRowData (primaryKeyIndexes .length );
406- for (PushdownUtils .FieldEqual fieldEqual : fieldEquals ) {
407- lookupRow .setField (keyRowProjection [fieldEqual .fieldIndex ], fieldEqual .equalValue );
408- visitedPkFields .add (fieldEqual .fieldIndex );
409- }
410- // if not all primary key fields are in condition, we skip to pushdown
411- if (!visitedPkFields .equals (primaryKeyTypes .keySet ())) {
394+ if (!streaming
395+ && startupOptions .startupMode == FlinkConnectorOptions .ScanStartupMode .FULL
396+ && hasPrimaryKey ()
397+ && filters .size () == primaryKeyIndexes .length ) {
398+ Map <Integer , LogicalType > primaryKeyTypes = getPrimaryKeyTypes ();
399+ List <FieldEqual > fieldEquals =
400+ PushdownUtils .extractFieldEquals (
401+ filters ,
402+ primaryKeyTypes ,
403+ acceptedFilters ,
404+ remainingFilters ,
405+ FLINK_INTERNAL_VALUE );
406+ int [] keyRowProjection = getKeyRowProjection ();
407+ HashSet <Integer > visitedPkFields = new HashSet <>();
408+ GenericRowData lookupRow = new GenericRowData (primaryKeyIndexes .length );
409+ for (FieldEqual fieldEqual : fieldEquals ) {
410+ lookupRow .setField (keyRowProjection [fieldEqual .fieldIndex ], fieldEqual .equalValue );
411+ visitedPkFields .add (fieldEqual .fieldIndex );
412+ }
413+ // if not all primary key fields are in condition, we skip to pushdown
414+ if (!visitedPkFields .equals (primaryKeyTypes .keySet ())) {
415+ return Result .of (Collections .emptyList (), filters );
416+ }
417+ singleRowFilter = lookupRow ;
418+ return Result .of (acceptedFilters , remainingFilters );
419+ } else if (isPartitioned ()) {
420+ // dynamic partition pushdown
421+ Map <Integer , LogicalType > partitionKeyTypes = getPartitionKeyTypes ();
422+ List <FieldEqual > fieldEquals =
423+ PushdownUtils .extractFieldEquals (
424+ filters ,
425+ partitionKeyTypes ,
426+ acceptedFilters ,
427+ remainingFilters ,
428+ FLINK_INTERNAL_VALUE );
429+ // partitions are filtered by string representations, convert the equals to string first
430+ fieldEquals = stringifyFieldEquals (fieldEquals );
431+
432+ this .partitionFilters = fieldEquals ;
433+ return Result .of (acceptedFilters , remainingFilters );
434+ } else {
412435 return Result .of (Collections .emptyList (), filters );
413436 }
414- singleRowFilter = lookupRow ;
415- return Result .of (acceptedFilters , remainingFilters );
416437 }
417438
418439 @ Override
@@ -468,6 +489,24 @@ private Map<Integer, LogicalType> getPrimaryKeyTypes() {
468489 return pkTypes ;
469490 }
470491
492+ private Map <Integer , LogicalType > getPartitionKeyTypes () {
493+ Map <Integer , LogicalType > partitionKeyTypes = new HashMap <>();
494+ for (int index : partitionKeyIndexes ) {
495+ partitionKeyTypes .put (index , tableOutputType .getTypeAt (index ));
496+ }
497+ return partitionKeyTypes ;
498+ }
499+
500+ private List <FieldEqual > stringifyFieldEquals (List <FieldEqual > fieldEquals ) {
501+ List <FieldEqual > serialize = new ArrayList <>();
502+ for (FieldEqual fieldEqual : fieldEquals ) {
503+ // revisit this again when we support more data types for partition key
504+ serialize .add (
505+ new FieldEqual (fieldEqual .fieldIndex , (fieldEqual .equalValue ).toString ()));
506+ }
507+ return serialize ;
508+ }
509+
471510 // projection from pk_field_index to index_in_pk
472511 private int [] getKeyRowProjection () {
473512 int [] projection = new int [tableOutputType .getFieldCount ()];
0 commit comments