Skip to content

Commit 827fe46

Browse files
committed
Make expression testing clearer by using QueryBuilder
Also adds ast.Set and ast.Tuple handling methods
1 parent cfe2c11 commit 827fe46

File tree

2 files changed

+58
-32
lines changed

2 files changed

+58
-32
lines changed

python/arcticdb/version_store/processing.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,14 @@ def _(a: ast.BinOp, function_map) -> Any:
377377
@_ast_to_expression.register(ast.Compare)
378378
def _(a: ast.Compare, function_map) -> Any:
379379
# Compares in pyarrow Expression contain exactly one comparison (i.e. 1 < field("asdf") < 3 is not supported)
380-
assert len(a.ops) == 1
381-
assert len(a.comparators) == 1
380+
check(
381+
len(a.ops) == 1,
382+
f"Received a series of {len(a.ops)} comparisons, but only series of 1 comparison is supported. "
383+
"Use `(a < b) & (b < c)` instead of `a < b < c`.")
384+
check(
385+
len(a.comparators) == 1,
386+
f"Received a series of {len(a.comparators)} comparators, but only series of 1 comparison is supported. "
387+
"Use `(a < b) & (b < c)` instead of `a < b < c`.")
382388
op = a.ops[0]
383389
left = a.left
384390
right = a.comparators[0]
@@ -405,6 +411,16 @@ def _(a: ast.List, function_map) -> Any:
405411
return [_ast_to_expression(e, function_map) for e in a.elts]
406412

407413

414+
@_ast_to_expression.register(ast.Set)
415+
def _(a: ast.Set, function_map) -> Any:
416+
return set([_ast_to_expression(e, function_map) for e in a.elts])
417+
418+
419+
@_ast_to_expression.register(ast.Tuple)
420+
def _(a: ast.Tuple, function_map) -> Any:
421+
return tuple([_ast_to_expression(e, function_map) for e in a.elts])
422+
423+
408424
def is_supported_sequence(obj):
409425
return isinstance(obj, (list, set, frozenset, tuple, np.ndarray))
410426

python/tests/unit/arcticdb/version_store/test_query_builder_parse_pyarrow.py

+40-30
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ def df_with_all_column_types(num_rows=100):
2121
return pd.DataFrame(data=data, index=index)
2222

2323

24-
def compare_against_pyarrow(pyarrow_expr_str, expected_adb_expr, lib, function_map = None, expect_equal=True):
24+
def compare_against_pyarrow(pyarrow_expr_str, expected_adb_qb, lib, function_map = None, expect_equal=True):
2525
adb_expr = ExpressionNode.from_pyarrow_expression_str(pyarrow_expr_str, function_map)
26-
assert str(adb_expr) == str(expected_adb_expr)
26+
q = QueryBuilder()
27+
q = q[adb_expr]
28+
assert q == expected_adb_qb
2729
pa_expr = eval(pyarrow_expr_str)
2830

2931
# Setup
@@ -33,8 +35,6 @@ def compare_against_pyarrow(pyarrow_expr_str, expected_adb_expr, lib, function_m
3335
pa_table = pa.Table.from_pandas(df)
3436

3537
# Apply filter to adb
36-
q = QueryBuilder()
37-
q = q[adb_expr]
3838
adb_result = lib.read(sym, query_builder=q).data
3939

4040
# Apply filter to pyarrow
@@ -48,76 +48,86 @@ def compare_against_pyarrow(pyarrow_expr_str, expected_adb_expr, lib, function_m
4848

4949
def test_basic_filters(lmdb_version_store_v1):
5050
lib = lmdb_version_store_v1
51+
q = QueryBuilder()
5152

5253
# Filter by boolean column
5354
expr = f"pc.field('bool_col')"
54-
expected_expr = ExpressionNode.column_ref('bool_col')
55-
compare_against_pyarrow(expr, expected_expr, lib)
55+
expected_q = q[q['bool_col']]
56+
compare_against_pyarrow(expr, expected_q, lib)
5657

5758
# Filter by comparison
5859
for op in ["<", "<=", "==", ">=", ">"]:
5960
expr = f"pc.field('int_col') {op} 50"
60-
expected_expr = eval(f"ExpressionNode.column_ref('int_col') {op} 50")
61-
compare_against_pyarrow(expr, expected_expr, lib)
61+
expected_q = q[eval(f"q['int_col'] {op} 50")]
62+
compare_against_pyarrow(expr, expected_q, lib)
6263

6364
# Filter with unary operators
6465
expr = "~pc.field('bool_col')"
65-
expected_expr = ~ExpressionNode.column_ref('bool_col')
66-
compare_against_pyarrow(expr, expected_expr, lib)
66+
expected_q = q[~q['bool_col']]
67+
compare_against_pyarrow(expr, expected_q, lib)
6768

6869
# Filter with binary operators
6970
for op in ["+", "-", "*", "/"]:
7071
expr = f"pc.field('float_col') {op} 5.0 < 50.0"
71-
expected_expr = eval(f"ExpressionNode.column_ref('float_col') {op} 5.0 < 50.0")
72-
compare_against_pyarrow(expr, expected_expr, lib)
72+
expected_q = q[eval(f"q['float_col'] {op} 5.0 < 50.0")]
73+
compare_against_pyarrow(expr, expected_q, lib)
7374

7475
for op in ["&", "|"]:
7576
expr = f"pc.field('bool_col') {op} (pc.field('int_col') < 50)"
76-
expected_expr = eval(f"ExpressionNode.column_ref('bool_col') {op} (ExpressionNode.column_ref('int_col') < 50)")
77-
compare_against_pyarrow(expr, expected_expr, lib)
77+
expected_q = q[eval(f"q['bool_col'] {op} (q['int_col'] < 50)")]
78+
compare_against_pyarrow(expr, expected_q, lib)
7879

7980
# Filter with expression method calls
8081
expr = "pc.field('str_col').isin(['str_0', 'str_10', 'str_20'])"
81-
expected_expr = ExpressionNode.column_ref('str_col').isin(['str_0', 'str_10', 'str_20'])
82-
compare_against_pyarrow(expr, expected_expr, lib)
82+
expected_q = q[q['str_col'].isin(['str_0', 'str_10', 'str_20'])]
83+
compare_against_pyarrow(expr, expected_q, lib)
84+
85+
expr = "pc.field('str_col').isin(('str_0', 'str_10', 'str_20'))"
86+
expected_q = q[q['str_col'].isin(('str_0', 'str_10', 'str_20'))]
87+
compare_against_pyarrow(expr, expected_q, lib)
88+
89+
expr = "pc.field('str_col').isin({'str_0', 'str_10', 'str_20'})"
90+
expected_q = q[q['str_col'].isin({'str_0', 'str_10', 'str_20'})]
91+
compare_against_pyarrow(expr, expected_q, lib)
8392

8493
expr = "pc.field('float_col').is_nan()"
85-
expected_expr = ExpressionNode.column_ref('float_col').isnull()
94+
expected_q = q[q['float_col'].isnull()]
8695
# We expect a different result between adb and pyarrow because of the different nan/null handling
87-
compare_against_pyarrow(expr, expected_expr, lib, expect_equal=False)
96+
compare_against_pyarrow(expr, expected_q, lib, expect_equal=False)
8897

8998
expr = "pc.field('float_col').is_null()"
90-
expected_expr = ExpressionNode.column_ref('float_col').isnull()
91-
compare_against_pyarrow(expr, expected_expr, lib)
99+
expected_q = q[q['float_col'].isnull()]
100+
compare_against_pyarrow(expr, expected_q, lib)
92101

93102
expr = "pc.field('float_col').is_valid()"
94-
expected_expr = ExpressionNode.column_ref('float_col').notnull()
95-
compare_against_pyarrow(expr, expected_expr, lib)
103+
expected_q = q[q['float_col'].notnull()]
104+
compare_against_pyarrow(expr, expected_q, lib)
96105

97106
def test_complex_filters(lmdb_version_store_v1):
98107
lib = lmdb_version_store_v1
108+
q = QueryBuilder()
99109

100110
# Nested complex filters
101111
expr = "((pc.field('float_col') * 2) > 20.0) & (pc.field('int_col') <= pc.scalar(60)) | pc.field('bool_col')"
102-
expected_expr = (ExpressionNode.column_ref('float_col') * 2 > 20.0) & (ExpressionNode.column_ref('int_col') <= 60) | ExpressionNode.column_ref('bool_col')
103-
compare_against_pyarrow(expr, expected_expr, lib)
112+
expected_q = q[(q['float_col'] * 2 > 20.0) & (q['int_col'] <= 60) | q['bool_col']]
113+
compare_against_pyarrow(expr, expected_q, lib)
104114

105115
expr = "((pc.field('float_col') / 2) > 20.0) & (pc.field('float_col') <= pc.scalar(60)) & pc.field('str_col').isin(['str_30', 'str_41', 'str_42', 'str_53', 'str_99'])"
106-
expected_expr = (ExpressionNode.column_ref('float_col') / 2 > 20.0) & (ExpressionNode.column_ref('float_col') <= 60) & ExpressionNode.column_ref('str_col').isin(['str_30', 'str_41', 'str_42', 'str_53', 'str_99'])
107-
compare_against_pyarrow(expr, expected_expr, lib)
116+
expected_q = q[(q['float_col'] / 2 > 20.0) & (q['float_col'] <= 60) & q['str_col'].isin(['str_30', 'str_41', 'str_42', 'str_53', 'str_99'])]
117+
compare_against_pyarrow(expr, expected_q, lib)
108118

109119
# Filters with function calls
110120
function_map = {
111121
"datetime.datetime": datetime.datetime,
112122
"abs": abs,
113123
}
114124
expr = "pc.field('datetime_col') < datetime.datetime(2025, 1, 20)"
115-
expected_expr = ExpressionNode.column_ref('datetime_col') < datetime.datetime(2025, 1, 20)
116-
compare_against_pyarrow(expr, expected_expr, lib, function_map)
125+
expected_q = q[q['datetime_col'] < datetime.datetime(2025, 1, 20)]
126+
compare_against_pyarrow(expr, expected_q, lib, function_map)
117127

118128
expr = "(pc.field('datetime_col') < datetime.datetime(2025, 1, abs(-20))) & (pc.field('int_col') >= abs(-5))"
119-
expected_expr = (ExpressionNode.column_ref('datetime_col') < datetime.datetime(2025, 1, abs(-20))) & (ExpressionNode.column_ref('int_col') >= abs(-5))
120-
compare_against_pyarrow(expr, expected_expr, lib, function_map)
129+
expected_q = q[(q['datetime_col'] < datetime.datetime(2025, 1, abs(-20))) & (q['int_col'] >= abs(-5))]
130+
compare_against_pyarrow(expr, expected_q, lib, function_map)
121131

122132
def test_broken_filters():
123133
# ill-formated filter

0 commit comments

Comments
 (0)