Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion from pyarrow Expressions to QueryBuilder expressions #2202

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 172 additions & 1 deletion python/arcticdb/version_store/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
from collections import namedtuple
import copy
from collections.abc import Callable
from dataclasses import dataclass
import datetime
from math import inf
Expand All @@ -15,7 +16,12 @@
import pandas as pd
from pandas.tseries.frequencies import to_offset

from typing import Dict, NamedTuple, Optional, Tuple, Union
from typing import Dict, NamedTuple, Optional, Tuple, Union, Any

import sys
import ast

from functools import singledispatch

from arcticdb.exceptions import ArcticDbNotYetImplemented, ArcticNativeException, UserInputException
from arcticdb.version_store._normalization import normalize_dt_range_to_ts
Expand Down Expand Up @@ -250,6 +256,171 @@ def get_name(self):
return self.name


@classmethod
def _from_pyarrow_expression_str(cls, expression_str : str, function_map : Optional[Dict[str, Callable]] = None) -> "ExpressionNode":
"""
Builds an ExpressionNode from a pyarrow expression string.

It is required for an integration with polars predicate pushdown. We get the pyarrow expression as a string
because pyarrow doesn't provide any APIs for traversing the expression tree.

Any of pyarrow's `is_null`, `is_nan` and `is_valid` will get converted to our ArcticDB's `isnull` and `notnull`,
which don't differentiate nulls and nans.
"""
if function_map is None:
function_map = {}
try:
expression_ast = ast.parse(expression_str, mode="eval").body
return _ast_to_expression(expression_ast, function_map)
except Exception as e:
msg = f"Could not parse pyarrow expression as an arcticdb expression: {e}"
raise ValueError(msg)


@singledispatch
def _ast_to_expression(a: Any, function_map) -> Any:
"""Walks the AST to convert the PyArrow expression to an ArcticDB expression."""
raise ValueError(f"Unexpected symbol: {a}")


@_ast_to_expression.register(ast.Constant)
def _(a: ast.Constant, function_map) -> Any:
return a.value


if sys.version_info < (3, 8):
@_ast_to_expression.register(ast.Str)
def _(a: ast.Str, function_map) -> Any:
return a.s

@_ast_to_expression.register(ast.Num)
def _(a: ast.Num, function_map) -> Any:
return a.n

@_ast_to_expression.register(ast.Name)
def _(a: ast.Name, function_map) -> Any:
return a.id


@_ast_to_expression.register(ast.UnaryOp)
def _(a: ast.UnaryOp, function_map) -> Any:
operand = _ast_to_expression(a.operand, function_map)
if isinstance(a.op, ast.Invert):
return ~operand
if isinstance(a.op, ast.USub):
# pyarrow expressions don't support unary subrtract, so this branch will not be reached.
# Leaving as future-proofing in case they ever introduce it.
return -operand
raise ValueError(f"Unexpected UnaryOp: {a.op}")


@_ast_to_expression.register(ast.Call)
def _(a: ast.Call, function_map) -> Any:
f = _ast_to_expression(a.func, function_map)
args = [_ast_to_expression(arg, function_map) for arg in a.args]
if callable(f):
return f(*args)
if isinstance(f, str):
if f in function_map:
return function_map[f](*args)
raise ValueError(f"Unexpected function call: {f}")


@_ast_to_expression.register(ast.Attribute)
def _(a: ast.Attribute, function_map) -> Any:
value = _ast_to_expression(a.value, function_map)
attr = a.attr
if isinstance(value, ExpressionNode):
# Handles expression function attributes like (<some expression>).isin([1, 2, 3])
if attr == "isin":
return value.isin
if attr == "is_null" or attr == "is_nan":
return value.isnull
if attr == "is_valid":
return value.notnull
if isinstance(value, str):
# Handles attributes like "pa.compute.field" or "pc.field"
if attr == "field":
return ExpressionNode.column_ref
if attr == "scalar":
return lambda x: x
return f"{value}.{attr}"
raise ValueError(f"Unexpected attribute {attr} of {value}")


@_ast_to_expression.register(ast.BinOp)
def _(a: ast.BinOp, function_map) -> Any:
lhs = _ast_to_expression(a.left, function_map)
rhs = _ast_to_expression(a.right, function_map)

op = a.op
if isinstance(op, ast.BitAnd):
return lhs & rhs
if isinstance(op, ast.BitOr):
return lhs | rhs
if isinstance(op, ast.BitXor):
# pyarrow expressions don't support BitXor, so this branch will not be reached.
# Leaving as future-proofing in case they ever introduce it.
return lhs ^ rhs

if isinstance(op, ast.Add):
return lhs + rhs
if isinstance(op, ast.Sub):
return lhs - rhs
if isinstance(op, ast.Mult):
return lhs * rhs
if isinstance(op, ast.Div):
return lhs / rhs
raise ValueError(f"Unexpected BinOp: {op}")


@_ast_to_expression.register(ast.Compare)
def _(a: ast.Compare, function_map) -> Any:
# Compares in pyarrow Expression contain exactly one comparison (i.e. 1 < field("asdf") < 3 is not supported)
check(
len(a.ops) == 1,
f"Received a series of {len(a.ops)} comparisons, but only series of 1 comparison is supported. "
"Use `(a < b) & (b < c)` instead of `a < b < c`.")
check(
len(a.comparators) == 1,
f"Received a series of {len(a.comparators)} comparators, but only series of 1 comparison is supported. "
"Use `(a < b) & (b < c)` instead of `a < b < c`.")
op = a.ops[0]
left = a.left
right = a.comparators[0]
lhs = _ast_to_expression(left, function_map)
rhs = _ast_to_expression(right, function_map)

if isinstance(op, ast.Gt):
return lhs > rhs
if isinstance(op, ast.GtE):
return lhs >= rhs
if isinstance(op, ast.Eq):
return lhs == rhs
if isinstance(op, ast.NotEq):
return lhs != rhs
if isinstance(op, ast.Lt):
return lhs < rhs
if isinstance(op, ast.LtE):
return lhs <= rhs
raise ValueError(f"Unknown comparison: {op}")


@_ast_to_expression.register(ast.List)
def _(a: ast.List, function_map) -> Any:
return [_ast_to_expression(e, function_map) for e in a.elts]


@_ast_to_expression.register(ast.Set)
def _(a: ast.Set, function_map) -> Any:
return set([_ast_to_expression(e, function_map) for e in a.elts])


@_ast_to_expression.register(ast.Tuple)
def _(a: ast.Tuple, function_map) -> Any:
return tuple([_ast_to_expression(e, function_map) for e in a.elts])


def is_supported_sequence(obj):
return isinstance(obj, (list, set, frozenset, tuple, np.ndarray))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import datetime
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pytest

from arcticdb.version_store.processing import QueryBuilder, ExpressionNode
from arcticdb.util.test import assert_frame_equal


def df_with_all_column_types(num_rows=100):
data = {
"int_col": np.arange(num_rows, dtype=np.int64),
"float_col": [np.nan if i%20==5 else i for i in range(num_rows)],
"str_col": [f"str_{i}" for i in range(num_rows)],
"bool_col": [i%2 == 0 for i in range(num_rows)],
"datetime_col": pd.date_range(start=pd.Timestamp(2025, 1, 1), periods=num_rows)
}
index = pd.date_range(start=pd.Timestamp(2025, 1, 1), periods=num_rows)
return pd.DataFrame(data=data, index=index)


def compare_against_pyarrow(pyarrow_expr_str, expected_adb_qb, lib, function_map = None, expect_equal=True):
adb_expr = ExpressionNode._from_pyarrow_expression_str(pyarrow_expr_str, function_map)
q = QueryBuilder()
q = q[adb_expr]
assert q == expected_adb_qb
pa_expr = eval(pyarrow_expr_str)

# Setup
sym = "sym"
df = df_with_all_column_types()
lib.write(sym, df)
pa_table = pa.Table.from_pandas(df)

# Apply filter to adb
adb_result = lib.read(sym, query_builder=q).data

# Apply filter to pyarrow
pa_result = pa_table.filter(pa_expr).to_pandas()

if expect_equal:
assert_frame_equal(adb_result, pa_result)
else:
assert len(adb_result) != len(pa_result)


def test_basic_filters(lmdb_version_store_v1):
lib = lmdb_version_store_v1
q = QueryBuilder()

# Filter by boolean column
expr = f"pc.field('bool_col')"
expected_q = q[q['bool_col']]
compare_against_pyarrow(expr, expected_q, lib)

# Filter by comparison
for op in ["<", "<=", "==", ">=", ">"]:
expr = f"pc.field('int_col') {op} 50"
expected_q = q[eval(f"q['int_col'] {op} 50")]
compare_against_pyarrow(expr, expected_q, lib)

# Filter with unary operators
expr = "~pc.field('bool_col')"
expected_q = q[~q['bool_col']]
compare_against_pyarrow(expr, expected_q, lib)

# Filter with binary operators
for op in ["+", "-", "*", "/"]:
expr = f"pc.field('float_col') {op} 5.0 < 50.0"
expected_q = q[eval(f"q['float_col'] {op} 5.0 < 50.0")]
compare_against_pyarrow(expr, expected_q, lib)

for op in ["&", "|"]:
expr = f"pc.field('bool_col') {op} (pc.field('int_col') < 50)"
expected_q = q[eval(f"q['bool_col'] {op} (q['int_col'] < 50)")]
compare_against_pyarrow(expr, expected_q, lib)

# Filter with expression method calls
expr = "pc.field('str_col').isin(['str_0', 'str_10', 'str_20'])"
expected_q = q[q['str_col'].isin(['str_0', 'str_10', 'str_20'])]
compare_against_pyarrow(expr, expected_q, lib)

expr = "pc.field('str_col').isin(('str_0', 'str_10', 'str_20'))"
expected_q = q[q['str_col'].isin(('str_0', 'str_10', 'str_20'))]
compare_against_pyarrow(expr, expected_q, lib)

expr = "pc.field('str_col').isin({'str_0', 'str_10', 'str_20'})"
expected_q = q[q['str_col'].isin({'str_0', 'str_10', 'str_20'})]
compare_against_pyarrow(expr, expected_q, lib)

expr = "pc.field('float_col').is_nan()"
expected_q = q[q['float_col'].isnull()]
# We expect a different result between adb and pyarrow because of the different nan/null handling
compare_against_pyarrow(expr, expected_q, lib, expect_equal=False)

expr = "pc.field('float_col').is_null()"
expected_q = q[q['float_col'].isnull()]
compare_against_pyarrow(expr, expected_q, lib)

expr = "pc.field('float_col').is_valid()"
expected_q = q[q['float_col'].notnull()]
compare_against_pyarrow(expr, expected_q, lib)

def test_complex_filters(lmdb_version_store_v1):
lib = lmdb_version_store_v1
q = QueryBuilder()

# Nested complex filters
expr = "((pc.field('float_col') * 2) > 20.0) & (pc.field('int_col') <= pc.scalar(60)) | pc.field('bool_col')"
expected_q = q[(q['float_col'] * 2 > 20.0) & (q['int_col'] <= 60) | q['bool_col']]
compare_against_pyarrow(expr, expected_q, lib)

expr = "((pc.field('float_col') / 2) > 20.0) & (pc.field('float_col') <= pc.scalar(60)) & pc.field('str_col').isin(['str_30', 'str_41', 'str_42', 'str_53', 'str_99'])"
expected_q = q[(q['float_col'] / 2 > 20.0) & (q['float_col'] <= 60) & q['str_col'].isin(['str_30', 'str_41', 'str_42', 'str_53', 'str_99'])]
compare_against_pyarrow(expr, expected_q, lib)

# Filters with function calls
function_map = {
"datetime.datetime": datetime.datetime,
"abs": abs,
}
expr = "pc.field('datetime_col') < datetime.datetime(2025, 1, 20)"
expected_q = q[q['datetime_col'] < datetime.datetime(2025, 1, 20)]
compare_against_pyarrow(expr, expected_q, lib, function_map)

expr = "(pc.field('datetime_col') < datetime.datetime(2025, 1, abs(-20))) & (pc.field('int_col') >= abs(-5))"
expected_q = q[(q['datetime_col'] < datetime.datetime(2025, 1, abs(-20))) & (q['int_col'] >= abs(-5))]
compare_against_pyarrow(expr, expected_q, lib, function_map)

def test_broken_filters():
# ill-formated filter
expr = "pc.field('float_col'"
with pytest.raises(ValueError):
ExpressionNode._from_pyarrow_expression_str(expr)

# pyarrow expressions only support single comparisons
expr = "1 < pc.field('int_col') < 10"
with pytest.raises(ValueError):
ExpressionNode._from_pyarrow_expression_str(expr)

# calling a mising function
expr = "some.missing.function(5)"
with pytest.raises(ValueError):
ExpressionNode._from_pyarrow_expression_str(expr)
Loading