11package connector
22
33import (
4+ "fmt"
45 "strings"
56
67 "github.com/hasura/ndc-elasticsearch/internal"
@@ -11,10 +12,16 @@ import (
1112// prepareFilterQuery prepares a filter query based on the given expression.
1213func prepareFilterQuery (expression schema.Expression , state * types.State , collection string ) (map [string ]interface {}, error ) {
1314 filter := make (map [string ]interface {})
14- switch expr := expression .Interface ().(type ) {
15+ columnPath , predicate := getPredicate (expression )
16+
17+ switch expr := predicate .Interface ().(type ) {
1518 case * schema.ExpressionUnaryComparisonOperator :
19+ fieldPath := strings .Split (columnPath , "." )
20+ expr .Column .FieldPath = fieldPath
1621 return handleExpressionUnaryComparisonOperator (expr , state , collection )
1722 case * schema.ExpressionBinaryComparisonOperator :
23+ fieldPath := strings .Split (columnPath , "." )
24+ expr .Column .FieldPath = fieldPath
1825 return handleExpressionBinaryComparisonOperator (expr , state , collection )
1926 case * schema.ExpressionAnd :
2027 queries := make ([]map [string ]interface {}, 0 )
@@ -59,10 +66,52 @@ func prepareFilterQuery(expression schema.Expression, state *types.State, collec
5966 }
6067}
6168
69+ // getPredicate checks if a schema.Expression has nested filtering
70+ // if it does, it traverses the schema.Expression recursively until it finds a non-nested query predicate
71+ func getPredicate (expression schema.Expression ) (string , schema.Expression ) {
72+ if nested , fieldName := requiresNestedFiltering (expression ); nested {
73+ expressionPredicate , ok := expression ["predicate" ].(schema.Expression )
74+ if ! ok {
75+ return "" , nil
76+ }
77+
78+ columnPathPostfix , predicate := getPredicate (expressionPredicate )
79+ return fmt .Sprintf ("%s.%s" , fieldName , columnPathPostfix ), predicate
80+ }
81+ switch expr := expression .Interface ().(type ) {
82+ case * schema.ExpressionUnaryComparisonOperator :
83+ return expr .Column .Name , expression
84+ case * schema.ExpressionBinaryComparisonOperator :
85+ return expr .Column .Name , expression
86+ }
87+
88+ return "" , expression
89+ }
90+
91+ func requiresNestedFiltering (predicate schema.Expression ) (requiresNestedFiltering bool , nestedFieldName string ) {
92+ inCollection , ok := predicate ["in_collection" ].(schema.ExistsInCollection )
93+ if ! ok {
94+ return false , ""
95+ }
96+ collection , err := inCollection .AsNestedCollection ()
97+ if err != nil {
98+ return false , ""
99+ }
100+ if collection .Type == "nested_collection" {
101+ return true , collection .ColumnName
102+ }
103+ return false , ""
104+ }
105+
62106// handleExpressionUnaryComparisonOperator processes the unary comparison operator expression.
63107func handleExpressionUnaryComparisonOperator (expr * schema.ExpressionUnaryComparisonOperator , state * types.State , collection string ) (map [string ]interface {}, error ) {
64108 if expr .Operator == "is_null" {
65- fieldName , _ := joinFieldPath (state , expr .Column .FieldPath , expr .Column .Name , collection )
109+ if len (expr .Column .FieldPath ) == 0 || expr .Column .FieldPath [len (expr .Column .FieldPath )- 1 ] != expr .Column .Name {
110+ // if the column name is not the last element in fieldPath, we'll add it so that the fieldpath is complete
111+ expr .Column .FieldPath = append (expr .Column .FieldPath , expr .Column .Name )
112+ }
113+
114+ fieldName := strings .Join (expr .Column .FieldPath , "." )
66115 value := map [string ]interface {}{
67116 "field" : fieldName ,
68117 }
@@ -85,7 +134,12 @@ func handleExpressionBinaryComparisonOperator(
85134 state * types.State ,
86135 collection string ,
87136) (map [string ]interface {}, error ) {
88- fieldPath , nestedPath := joinFieldPath (state , expr .Column .FieldPath , expr .Column .Name , collection )
137+ if len (expr .Column .FieldPath ) == 0 || expr .Column .FieldPath [len (expr .Column .FieldPath )- 1 ] != expr .Column .Name {
138+ // if the column name is not the last element in fieldPath, we'll add it so that the fieldpath is complete
139+ expr .Column .FieldPath = append (expr .Column .FieldPath , expr .Column .Name )
140+ }
141+
142+ fieldPath := strings .Join (expr .Column .FieldPath , "." )
89143 fieldType , fieldSubTypes , _ , err := state .Configuration .GetFieldProperties (collection , fieldPath )
90144 if err != nil {
91145 return nil , schema .UnprocessableContentError ("unable to get field types" , map [string ]any {
@@ -110,9 +164,10 @@ func handleExpressionBinaryComparisonOperator(
110164 expr .Operator : value ,
111165 }
112166
113- if nestedPath != "" {
114- filter = prepareNestedQuery (state , expr .Operator , value , fieldPath , len (expr .Column .FieldPath ), collection )
115- }
167+ // TOOD: re-enable
168+ // if nestedPath != "" {
169+ // filter = prepareNestedQuery(state, expr.Operator, value, fieldPath, len(expr.Column.FieldPath), collection)
170+ // }
116171
117172 return filter , nil
118173}
0 commit comments