diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml new file mode 100644 index 0000000..1e16d6e --- /dev/null +++ b/.github/workflows/ruff.yml @@ -0,0 +1,34 @@ +name: Run linter and formatter + +on: + pull_request: + push: + branches: [ main ] + tags: [ 'v*.*.*' ] + +permissions: + contents: read + +jobs: + test: + name: Lint and Format + strategy: + matrix: + os: [ubuntu-latest] + python: ["3.9"] + runs-on: ${{ matrix.os }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + submodules: recursive + - name: Install uv with python + uses: astral-sh/setup-uv@v6 + with: + python-version: ${{ matrix.python }} + - name: Run ruff linter + run: | + uvx ruff@0.11.11 check + - name: Run ruff formatter + run: | + uvx ruff@0.11.11 format \ No newline at end of file diff --git a/examples/adbc_example.py b/examples/adbc_example.py index 530f9f3..10994f6 100644 --- a/examples/adbc_example.py +++ b/examples/adbc_example.py @@ -25,11 +25,15 @@ names=["ints", "strs"], ) + def read_adbc_named_table(name: str, conn): pa_schema = conn.adbc_get_table_schema(name) - substrait_schema = pa_substrait.serialize_schema(pa_schema).to_pysubstrait().base_schema + substrait_schema = ( + pa_substrait.serialize_schema(pa_schema).to_pysubstrait().base_schema + ) return read_named_table(name, substrait_schema) + with adbc_driver_duckdb.dbapi.connect(":memory:") as conn: with conn.cursor() as cur: cur.adbc_ingest("AnswerToEverything", data) @@ -38,7 +42,14 @@ def read_adbc_named_table(name: str, conn): cur.executescript("LOAD substrait;") table = read_adbc_named_table("AnswerToEverything", conn) - table = filter(table, expression=scalar_function('functions_comparison.yaml', 'gte', column('ints'), literal(3, i64()))) + table = filter( + table, + expression=scalar_function( + "functions_comparison.yaml", + "gte", + expressions=[column("ints"), literal(3, i64())], + ), + ) cur.execute(table(registry).SerializeToString()) - print(cur.fetch_arrow_table()) \ No newline at end of file + print(cur.fetch_arrow_table()) diff --git a/examples/builder_example.py b/examples/builder_example.py index b0c9ed6..0a9dd0d 100644 --- a/examples/builder_example.py +++ b/examples/builder_example.py @@ -6,19 +6,20 @@ registry = ExtensionRegistry(load_default_extensions=True) ns = named_struct( - names=["id", "is_applicable"], - struct=struct( - types=[ - i64(nullable=False), - boolean() - ] - ) + names=["id", "is_applicable"], struct=struct(types=[i64(nullable=False), boolean()]) ) -table = read_named_table('example_table', ns) -table = filter(table, expression=column('is_applicable')) -table = filter(table, expression=scalar_function('functions_comparison.yaml', 'lt', column('id'), literal(100, i64()))) -table = project(table, expressions=[column('id')]) +table = read_named_table("example_table", ns) +table = filter(table, expression=column("is_applicable")) +table = filter( + table, + expression=scalar_function( + "functions_comparison.yaml", + "lt", + expressions=[column("id"), literal(100, i64())], + ), +) +table = project(table, expressions=[column("id")]) print(table(registry)) diff --git a/examples/duckdb_example.py b/examples/duckdb_example.py index bfdde7d..8aacaec 100644 --- a/examples/duckdb_example.py +++ b/examples/duckdb_example.py @@ -18,7 +18,7 @@ try: duckdb.install_extension("substrait") -except: +except duckdb.duckdb.HTTPException: duckdb.install_extension("substrait", repository="community") duckdb.load_extension("substrait") @@ -29,14 +29,27 @@ registry = ExtensionRegistry(load_default_extensions=True) + def read_duckdb_named_table(name: str, conn): pa_schema = conn.sql(f"SELECT * FROM {name} LIMIT 0").arrow().schema - substrait_schema = pa_substrait.serialize_schema(pa_schema).to_pysubstrait().base_schema + substrait_schema = ( + pa_substrait.serialize_schema(pa_schema).to_pysubstrait().base_schema + ) return read_named_table(name, substrait_schema) + table = read_duckdb_named_table("customer", duckdb) -table = filter(table, expression=scalar_function('functions_comparison.yaml', 'equal', column('c_nationkey'), literal(3, i32()))) -table = project(table, expressions=[column('c_name'), column('c_address'), column('c_nationkey')]) +table = filter( + table, + expression=scalar_function( + "functions_comparison.yaml", + "equal", + expressions=[column("c_nationkey"), literal(3, i32())], + ), +) +table = project( + table, expressions=[column("c_name"), column("c_address"), column("c_nationkey")] +) sql = f"CALL from_substrait_json('{dump_json(table(registry))}')" print(duckdb.sql(sql)) diff --git a/examples/pyarrow_example.py b/examples/pyarrow_example.py index d4c0d36..f6fcaa6 100644 --- a/examples/pyarrow_example.py +++ b/examples/pyarrow_example.py @@ -11,21 +11,18 @@ import substrait from substrait.builders.plan import project, read_named_table -arrow_schema = pa.schema([ - pa.field("x", pa.int32()), - pa.field("y", pa.int32()) -]) +arrow_schema = pa.schema([pa.field("x", pa.int32()), pa.field("y", pa.int32())]) -substrait_schema = pa_substrait.serialize_schema(arrow_schema).to_pysubstrait().base_schema +substrait_schema = ( + pa_substrait.serialize_schema(arrow_schema).to_pysubstrait().base_schema +) substrait_expr = pa_substrait.serialize_expressions( - exprs=[pc.field("x") + pc.field("y")], - names=["total"], - schema=arrow_schema + exprs=[pc.field("x") + pc.field("y")], names=["total"], schema=arrow_schema ) pysubstrait_expr = substrait.proto.ExtendedExpression.FromString(bytes(substrait_expr)) table = read_named_table("example", substrait_schema) table = project(table, expressions=[pysubstrait_expr])(None) -print(table) \ No newline at end of file +print(table) diff --git a/pyproject.toml b/pyproject.toml index 4c4ab62..bba7c97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,3 @@ respect-gitignore = true target-version = "py39" # never autoformat upstream or generated code exclude = ["third_party/", "src/substrait/gen"] -# do not autofix the following (will still get flagged in lint) -lint.unfixable = [ - "F401", # unused imports - "T201", # print statements -] diff --git a/src/substrait/builders/extended_expression.py b/src/substrait/builders/extended_expression.py index 39bcb66..a7c8aad 100644 --- a/src/substrait/builders/extended_expression.py +++ b/src/substrait/builders/extended_expression.py @@ -5,69 +5,126 @@ import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.extensions.extensions_pb2 as ste from substrait.extension_registry import ExtensionRegistry -from substrait.utils import type_num_names, merge_extension_uris, merge_extension_declarations +from substrait.utils import ( + type_num_names, + merge_extension_uris, + merge_extension_declarations, +) from substrait.type_inference import infer_extended_expression_schema from typing import Callable, Any, Union, Iterable -UnboundExtendedExpression = Callable[[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression] +UnboundExtendedExpression = Callable[ + [stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression +] ExtendedExpressionOrUnbound = Union[stee.ExtendedExpression, UnboundExtendedExpression] + def _alias_or_inferred( - alias: Union[Iterable[str], str], - op: str, - args: Iterable[str], - ): + alias: Union[Iterable[str], str], + op: str, + args: Iterable[str], +): if alias: return [alias] if isinstance(alias, str) else alias else: - return [f'{op}({",".join(args)})'] + return [f"{op}({','.join(args)})"] -def resolve_expression( - expression: ExtendedExpressionOrUnbound, - base_schema: stp.NamedStruct, - registry: ExtensionRegistry - ) -> stee.ExtendedExpression: - return expression if isinstance(expression, stee.ExtendedExpression) else expression(base_schema, registry) -def literal(value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None) -> UnboundExtendedExpression: +def resolve_expression( + expression: ExtendedExpressionOrUnbound, + base_schema: stp.NamedStruct, + registry: ExtensionRegistry, +) -> stee.ExtendedExpression: + return ( + expression + if isinstance(expression, stee.ExtendedExpression) + else expression(base_schema, registry) + ) + + +def literal( + value: Any, type: stp.Type, alias: Union[Iterable[str], str] = None +) -> UnboundExtendedExpression: """Builds a resolver for ExtendedExpression containing a literal expression""" - def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.ExtendedExpression: - kind = type.WhichOneof('kind') + + def resolve( + base_schema: stp.NamedStruct, registry: ExtensionRegistry + ) -> stee.ExtendedExpression: + kind = type.WhichOneof("kind") if kind == "bool": - literal = stalg.Expression.Literal(boolean=value, nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE) + literal = stalg.Expression.Literal( + boolean=value, + nullable=type.bool.nullability == stp.Type.NULLABILITY_NULLABLE, + ) elif kind == "i8": - literal = stalg.Expression.Literal(i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE) + literal = stalg.Expression.Literal( + i8=value, nullable=type.i8.nullability == stp.Type.NULLABILITY_NULLABLE + ) elif kind == "i16": - literal = stalg.Expression.Literal(i16=value, nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE) + literal = stalg.Expression.Literal( + i16=value, + nullable=type.i16.nullability == stp.Type.NULLABILITY_NULLABLE, + ) elif kind == "i32": - literal = stalg.Expression.Literal(i32=value, nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE) + literal = stalg.Expression.Literal( + i32=value, + nullable=type.i32.nullability == stp.Type.NULLABILITY_NULLABLE, + ) elif kind == "i64": - literal = stalg.Expression.Literal(i64=value, nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE) + literal = stalg.Expression.Literal( + i64=value, + nullable=type.i64.nullability == stp.Type.NULLABILITY_NULLABLE, + ) elif kind == "fp32": - literal = stalg.Expression.Literal(fp32=value, nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE) + literal = stalg.Expression.Literal( + fp32=value, + nullable=type.fp32.nullability == stp.Type.NULLABILITY_NULLABLE, + ) elif kind == "fp64": - literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE) + literal = stalg.Expression.Literal( + fp64=value, + nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE, + ) elif kind == "string": - literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE) + literal = stalg.Expression.Literal( + string=value, + nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE, + ) elif kind == "binary": - literal = stalg.Expression.Literal(binary=value, nullable=type.binary.nullability == stp.Type.NULLABILITY_NULLABLE) + literal = stalg.Expression.Literal( + binary=value, + nullable=type.binary.nullability == stp.Type.NULLABILITY_NULLABLE, + ) elif kind == "date": - date_value = (value - date(1970,1,1)).days if isinstance(value, date) else value - literal = stalg.Expression.Literal(date=date_value, nullable=type.date.nullability == stp.Type.NULLABILITY_NULLABLE) + date_value = ( + (value - date(1970, 1, 1)).days if isinstance(value, date) else value + ) + literal = stalg.Expression.Literal( + date=date_value, + nullable=type.date.nullability == stp.Type.NULLABILITY_NULLABLE, + ) # TODO # IntervalYearToMonth interval_year_to_month = 19; # IntervalDayToSecond interval_day_to_second = 20; # IntervalCompound interval_compound = 36; elif kind == "fixed_char": - literal = stalg.Expression.Literal(fixed_char=value, nullable=type.fixed_char.nullability == stp.Type.NULLABILITY_NULLABLE) + literal = stalg.Expression.Literal( + fixed_char=value, + nullable=type.fixed_char.nullability == stp.Type.NULLABILITY_NULLABLE, + ) elif kind == "varchar": literal = stalg.Expression.Literal( - var_char=stalg.Expression.Literal.VarChar(value=value, length=type.varchar.length), - nullable=type.varchar.nullability == stp.Type.NULLABILITY_NULLABLE + var_char=stalg.Expression.Literal.VarChar( + value=value, length=type.varchar.length + ), + nullable=type.varchar.nullability == stp.Type.NULLABILITY_NULLABLE, ) elif kind == "fixed_binary": - literal = stalg.Expression.Literal(fixed_binary=value, nullable=type.fixed_binary.nullability == stp.Type.NULLABILITY_NULLABLE) + literal = stalg.Expression.Literal( + fixed_binary=value, + nullable=type.fixed_binary.nullability == stp.Type.NULLABILITY_NULLABLE, + ) # TODO # Decimal decimal = 24; # PrecisionTime precision_time = 37; // Time in precision units past midnight. @@ -86,10 +143,8 @@ def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.E return stee.ExtendedExpression( referred_expr=[ stee.ExpressionReference( - expression=stalg.Expression( - literal=literal - ), - output_names=_alias_or_inferred(alias, 'Literal', [str(value)]) + expression=stalg.Expression(literal=literal), + output_names=_alias_or_inferred(alias, "Literal", [str(value)]), ) ], base_schema=base_schema, @@ -97,6 +152,7 @@ def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.E return resolve + def column(field: Union[str, int], alias: Union[Iterable[str], str] = None): """Builds a resolver for ExtendedExpression containing a FieldReference expression @@ -146,10 +202,15 @@ def resolve( return resolve + def scalar_function( - uri: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], alias: Union[Iterable[str], str] = None + uri: str, + function: str, + expressions: Iterable[ExtendedExpressionOrUnbound], + alias: Union[Iterable[str], str] = None, ): """Builds a resolver for ExtendedExpression containing a ScalarFunction expression""" + def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: @@ -207,7 +268,11 @@ def resolve( output_type=func[1], ) ), - output_names=_alias_or_inferred(alias, function, [e.referred_expr[0].output_names[0] for e in bound_expressions]), + output_names=_alias_or_inferred( + alias, + function, + [e.referred_expr[0].output_names[0] for e in bound_expressions], + ), ) ], base_schema=base_schema, @@ -217,10 +282,15 @@ def resolve( return resolve + def aggregate_function( - uri: str, function: str, expressions: Iterable[ExtendedExpressionOrUnbound], alias: Union[Iterable[str], str] = None + uri: str, + function: str, + expressions: Iterable[ExtendedExpressionOrUnbound], + alias: Union[Iterable[str], str] = None, ): """Builds a resolver for ExtendedExpression containing a AggregateFunction measure""" + def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: @@ -274,7 +344,11 @@ def resolve( ], output_type=func[1], ), - output_names=_alias_or_inferred(alias, 'IfThen', [e.referred_expr[0].output_names[0] for e in bound_expressions]), + output_names=_alias_or_inferred( + alias, + "IfThen", + [e.referred_expr[0].output_names[0] for e in bound_expressions], + ), ) ], base_schema=base_schema, @@ -291,9 +365,10 @@ def window_function( function: str, expressions: Iterable[ExtendedExpressionOrUnbound], partitions: Iterable[ExtendedExpressionOrUnbound] = [], - alias: Union[Iterable[str], str] = None + alias: Union[Iterable[str], str] = None, ): """Builds a resolver for ExtendedExpression containing a WindowFunction expression""" + def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: @@ -301,7 +376,9 @@ def resolve( resolve_expression(e, base_schema, registry) for e in expressions ] - bound_partitions = [resolve_expression(e, base_schema, registry) for e in partitions] + bound_partitions = [ + resolve_expression(e, base_schema, registry) for e in partitions + ] expression_schemas = [ infer_extended_expression_schema(b) for b in bound_expressions @@ -360,7 +437,11 @@ def resolve( ], ) ), - output_names=_alias_or_inferred(alias, function, [e.referred_expr[0].output_names[0] for e in bound_expressions]), + output_names=_alias_or_inferred( + alias, + function, + [e.referred_expr[0].output_names[0] for e in bound_expressions], + ), ) ], base_schema=base_schema, @@ -371,13 +452,21 @@ def resolve( return resolve -def if_then(ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]], _else: ExtendedExpressionOrUnbound, alias: Union[Iterable[str], str] = None): +def if_then( + ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]], + _else: ExtendedExpressionOrUnbound, + alias: Union[Iterable[str], str] = None, +): """Builds a resolver for ExtendedExpression containing an IfThen expression""" + def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: bound_ifs = [ - (resolve_expression(if_clause[0], base_schema, registry), resolve_expression(if_clause[1], base_schema, registry)) + ( + resolve_expression(if_clause[0], base_schema, registry), + resolve_expression(if_clause[1], base_schema, registry), + ) for if_clause in ifs ] @@ -386,33 +475,51 @@ def resolve( extension_uris = merge_extension_uris( *[b[0].extension_uris for b in bound_ifs], *[b[1].extension_uris for b in bound_ifs], - bound_else.extension_uris + bound_else.extension_uris, ) extensions = merge_extension_declarations( *[b[0].extensions for b in bound_ifs], *[b[1].extensions for b in bound_ifs], - bound_else.extensions + bound_else.extensions, ) return stee.ExtendedExpression( referred_expr=[ stee.ExpressionReference( expression=stalg.Expression( - if_then=stalg.Expression.IfThen(**{ - 'ifs': [ - stalg.Expression.IfThen.IfClause(**{ - 'if': if_clause[0].referred_expr[0].expression, - 'then': if_clause[1].referred_expr[0].expression, - }) - for if_clause in bound_ifs - ], - 'else': bound_else.referred_expr[0].expression - }) + if_then=stalg.Expression.IfThen( + **{ + "ifs": [ + stalg.Expression.IfThen.IfClause( + **{ + "if": if_clause[0] + .referred_expr[0] + .expression, + "then": if_clause[1] + .referred_expr[0] + .expression, + } + ) + for if_clause in bound_ifs + ], + "else": bound_else.referred_expr[0].expression, + } + ) + ), + output_names=_alias_or_inferred( + alias, + "IfThen", + [ + a + for e in bound_ifs + for a in [ + e[0].referred_expr[0].output_names[0], + e[1].referred_expr[0].output_names[0], + ] + ] + + [bound_else.referred_expr[0].output_names[0]], ), - output_names=_alias_or_inferred(alias, 'IfThen', [a for e in bound_ifs for a in [e[0].referred_expr[0].output_names[0], e[1].referred_expr[0].output_names[0]]] - + [bound_else.referred_expr[0].output_names[0]] - ), ) ], base_schema=base_schema, @@ -422,10 +529,14 @@ def resolve( return resolve -def switch(match: ExtendedExpressionOrUnbound, - ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]], - _else: ExtendedExpressionOrUnbound): + +def switch( + match: ExtendedExpressionOrUnbound, + ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]], + _else: ExtendedExpressionOrUnbound, +): """Builds a resolver for ExtendedExpression containing a switch expression""" + def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: @@ -433,20 +544,22 @@ def resolve( bound_ifs = [ ( resolve_expression(a, base_schema, registry), - resolve_expression(b, base_schema, registry) - ) for a, b in ifs] + resolve_expression(b, base_schema, registry), + ) + for a, b in ifs + ] bound_else = resolve_expression(_else, base_schema, registry) extension_uris = merge_extension_uris( bound_match.extension_uris, *[b.extension_uris for _, b in bound_ifs], - bound_else.extension_uris + bound_else.extension_uris, ) extensions = merge_extension_declarations( bound_match.extensions, *[b.extensions for _, b in bound_ifs], - bound_else.extensions + bound_else.extensions, ) return stee.ExtendedExpression( @@ -456,29 +569,33 @@ def resolve( switch_expression=stalg.Expression.SwitchExpression( match=bound_match.referred_expr[0].expression, ifs=[ - stalg.Expression.SwitchExpression.IfValue(**{ - 'if': i.referred_expr[0].expression.literal, - 'then': t.referred_expr[0].expression - }) + stalg.Expression.SwitchExpression.IfValue( + **{ + "if": i.referred_expr[0].expression.literal, + "then": t.referred_expr[0].expression, + } + ) for i, t in bound_ifs ], - **{ - 'else': bound_else.referred_expr[0].expression - } + **{"else": bound_else.referred_expr[0].expression}, ) ), - output_names=['switch'] #TODO construct name from inputs + output_names=["switch"], # TODO construct name from inputs ) ], base_schema=base_schema, extension_uris=extension_uris, extensions=extensions, ) - + return resolve -def singular_or_list(value: ExtendedExpressionOrUnbound, options: Iterable[ExtendedExpressionOrUnbound]): + +def singular_or_list( + value: ExtendedExpressionOrUnbound, options: Iterable[ExtendedExpressionOrUnbound] +): """Builds a resolver for ExtendedExpression containing a SingularOrList expression""" + def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: @@ -486,13 +603,11 @@ def resolve( bound_options = [resolve_expression(o, base_schema, registry) for o in options] extension_uris = merge_extension_uris( - bound_value.extension_uris, - *[b.extension_uris for b in bound_options] + bound_value.extension_uris, *[b.extension_uris for b in bound_options] ) extensions = merge_extension_declarations( - bound_value.extensions, - *[b.extensions for b in bound_options] + bound_value.extensions, *[b.extensions for b in bound_options] ) return stee.ExtendedExpression( @@ -502,23 +617,29 @@ def resolve( singular_or_list=stalg.Expression.SingularOrList( value=bound_value.referred_expr[0].expression, options=[ - o.referred_expr[0].expression - for o in bound_options - ] + o.referred_expr[0].expression for o in bound_options + ], ) ), - output_names=['singular_or_list'] #TODO construct name from inputs + output_names=[ + "singular_or_list" + ], # TODO construct name from inputs ) ], base_schema=base_schema, extension_uris=extension_uris, extensions=extensions, ) - + return resolve -def multi_or_list(value: Iterable[ExtendedExpressionOrUnbound], options: Iterable[Iterable[ExtendedExpressionOrUnbound]]): + +def multi_or_list( + value: Iterable[ExtendedExpressionOrUnbound], + options: Iterable[Iterable[ExtendedExpressionOrUnbound]], +): """Builds a resolver for ExtendedExpression containing a MultiOrList expression""" + def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: @@ -545,24 +666,28 @@ def resolve( value=[e.referred_expr[0].expression for e in bound_value], options=[ stalg.Expression.MultiOrList.Record( - fields=[e.referred_expr[0].expression for e in option] + fields=[ + e.referred_expr[0].expression for e in option + ] ) for option in bound_options - ] + ], ) ), - output_names=['multi_or_list'] #TODO construct name from inputs + output_names=["multi_or_list"], # TODO construct name from inputs ) ], base_schema=base_schema, extension_uris=extension_uris, extensions=extensions, ) - + return resolve + def cast(input: ExtendedExpressionOrUnbound, type: stp.Type): """Builds a resolver for ExtendedExpression containing a cast expression""" + def resolve( base_schema: stp.NamedStruct, registry: ExtensionRegistry ) -> stee.ExtendedExpression: @@ -575,15 +700,15 @@ def resolve( cast=stalg.Expression.Cast( input=bound_input.referred_expr[0].expression, type=type, - failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL + failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL, ) ), - output_names=['cast'] #TODO construct name from inputs + output_names=["cast"], # TODO construct name from inputs ) ], base_schema=base_schema, extension_uris=bound_input.extension_uris, extensions=bound_input.extensions, ) - + return resolve diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index 29fc193..51cc64a 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -12,7 +12,10 @@ import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.extended_expression_pb2 as stee from substrait.extension_registry import ExtensionRegistry -from substrait.builders.extended_expression import ExtendedExpressionOrUnbound, resolve_expression +from substrait.builders.extended_expression import ( + ExtendedExpressionOrUnbound, + resolve_expression, +) from substrait.type_inference import infer_plan_schema from substrait.utils import merge_extension_declarations, merge_extension_uris @@ -20,6 +23,7 @@ PlanOrUnbound = Union[stp.Plan, UnboundPlan] + def _merge_extensions(*objs): return { "extension_uris": merge_extension_uris(*[b.extension_uris for b in objs]), @@ -27,7 +31,9 @@ def _merge_extensions(*objs): } -def read_named_table(names: Union[str, Iterable[str]], named_struct: stt.NamedStruct) -> UnboundPlan: +def read_named_table( + names: Union[str, Iterable[str]], named_struct: stt.NamedStruct +) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: _names = [names] if isinstance(names, str) else names @@ -40,10 +46,11 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: ) return stp.Plan( - relations=[stp.PlanRel(root=stalg.RelRoot( - input=rel, names=named_struct.names))] + relations=[ + stp.PlanRel(root=stalg.RelRoot(input=rel, names=named_struct.names)) + ] ) - + return resolve @@ -53,11 +60,15 @@ def project( def resolve(registry: ExtensionRegistry) -> stp.Plan: _plan = plan if isinstance(plan, stp.Plan) else plan(registry) ns = infer_plan_schema(_plan) - bound_expressions: Iterable[stee.ExtendedExpression] = [resolve_expression(e, ns, registry) for e in expressions] + bound_expressions: Iterable[stee.ExtendedExpression] = [ + resolve_expression(e, ns, registry) for e in expressions + ] start_index = len(_plan.relations[-1].root.names) - names = [e.output_names[0] for ee in bound_expressions for e in ee.referred_expr] + names = [ + e.output_names[0] for ee in bound_expressions for e in ee.referred_expr + ] rel = stalg.Rel( project=stalg.ProjectRel( @@ -68,7 +79,8 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: ), input=_plan.relations[-1].root.input, expressions=[ - e.expression for ee in bound_expressions for e in ee.referred_expr], + e.expression for ee in bound_expressions for e in ee.referred_expr + ], ) ) @@ -76,17 +88,17 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], **_merge_extensions(_plan, *bound_expressions), ) - + return resolve -def filter( - plan: PlanOrUnbound, expression: ExtendedExpressionOrUnbound -) -> UnboundPlan: +def filter(plan: PlanOrUnbound, expression: ExtendedExpressionOrUnbound) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) ns = infer_plan_schema(bound_plan) - bound_expression: stee.ExtendedExpression = resolve_expression(expression, ns, registry) + bound_expression: stee.ExtendedExpression = resolve_expression( + expression, ns, registry + ) rel = stalg.Rel( filter=stalg.FilterRel( @@ -101,20 +113,32 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], **_merge_extensions(bound_plan, bound_expression), ) - + return resolve def sort( plan: PlanOrUnbound, - expressions: Iterable[Union[ExtendedExpressionOrUnbound, tuple[ExtendedExpressionOrUnbound, stalg.SortField.SortDirection.ValueType]]] -) -> UnboundPlan: + expressions: Iterable[ + Union[ + ExtendedExpressionOrUnbound, + tuple[ExtendedExpressionOrUnbound, stalg.SortField.SortDirection.ValueType], + ] + ], +) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) ns = infer_plan_schema(bound_plan) - - bound_expressions = [(e, stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST) if not isinstance(e, tuple) else e for e in expressions] - bound_expressions = [(resolve_expression(e[0], ns, registry), e[1]) for e in bound_expressions] + + bound_expressions = [ + (e, stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST) + if not isinstance(e, tuple) + else e + for e in expressions + ] + bound_expressions = [ + (resolve_expression(e[0], ns, registry), e[1]) for e in bound_expressions + ] rel = stalg.Rel( sort=stalg.SortRel( @@ -133,7 +157,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], **_merge_extensions(bound_plan, *[e[0] for e in bound_expressions]), ) - + return resolve @@ -149,21 +173,26 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: return stp.Plan( relations=[ stp.PlanRel( - root=stalg.RelRoot(input=rel, names=bound_inputs[0].relations[-1].root.names) + root=stalg.RelRoot( + input=rel, names=bound_inputs[0].relations[-1].root.names + ) ) ], **_merge_extensions(*bound_inputs), ) - + return resolve -def fetch(plan: PlanOrUnbound, - offset: ExtendedExpressionOrUnbound, - count: ExtendedExpressionOrUnbound) -> UnboundPlan: + +def fetch( + plan: PlanOrUnbound, + offset: ExtendedExpressionOrUnbound, + count: ExtendedExpressionOrUnbound, +) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) ns = infer_plan_schema(bound_plan) - + bound_offset = resolve_expression(offset, ns, registry) bound_count = resolve_expression(count, ns, registry) @@ -171,19 +200,21 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: fetch=stalg.FetchRel( input=bound_plan.relations[-1].root.input, offset_expr=bound_offset.referred_expr[0].expression, - count_expr=bound_count.referred_expr[0].expression + count_expr=bound_count.referred_expr[0].expression, ) ) return stp.Plan( relations=[ stp.PlanRel( - root=stalg.RelRoot(input=rel, names=bound_plan.relations[-1].root.names) + root=stalg.RelRoot( + input=rel, names=bound_plan.relations[-1].root.names + ) ) ], **_merge_extensions(bound_plan, bound_offset, bound_count), ) - + return resolve @@ -206,7 +237,9 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: ), names=list(left_ns.names) + list(right_ns.names), ) - bound_expression: stee.ExtendedExpression = resolve_expression(expression, ns, registry) + bound_expression: stee.ExtendedExpression = resolve_expression( + expression, ns, registry + ) rel = stalg.Rel( join=stalg.JoinRel( @@ -221,9 +254,10 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], **_merge_extensions(bound_left, bound_right, bound_expression), ) - + return resolve + def cross( left: PlanOrUnbound, right: PlanOrUnbound, @@ -233,7 +267,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_right = right if isinstance(right, stp.Plan) else right(registry) left_ns = infer_plan_schema(bound_left) right_ns = infer_plan_schema(bound_right) - + ns = stt.NamedStruct( struct=stt.Type.Struct( types=list(left_ns.struct.types) + list(right_ns.struct.types), @@ -245,17 +279,18 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: rel = stalg.Rel( cross=stalg.CrossRel( left=bound_left.relations[-1].root.input, - right=bound_right.relations[-1].root.input - ) + right=bound_right.relations[-1].root.input, + ) ) return stp.Plan( relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], **_merge_extensions(bound_left, bound_right), ) - + return resolve + # TODO grouping sets def aggregate( input: PlanOrUnbound, @@ -266,7 +301,9 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_input = input if isinstance(input, stp.Plan) else input(registry) ns = infer_plan_schema(bound_input) - bound_grouping_expressions = [resolve_expression(e, ns, registry) for e in grouping_expressions] + bound_grouping_expressions = [ + resolve_expression(e, ns, registry) for e in grouping_expressions + ] bound_measures = [resolve_expression(e, ns, registry) for e in measures] rel = stalg.Rel( @@ -279,7 +316,8 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: stalg.AggregateRel.Grouping( expression_references=range(len(bound_grouping_expressions)), grouping_expressions=[ - e.referred_expr[0].expression for e in bound_grouping_expressions + e.referred_expr[0].expression + for e in bound_grouping_expressions ], ) ], @@ -290,13 +328,15 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: ) ) - names = [e.referred_expr[0].output_names[0] for e in bound_grouping_expressions] + [ - e.referred_expr[0].output_names[0] for e in bound_measures - ] + names = [ + e.referred_expr[0].output_names[0] for e in bound_grouping_expressions + ] + [e.referred_expr[0].output_names[0] for e in bound_measures] return stp.Plan( relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], - **_merge_extensions(bound_input, *bound_grouping_expressions, *bound_measures), + **_merge_extensions( + bound_input, *bound_grouping_expressions, *bound_measures + ), ) - + return resolve diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index f0a51a1..8405595 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -1,77 +1,260 @@ from typing import Iterable import substrait.gen.proto.type_pb2 as stt + def boolean(nullable=True) -> stt.Type: - return stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + bool=stt.Type.Boolean( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def i8(nullable=True) -> stt.Type: - return stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + i8=stt.Type.I8( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def i16(nullable=True) -> stt.Type: - return stt.Type(i16=stt.Type.I16(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + i16=stt.Type.I16( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def i32(nullable=True) -> stt.Type: - return stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + i32=stt.Type.I32( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def i64(nullable=True) -> stt.Type: - return stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + i64=stt.Type.I64( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def fp32(nullable=True) -> stt.Type: - return stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + fp32=stt.Type.FP32( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def fp64(nullable=True) -> stt.Type: - return stt.Type(fp64=stt.Type.FP64(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + fp64=stt.Type.FP64( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def string(nullable=True) -> stt.Type: - return stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + string=stt.Type.String( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def binary(nullable=True) -> stt.Type: - return stt.Type(binary=stt.Type.Binary(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + binary=stt.Type.Binary( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def date(nullable=True) -> stt.Type: - return stt.Type(date=stt.Type.Date(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + date=stt.Type.Date( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def interval_year(nullable=True) -> stt.Type: - return stt.Type(interval_year=stt.Type.IntervalYear(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + interval_year=stt.Type.IntervalYear( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def interval_day(precision: int, nullable=True) -> stt.Type: - return stt.Type(interval_day=stt.Type.IntervalDay(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + interval_day=stt.Type.IntervalDay( + precision=precision, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def interval_compound(precision: int, nullable=True) -> stt.Type: - return stt.Type(interval_compound=stt.Type.IntervalCompound(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + interval_compound=stt.Type.IntervalCompound( + precision=precision, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def uuid(nullable=True) -> stt.Type: - return stt.Type(uuid=stt.Type.UUID(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + uuid=stt.Type.UUID( + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED + ) + ) + def fixed_char(length: int, nullable=True) -> stt.Type: - return stt.Type(fixed_char=stt.Type.FixedChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + fixed_char=stt.Type.FixedChar( + length=length, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def var_char(length: int, nullable=True) -> stt.Type: - return stt.Type(varchar=stt.Type.VarChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + varchar=stt.Type.VarChar( + length=length, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def fixed_binary(length: int, nullable=True) -> stt.Type: - return stt.Type(fixed_binary=stt.Type.FixedBinary(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + fixed_binary=stt.Type.FixedBinary( + length=length, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def decimal(scale: int, precision: int, nullable=True) -> stt.Type: - return stt.Type(decimal=stt.Type.Decimal(scale=scale, precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + decimal=stt.Type.Decimal( + scale=scale, + precision=precision, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def precision_time(precision: int, nullable=True) -> stt.Type: - return stt.Type(precision_time=stt.Type.PrecisionTime(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + precision_time=stt.Type.PrecisionTime( + precision=precision, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def precision_timestamp(precision: int, nullable=True) -> stt.Type: - return stt.Type(precision_timestamp=stt.Type.PrecisionTimestamp(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + precision_timestamp=stt.Type.PrecisionTimestamp( + precision=precision, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type: - return stt.Type(precision_timestamp_tz=stt.Type.PrecisionTimestampTZ(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + precision_timestamp_tz=stt.Type.PrecisionTimestampTZ( + precision=precision, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type: - return stt.Type(struct=stt.Type.Struct(types=types, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + struct=stt.Type.Struct( + types=types, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def list(type: stt.Type, nullable=True) -> stt.Type: - return stt.Type(list=stt.Type.List(type=type, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + list=stt.Type.List( + type=type, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def map(key: stt.Type, value: stt.Type, nullable=True) -> stt.Type: - return stt.Type(map=stt.Type.Map(key=key, value=value, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type( + map=stt.Type.Map( + key=key, + value=value, + nullability=stt.Type.NULLABILITY_NULLABLE + if nullable + else stt.Type.NULLABILITY_REQUIRED, + ) + ) + def named_struct(names: Iterable[str], struct: stt.Type) -> stt.NamedStruct: return stt.NamedStruct(names=names, struct=struct.struct) diff --git a/src/substrait/derivation_expression.py b/src/substrait/derivation_expression.py index a46c985..f4d18d7 100644 --- a/src/substrait/derivation_expression.py +++ b/src/substrait/derivation_expression.py @@ -6,7 +6,7 @@ def _evaluate(x, values: dict): - if type(x) == SubstraitTypeParser.BinaryExprContext: + if isinstance(x, SubstraitTypeParser.BinaryExprContext): left = _evaluate(x.left, values) right = _evaluate(x.right, values) @@ -26,15 +26,15 @@ def _evaluate(x, values: dict): return left <= right else: raise Exception(f"Unknown binary op {x.op.text}") - elif type(x) == SubstraitTypeParser.LiteralNumberContext: + elif isinstance(x, SubstraitTypeParser.LiteralNumberContext): return int(x.Number().symbol.text) - elif type(x) == SubstraitTypeParser.ParameterNameContext: + elif isinstance(x, SubstraitTypeParser.ParameterNameContext): return values[x.Identifier().symbol.text] - elif type(x) == SubstraitTypeParser.NumericParameterNameContext: + elif isinstance(x, SubstraitTypeParser.NumericParameterNameContext): return values[x.Identifier().symbol.text] - elif type(x) == SubstraitTypeParser.ParenExpressionContext: + elif isinstance(x, SubstraitTypeParser.ParenExpressionContext): return _evaluate(x.expr(), values) - elif type(x) == SubstraitTypeParser.FunctionCallContext: + elif isinstance(x, SubstraitTypeParser.FunctionCallContext): exprs = [_evaluate(e, values) for e in x.expr()] func = x.Identifier().symbol.text if func == "min": @@ -43,7 +43,7 @@ def _evaluate(x, values: dict): return max(*exprs) else: raise Exception(f"Unknown function {func}") - elif type(x) == SubstraitTypeParser.TypeDefContext: + elif isinstance(x, SubstraitTypeParser.TypeDefContext): scalar_type = x.scalarType() parametrized_type = x.parameterizedType() any_type = x.anyType() @@ -89,16 +89,18 @@ def _evaluate(x, values: dict): else: raise Exception() else: - raise Exception(f"either scalar_type, parametrized_type or any_type is required") - elif type(x) == SubstraitTypeParser.NumericExpressionContext: + raise Exception( + "either scalar_type, parametrized_type or any_type is required" + ) + elif isinstance(x, SubstraitTypeParser.NumericExpressionContext): return _evaluate(x.expr(), values) - elif type(x) == SubstraitTypeParser.TernaryContext: + elif isinstance(x, SubstraitTypeParser.TernaryContext): ifExpr = _evaluate(x.ifExpr, values) thenExpr = _evaluate(x.thenExpr, values) elseExpr = _evaluate(x.elseExpr, values) return thenExpr if ifExpr else elseExpr - elif type(x) == SubstraitTypeParser.MultilineDefinitionContext: + elif isinstance(x, SubstraitTypeParser.MultilineDefinitionContext): lines = zip(x.Identifier(), x.expr()) for i, e in lines: @@ -107,9 +109,9 @@ def _evaluate(x, values: dict): values[identifier] = expr_eval return _evaluate(x.finalType, values) - elif type(x) == SubstraitTypeParser.TypeLiteralContext: + elif isinstance(x, SubstraitTypeParser.TypeLiteralContext): return _evaluate(x.typeDef(), values) - elif type(x) == SubstraitTypeParser.NumericLiteralContext: + elif isinstance(x, SubstraitTypeParser.NumericLiteralContext): return int(str(x.Number())) else: raise Exception(f"Unknown token type {type(x)}") diff --git a/src/substrait/extension_registry.py b/src/substrait/extension_registry.py index f271bcb..d311fab 100644 --- a/src/substrait/extension_registry.py +++ b/src/substrait/extension_registry.py @@ -8,7 +8,7 @@ from .derivation_expression import evaluate, _evaluate, _parse import yaml - +from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser DEFAULT_URI_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions" @@ -78,9 +78,6 @@ def violates_integer_option(actual: int, option, parameters: dict): return False -from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser - - def types_equal(type1: Type, type2: Type, check_nullability=False): if check_nullability: return type1 == type2 @@ -96,7 +93,10 @@ def types_equal(type1: Type, type2: Type, check_nullability=False): ).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED return x == y -def handle_parameter_cover(covered: Type, parameter_name: str, parameters: dict, check_nullability: bool): + +def handle_parameter_cover( + covered: Type, parameter_name: str, parameters: dict, check_nullability: bool +): if parameter_name in parameters: covering = parameters[parameter_name] return types_equal(covering, covered, check_nullability) @@ -104,22 +104,27 @@ def handle_parameter_cover(covered: Type, parameter_name: str, parameters: dict, parameters[parameter_name] = covered return True + def covers( covered: Type, covering: SubstraitTypeParser.TypeLiteralContext, parameters: dict, check_nullability=False, -): +): if isinstance(covering, SubstraitTypeParser.ParameterNameContext): parameter_name = str(covering.Identifier()) - return handle_parameter_cover(covered, parameter_name, parameters, check_nullability) + return handle_parameter_cover( + covered, parameter_name, parameters, check_nullability + ) covering: SubstraitTypeParser.TypeDefContext = covering.typeDef() any_type: SubstraitTypeParser.AnyTypeContext = covering.anyType() if any_type: if any_type.AnyVar(): - return handle_parameter_cover(covered, any_type.AnyVar().symbol.text, parameters, check_nullability) + return handle_parameter_cover( + covered, any_type.AnyVar().symbol.text, parameters, check_nullability + ) else: return True @@ -176,7 +181,7 @@ def __init__( if typ := val.get("value"): self.arguments.append(_parse(typ)) self.normalized_inputs.append(normalize_substrait_type_names(typ)) - elif arg_name := val.get("name", None): + elif _ := val.get("name", None): self.arguments.append(val.get("options")) self.normalized_inputs.append("req") @@ -199,7 +204,7 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: parameters = {} for x, y in zipped_args: - if type(y) == str: + if isinstance(y, str): if y not in x: return None else: @@ -216,7 +221,7 @@ def satisfies_signature(self, signature: tuple) -> Optional[str]: p.__getattribute__(p.WhichOneof("kind")).nullability == Type.NULLABILITY_NULLABLE for p in signature - if type(p) == Type + if isinstance(p, Type) ] ) output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = ( @@ -259,7 +264,7 @@ def register_extension_yaml( def register_extension_dict(self, definitions: dict, uri: str) -> None: self._uri_mapping[uri] = next(self._uri_id_generator) - + for named_functions in definitions.values(): for function in named_functions: for impl in function.get("impls", []): @@ -293,7 +298,7 @@ def lookup_function( return (f, rtn) return None - + def lookup_uri(self, uri: str) -> Optional[int]: uri = self._uri_aliases.get(uri, uri) return self._uri_mapping.get(uri, None) diff --git a/src/substrait/type_inference.py b/src/substrait/type_inference.py index 5ed16fa..d6a68e8 100644 --- a/src/substrait/type_inference.py +++ b/src/substrait/type_inference.py @@ -348,7 +348,8 @@ def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct: nullability=struct.nullability, ) + def infer_plan_schema(plan: stp.Plan) -> stt.NamedStruct: schema = infer_rel_schema(plan.relations[-1].root.input) - return stt.NamedStruct(names=plan.relations[-1].root.names, struct=schema) \ No newline at end of file + return stt.NamedStruct(names=plan.relations[-1].root.names, struct=schema) diff --git a/src/substrait/utils.py b/src/substrait/utils.py index 0f8764b..bd84381 100644 --- a/src/substrait/utils.py +++ b/src/substrait/utils.py @@ -2,6 +2,7 @@ import substrait.gen.proto.extensions.extensions_pb2 as ste from typing import Iterable + def type_num_names(typ: stp.Type): kind = typ.WhichOneof("kind") if kind == "struct": @@ -14,6 +15,7 @@ def type_num_names(typ: stp.Type): else: return 1 + def merge_extension_uris(*extension_uris: Iterable[ste.SimpleExtensionURI]): """Merges multiple sets of SimpleExtensionURI objects into a single set. The order of extensions is kept intact, while duplicates are discarded. @@ -30,7 +32,10 @@ def merge_extension_uris(*extension_uris: Iterable[ste.SimpleExtensionURI]): return ret -def merge_extension_declarations(*extension_declarations: Iterable[ste.SimpleExtensionDeclaration]): + +def merge_extension_declarations( + *extension_declarations: Iterable[ste.SimpleExtensionDeclaration], +): """Merges multiple sets of SimpleExtensionDeclaration objects into a single set. The order of extension declarations is kept intact, while duplicates are discarded. Assumes that there are no collisions (different extension declarations having identical anchors). @@ -41,13 +46,15 @@ def merge_extension_declarations(*extension_declarations: Iterable[ste.SimpleExt for declarations in extension_declarations: for declaration in declarations: - if declaration.WhichOneof('mapping_type') == 'extension_function': - ident = (declaration.extension_function.extension_uri_reference, declaration.extension_function.name) + if declaration.WhichOneof("mapping_type") == "extension_function": + ident = ( + declaration.extension_function.extension_uri_reference, + declaration.extension_function.name, + ) if ident not in seen_extension_functions: seen_extension_functions.add(ident) ret.append(declaration) else: - raise Exception('') #TODO handle extension types + raise Exception("") # TODO handle extension types return ret - \ No newline at end of file diff --git a/tests/builders/extended_expression/test_aggregate_function.py b/tests/builders/extended_expression/test_aggregate_function.py index 304e1c7..538dc4b 100644 --- a/tests/builders/extended_expression/test_aggregate_function.py +++ b/tests/builders/extended_expression/test_aggregate_function.py @@ -38,25 +38,28 @@ registry = ExtensionRegistry(load_default_extensions=False) registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") + def test_aggregate_count(): - e = aggregate_function('test_uri', 'count', - expressions=[literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)))], - alias='count', - )(named_struct, registry) - - expected = stee.ExtendedExpression( - extension_uris=[ - ste.SimpleExtensionURI( - extension_uri_anchor=1, - uri='test_uri' + e = aggregate_function( + "test_uri", + "count", + expressions=[ + literal( + 10, + type=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), ) ], + alias="count", + )(named_struct, registry) + + expected = stee.ExtendedExpression( + extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="test_uri")], extensions=[ ste.SimpleExtensionDeclaration( extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( - extension_uri_reference=1, - function_anchor=1, - name='count' + extension_uri_reference=1, function_anchor=1, name="count" ) ) ], @@ -65,9 +68,15 @@ def test_aggregate_count(): measure=stalg.AggregateFunction( function_reference=1, arguments=[ - stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=10, nullable=False))), + stalg.FunctionArgument( + value=stalg.Expression( + literal=stalg.Expression.Literal(i8=10, nullable=False) + ) + ), ], - output_type=stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)) + output_type=stt.Type( + i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED) + ), ), output_names=["count"], ) diff --git a/tests/builders/extended_expression/test_cast.py b/tests/builders/extended_expression/test_cast.py index 2fcfca5..30defe5 100644 --- a/tests/builders/extended_expression/test_cast.py +++ b/tests/builders/extended_expression/test_cast.py @@ -1,5 +1,3 @@ -import yaml - import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.extended_expression_pb2 as stee @@ -21,12 +19,10 @@ registry = ExtensionRegistry(load_default_extensions=False) + def test_cast(): - e = cast( - input=literal(3, i8()), - type=i16() - )(named_struct, registry) - + e = cast(input=literal(3, i8()), type=i16())(named_struct, registry) + expected = stee.ExtendedExpression( referred_expr=[ stee.ExpressionReference( @@ -35,8 +31,10 @@ def test_cast(): type=stt.Type( i16=stt.Type.I16(nullability=stt.Type.NULLABILITY_NULLABLE) ), - input=stalg.Expression(literal=stalg.Expression.Literal(i8=3, nullable=True)), - failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL + input=stalg.Expression( + literal=stalg.Expression.Literal(i8=3, nullable=True) + ), + failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL, ) ), output_names=["cast"], @@ -46,4 +44,3 @@ def test_cast(): ) assert e == expected - diff --git a/tests/builders/extended_expression/test_if_then.py b/tests/builders/extended_expression/test_if_then.py index 1adeb6d..8392b3a 100644 --- a/tests/builders/extended_expression/test_if_then.py +++ b/tests/builders/extended_expression/test_if_then.py @@ -16,30 +16,57 @@ names=["order_id", "description", "order_total"], struct=struct ) + def test_if_else(): actual = if_then( ifs=[ ( - literal(True, type=stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED))), - literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))) + literal( + True, + type=stt.Type( + bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + ), + literal( + 10, + type=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + ), ) ], - _else=literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))) + _else=literal( + 20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)) + ), )(named_struct, None) expected = stee.ExtendedExpression( referred_expr=[ stee.ExpressionReference( expression=stalg.Expression( - if_then=stalg.Expression.IfThen(**{ - 'ifs': [ - stalg.Expression.IfThen.IfClause(**{ - 'if': stalg.Expression(literal=stalg.Expression.Literal(boolean=True, nullable=False)), - 'then': stalg.Expression(literal=stalg.Expression.Literal(i8=10, nullable=False)) - }) - ], - 'else': stalg.Expression(literal=stalg.Expression.Literal(i8=20, nullable=False)) - }) + if_then=stalg.Expression.IfThen( + **{ + "ifs": [ + stalg.Expression.IfThen.IfClause( + **{ + "if": stalg.Expression( + literal=stalg.Expression.Literal( + boolean=True, nullable=False + ) + ), + "then": stalg.Expression( + literal=stalg.Expression.Literal( + i8=10, nullable=False + ) + ), + } + ) + ], + "else": stalg.Expression( + literal=stalg.Expression.Literal(i8=20, nullable=False) + ), + } + ) ), output_names=["IfThen(Literal(True),Literal(10),Literal(20))"], ) diff --git a/tests/builders/extended_expression/test_literal.py b/tests/builders/extended_expression/test_literal.py index 45f86dd..d43f3be 100644 --- a/tests/builders/extended_expression/test_literal.py +++ b/tests/builders/extended_expression/test_literal.py @@ -1,36 +1,65 @@ from datetime import date import substrait.gen.proto.algebra_pb2 as stalg -import substrait.gen.proto.type_pb2 as stt -import substrait.gen.proto.extended_expression_pb2 as stee from substrait.builders.extended_expression import literal from substrait.builders import type as sttb + def extract_literal(builder): return builder(None, None).referred_expr[0].expression.literal + def test_boolean(): - assert extract_literal(literal(True, sttb.boolean())) == stalg.Expression.Literal(boolean=True, nullable=True) - assert extract_literal(literal(False, sttb.boolean())) == stalg.Expression.Literal(boolean=False, nullable=True) + assert extract_literal(literal(True, sttb.boolean())) == stalg.Expression.Literal( + boolean=True, nullable=True + ) + assert extract_literal(literal(False, sttb.boolean())) == stalg.Expression.Literal( + boolean=False, nullable=True + ) + def test_integer(): - assert extract_literal(literal(100, sttb.i16())) == stalg.Expression.Literal(i16=100, nullable=True) + assert extract_literal(literal(100, sttb.i16())) == stalg.Expression.Literal( + i16=100, nullable=True + ) + def test_string(): - assert extract_literal(literal("Hello", sttb.string())) == stalg.Expression.Literal(string="Hello", nullable=True) + assert extract_literal(literal("Hello", sttb.string())) == stalg.Expression.Literal( + string="Hello", nullable=True + ) + def test_binary(): - assert extract_literal(literal(b"Hello", sttb.binary())) == stalg.Expression.Literal(binary=b"Hello", nullable=True) + assert extract_literal( + literal(b"Hello", sttb.binary()) + ) == stalg.Expression.Literal(binary=b"Hello", nullable=True) + def test_date(): - assert extract_literal(literal(1000, sttb.date())) == stalg.Expression.Literal(date=1000, nullable=True) - assert extract_literal(literal(date(1970, 1, 11), sttb.date())) == stalg.Expression.Literal(date=10, nullable=True) + assert extract_literal(literal(1000, sttb.date())) == stalg.Expression.Literal( + date=1000, nullable=True + ) + assert extract_literal( + literal(date(1970, 1, 11), sttb.date()) + ) == stalg.Expression.Literal(date=10, nullable=True) + def test_fixed_char(): - assert extract_literal(literal("Hello", sttb.fixed_char(length=5))) == stalg.Expression.Literal(fixed_char="Hello", nullable=True) + assert extract_literal( + literal("Hello", sttb.fixed_char(length=5)) + ) == stalg.Expression.Literal(fixed_char="Hello", nullable=True) + def test_var_char(): - assert extract_literal(literal("Hello", sttb.var_char(length=5))) \ - == stalg.Expression.Literal(var_char=stalg.Expression.Literal.VarChar(value="Hello", length=5), nullable=True) + assert extract_literal( + literal("Hello", sttb.var_char(length=5)) + ) == stalg.Expression.Literal( + var_char=stalg.Expression.Literal.VarChar(value="Hello", length=5), + nullable=True, + ) + def test_fixed_binary(): - assert extract_literal(literal(b"Hello", sttb.fixed_binary(length=5))) == stalg.Expression.Literal(fixed_binary=b"Hello", nullable=True) + assert extract_literal( + literal(b"Hello", sttb.fixed_binary(length=5)) + ) == stalg.Expression.Literal(fixed_binary=b"Hello", nullable=True) diff --git a/tests/builders/extended_expression/test_multi_or_list.py b/tests/builders/extended_expression/test_multi_or_list.py index d2efa2d..31a2a15 100644 --- a/tests/builders/extended_expression/test_multi_or_list.py +++ b/tests/builders/extended_expression/test_multi_or_list.py @@ -1,5 +1,3 @@ -import yaml - import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.extended_expression_pb2 as stee @@ -21,38 +19,59 @@ registry = ExtensionRegistry(load_default_extensions=False) + def test_singular_or_list(): e = multi_or_list( value=[literal(1, i8()), literal(2, i8())], options=[ [literal(1, i8()), literal(2, i8())], - [literal(3, i8()), literal(4, i8())] - ] + [literal(3, i8()), literal(4, i8())], + ], )(named_struct, registry) - + expected = stee.ExtendedExpression( referred_expr=[ stee.ExpressionReference( expression=stalg.Expression( multi_or_list=stalg.Expression.MultiOrList( value=[ - stalg.Expression(literal=stalg.Expression.Literal(i8=1, nullable=True)), - stalg.Expression(literal=stalg.Expression.Literal(i8=2, nullable=True)) + stalg.Expression( + literal=stalg.Expression.Literal(i8=1, nullable=True) + ), + stalg.Expression( + literal=stalg.Expression.Literal(i8=2, nullable=True) + ), ], options=[ stalg.Expression.MultiOrList.Record( fields=[ - stalg.Expression(literal=stalg.Expression.Literal(i8=1, nullable=True)), - stalg.Expression(literal=stalg.Expression.Literal(i8=2, nullable=True)) + stalg.Expression( + literal=stalg.Expression.Literal( + i8=1, nullable=True + ) + ), + stalg.Expression( + literal=stalg.Expression.Literal( + i8=2, nullable=True + ) + ), ] ), stalg.Expression.MultiOrList.Record( fields=[ - stalg.Expression(literal=stalg.Expression.Literal(i8=3, nullable=True)), - stalg.Expression(literal=stalg.Expression.Literal(i8=4, nullable=True)) + stalg.Expression( + literal=stalg.Expression.Literal( + i8=3, nullable=True + ) + ), + stalg.Expression( + literal=stalg.Expression.Literal( + i8=4, nullable=True + ) + ), ] - ) - ] + ), + ], ) ), output_names=["multi_or_list"], @@ -62,4 +81,3 @@ def test_singular_or_list(): ) assert e == expected - diff --git a/tests/builders/extended_expression/test_scalar_function.py b/tests/builders/extended_expression/test_scalar_function.py index 6c7024b..4ac2052 100644 --- a/tests/builders/extended_expression/test_scalar_function.py +++ b/tests/builders/extended_expression/test_scalar_function.py @@ -42,26 +42,33 @@ registry = ExtensionRegistry(load_default_extensions=False) registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") + def test_sclar_add(): - e = scalar_function('test_uri', 'test_func', - expressions=[ - literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), - literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))) - ])(named_struct, registry) - - expected = stee.ExtendedExpression( - extension_uris=[ - ste.SimpleExtensionURI( - extension_uri_anchor=1, - uri='test_uri' - ) + e = scalar_function( + "test_uri", + "test_func", + expressions=[ + literal( + 10, + type=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + ), + literal( + 20, + type=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + ), ], + )(named_struct, registry) + + expected = stee.ExtendedExpression( + extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="test_uri")], extensions=[ ste.SimpleExtensionDeclaration( extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( - extension_uri_reference=1, - function_anchor=1, - name='test_func' + extension_uri_reference=1, function_anchor=1, name="test_func" ) ) ], @@ -71,10 +78,24 @@ def test_sclar_add(): scalar_function=stalg.Expression.ScalarFunction( function_reference=1, arguments=[ - stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=10, nullable=False))), - stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=20, nullable=False))) + stalg.FunctionArgument( + value=stalg.Expression( + literal=stalg.Expression.Literal( + i8=10, nullable=False + ) + ) + ), + stalg.FunctionArgument( + value=stalg.Expression( + literal=stalg.Expression.Literal( + i8=20, nullable=False + ) + ) + ), ], - output_type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)) + output_type=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), ) ), output_names=["test_func(Literal(10),Literal(20))"], @@ -87,40 +108,45 @@ def test_sclar_add(): def test_nested_scalar_calls(): - e = scalar_function('test_uri', 'is_positive', - expressions=[ - scalar_function('test_uri', 'test_func', - expressions=[ - literal(10, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))), - literal(20, type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED))) - ] - ) - ], - alias='positive' - )(named_struct, registry) - - expected = stee.ExtendedExpression( - extension_uris=[ - ste.SimpleExtensionURI( - extension_uri_anchor=1, - uri='test_uri' + e = scalar_function( + "test_uri", + "is_positive", + expressions=[ + scalar_function( + "test_uri", + "test_func", + expressions=[ + literal( + 10, + type=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + ), + literal( + 20, + type=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + ), + ], ) ], + alias="positive", + )(named_struct, registry) + + expected = stee.ExtendedExpression( + extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="test_uri")], extensions=[ ste.SimpleExtensionDeclaration( extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( - extension_uri_reference=1, - function_anchor=2, - name='is_positive' - ) + extension_uri_reference=1, function_anchor=2, name="is_positive" + ) ), ste.SimpleExtensionDeclaration( extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( - extension_uri_reference=1, - function_anchor=1, - name='test_func' - ) - ) + extension_uri_reference=1, function_anchor=1, name="test_func" + ) + ), ], referred_expr=[ stee.ExpressionReference( @@ -133,15 +159,35 @@ def test_nested_scalar_calls(): scalar_function=stalg.Expression.ScalarFunction( function_reference=1, arguments=[ - stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=10, nullable=False))), - stalg.FunctionArgument(value=stalg.Expression(literal=stalg.Expression.Literal(i8=20, nullable=False))) + stalg.FunctionArgument( + value=stalg.Expression( + literal=stalg.Expression.Literal( + i8=10, nullable=False + ) + ) + ), + stalg.FunctionArgument( + value=stalg.Expression( + literal=stalg.Expression.Literal( + i8=20, nullable=False + ) + ) + ), ], - output_type=stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)) + output_type=stt.Type( + i8=stt.Type.I8( + nullability=stt.Type.NULLABILITY_REQUIRED + ) + ), ) ) ) ], - output_type=stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED)) + output_type=stt.Type( + bool=stt.Type.Boolean( + nullability=stt.Type.NULLABILITY_REQUIRED + ) + ), ) ), output_names=["positive"], diff --git a/tests/builders/extended_expression/test_singular_or_list.py b/tests/builders/extended_expression/test_singular_or_list.py index 927a109..2408581 100644 --- a/tests/builders/extended_expression/test_singular_or_list.py +++ b/tests/builders/extended_expression/test_singular_or_list.py @@ -1,5 +1,3 @@ -import yaml - import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.extended_expression_pb2 as stee @@ -21,25 +19,28 @@ registry = ExtensionRegistry(load_default_extensions=False) + def test_singular_or_list(): e = singular_or_list( - value=literal(3, i8()), - options=[ - literal(1, i8()), - literal(2, i8()) - ] + value=literal(3, i8()), options=[literal(1, i8()), literal(2, i8())] )(named_struct, registry) - + expected = stee.ExtendedExpression( referred_expr=[ stee.ExpressionReference( expression=stalg.Expression( singular_or_list=stalg.Expression.SingularOrList( - value=stalg.Expression(literal=stalg.Expression.Literal(i8=3, nullable=True)), + value=stalg.Expression( + literal=stalg.Expression.Literal(i8=3, nullable=True) + ), options=[ - stalg.Expression(literal=stalg.Expression.Literal(i8=1, nullable=True)), - stalg.Expression(literal=stalg.Expression.Literal(i8=2, nullable=True)) - ] + stalg.Expression( + literal=stalg.Expression.Literal(i8=1, nullable=True) + ), + stalg.Expression( + literal=stalg.Expression.Literal(i8=2, nullable=True) + ), + ], ) ), output_names=["singular_or_list"], @@ -49,4 +50,3 @@ def test_singular_or_list(): ) assert e == expected - diff --git a/tests/builders/extended_expression/test_switch.py b/tests/builders/extended_expression/test_switch.py index 35095f8..ee675b4 100644 --- a/tests/builders/extended_expression/test_switch.py +++ b/tests/builders/extended_expression/test_switch.py @@ -1,5 +1,3 @@ -import yaml - import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.extended_expression_pb2 as stee @@ -21,35 +19,52 @@ registry = ExtensionRegistry(load_default_extensions=False) + def test_switch(): e = switch( match=literal(3, i8()), ifs=[ (literal(1, i8()), literal(1, i8())), - (literal(2, i8()), literal(4, i8())) + (literal(2, i8()), literal(4, i8())), ], - _else=literal(9, i8()) + _else=literal(9, i8()), )(named_struct, registry) - + expected = stee.ExtendedExpression( referred_expr=[ stee.ExpressionReference( expression=stalg.Expression( switch_expression=stalg.Expression.SwitchExpression( - match=stalg.Expression(literal=stalg.Expression.Literal(i8=3, nullable=True)), + match=stalg.Expression( + literal=stalg.Expression.Literal(i8=3, nullable=True) + ), ifs=[ - stalg.Expression.SwitchExpression.IfValue(**{ - 'if': stalg.Expression.Literal(i8=1, nullable=True), - 'then': stalg.Expression(literal=stalg.Expression.Literal(i8=1, nullable=True)) - }), - stalg.Expression.SwitchExpression.IfValue(**{ - 'if': stalg.Expression.Literal(i8=2, nullable=True), - 'then': stalg.Expression(literal=stalg.Expression.Literal(i8=4, nullable=True)) - }), + stalg.Expression.SwitchExpression.IfValue( + **{ + "if": stalg.Expression.Literal(i8=1, nullable=True), + "then": stalg.Expression( + literal=stalg.Expression.Literal( + i8=1, nullable=True + ) + ), + } + ), + stalg.Expression.SwitchExpression.IfValue( + **{ + "if": stalg.Expression.Literal(i8=2, nullable=True), + "then": stalg.Expression( + literal=stalg.Expression.Literal( + i8=4, nullable=True + ) + ), + } + ), ], **{ - 'else': stalg.Expression(literal=stalg.Expression.Literal(i8=9, nullable=True)) - } + "else": stalg.Expression( + literal=stalg.Expression.Literal(i8=9, nullable=True) + ) + }, ) ), output_names=["switch"], @@ -59,4 +74,3 @@ def test_switch(): ) assert e == expected - diff --git a/tests/builders/extended_expression/test_window_function.py b/tests/builders/extended_expression/test_window_function.py index f9dad38..38b9a8c 100644 --- a/tests/builders/extended_expression/test_window_function.py +++ b/tests/builders/extended_expression/test_window_function.py @@ -4,7 +4,7 @@ import substrait.gen.proto.type_pb2 as stt import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.extensions.extensions_pb2 as ste -from substrait.builders.extended_expression import window_function, literal +from substrait.builders.extended_expression import window_function from substrait.extension_registry import ExtensionRegistry struct = stt.Type.Struct( @@ -44,22 +44,18 @@ registry = ExtensionRegistry(load_default_extensions=False) registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") + def test_row_number(): - e = window_function('test_uri', 'row_number', expressions=[], alias='rn')(named_struct, registry) - + e = window_function("test_uri", "row_number", expressions=[], alias="rn")( + named_struct, registry + ) + expected = stee.ExtendedExpression( - extension_uris=[ - ste.SimpleExtensionURI( - extension_uri_anchor=1, - uri='test_uri' - ) - ], + extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="test_uri")], extensions=[ ste.SimpleExtensionDeclaration( extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( - extension_uri_reference=1, - function_anchor=1, - name='row_number' + extension_uri_reference=1, function_anchor=1, name="row_number" ) ) ], @@ -68,7 +64,9 @@ def test_row_number(): expression=stalg.Expression( window_function=stalg.Expression.WindowFunction( function_reference=1, - output_type=stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE)) + output_type=stt.Type( + i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE) + ), ) ), output_names=["rn"], diff --git a/tests/builders/plan/test_aggregate.py b/tests/builders/plan/test_aggregate.py index bb16af8..03af3ea 100644 --- a/tests/builders/plan/test_aggregate.py +++ b/tests/builders/plan/test_aggregate.py @@ -30,35 +30,29 @@ struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) -named_struct = stt.NamedStruct( - names=["id", "is_applicable"], struct=struct -) +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + def test_aggregate(): - table = read_named_table('table', named_struct) + table = read_named_table("table", named_struct) + + group_expr = column("id") + measure_expr = aggregate_function( + "test_uri", "count", expressions=[column("is_applicable")], alias=["count"] + ) - group_expr = column('id') - measure_expr = aggregate_function('test_uri', 'count', expressions=[column('is_applicable')], alias=['count']) + actual = aggregate( + table, grouping_expressions=[group_expr], measures=[measure_expr] + )(registry) - actual = aggregate(table, - grouping_expressions=[group_expr], - measures=[measure_expr])(registry) - ns = infer_plan_schema(table(None)) expected = stp.Plan( - extension_uris=[ - ste.SimpleExtensionURI( - extension_uri_anchor=1, - uri='test_uri' - ) - ], + extension_uris=[ste.SimpleExtensionURI(extension_uri_anchor=1, uri="test_uri")], extensions=[ ste.SimpleExtensionDeclaration( extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( - extension_uri_reference=1, - function_anchor=1, - name='count' + extension_uri_reference=1, function_anchor=1, name="count" ) ) ], @@ -74,23 +68,26 @@ def test_aggregate(): groupings=[ stalg.AggregateRel.Grouping( grouping_expressions=[ - group_expr(ns, registry).referred_expr[0].expression + group_expr(ns, registry) + .referred_expr[0] + .expression ], - expression_references=[0] + expression_references=[0], ) ], measures=[ stalg.AggregateRel.Measure( - measure=measure_expr(ns, registry).referred_expr[0].measure + measure=measure_expr(ns, registry) + .referred_expr[0] + .measure ) - ] - + ], ) ), - names=['id', 'count'] + names=["id", "count"], ) ) - ] + ], ) - assert actual == expected \ No newline at end of file + assert actual == expected diff --git a/tests/builders/plan/test_cross.py b/tests/builders/plan/test_cross.py index 42cfb05..9a47ba4 100644 --- a/tests/builders/plan/test_cross.py +++ b/tests/builders/plan/test_cross.py @@ -3,24 +3,23 @@ import substrait.gen.proto.algebra_pb2 as stalg from substrait.builders.type import boolean, i64, string from substrait.builders.plan import read_named_table, cross -from substrait.builders.extended_expression import literal from substrait.extension_registry import ExtensionRegistry registry = ExtensionRegistry(load_default_extensions=False) struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) -named_struct = stt.NamedStruct( - names=["id", "is_applicable"], struct=struct -) +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) named_struct_2 = stt.NamedStruct( - names=["fk_id", "name"], struct=stt.Type.Struct(types=[i64(nullable=False), string()]) + names=["fk_id", "name"], + struct=stt.Type.Struct(types=[i64(nullable=False), string()]), ) + def test_cross_join(): - table = read_named_table('table', named_struct) - table2 = read_named_table('table2', named_struct_2) + table = read_named_table("table", named_struct) + table2 = read_named_table("table2", named_struct_2) actual = cross(table, table2)(registry) @@ -34,11 +33,10 @@ def test_cross_join(): right=table2(None).relations[-1].root.input, ) ), - names=['id', 'is_applicable', 'fk_id', 'name'] + names=["id", "is_applicable", "fk_id", "name"], ) ) ] ) assert actual == expected - diff --git a/tests/builders/plan/test_fetch.py b/tests/builders/plan/test_fetch.py index f53a9b9..ebcd372 100644 --- a/tests/builders/plan/test_fetch.py +++ b/tests/builders/plan/test_fetch.py @@ -4,19 +4,17 @@ from substrait.builders.type import boolean, i64 from substrait.builders.plan import read_named_table, fetch from substrait.builders.extended_expression import literal -from substrait.type_inference import infer_plan_schema from substrait.extension_registry import ExtensionRegistry registry = ExtensionRegistry(load_default_extensions=False) struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) -named_struct = stt.NamedStruct( - names=["id", "is_applicable"], struct=struct -) +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + def test_fetch(): - table = read_named_table('table', named_struct) + table = read_named_table("table", named_struct) offset = literal(10, i64()) count = literal(5, i64()) @@ -31,14 +29,13 @@ def test_fetch(): fetch=stalg.FetchRel( input=table(None).relations[-1].root.input, offset_expr=offset(None, None).referred_expr[0].expression, - count_expr=count(None, None).referred_expr[0].expression + count_expr=count(None, None).referred_expr[0].expression, ) ), - names=['id', 'is_applicable'] + names=["id", "is_applicable"], ) ) ] ) assert actual == expected - diff --git a/tests/builders/plan/test_filter.py b/tests/builders/plan/test_filter.py index e40ed22..659f402 100644 --- a/tests/builders/plan/test_filter.py +++ b/tests/builders/plan/test_filter.py @@ -3,19 +3,18 @@ import substrait.gen.proto.algebra_pb2 as stalg from substrait.builders.type import boolean, i64 from substrait.builders.plan import read_named_table, filter -from substrait.builders.extended_expression import column, literal +from substrait.builders.extended_expression import literal from substrait.extension_registry import ExtensionRegistry registry = ExtensionRegistry(load_default_extensions=False) struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) -named_struct = stt.NamedStruct( - names=["id", "is_applicable"], struct=struct -) +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + def test_filter(): - table = read_named_table('table', named_struct) + table = read_named_table("table", named_struct) actual = filter(table, literal(True, boolean()))(registry) @@ -28,16 +27,15 @@ def test_filter(): input=table(None).relations[-1].root.input, condition=stalg.Expression( literal=stalg.Expression.Literal( - boolean=True, - nullable=True + boolean=True, nullable=True ) - ) + ), ) ), - names=['id', 'is_applicable'] + names=["id", "is_applicable"], ) ) ] ) - assert actual == expected \ No newline at end of file + assert actual == expected diff --git a/tests/builders/plan/test_join.py b/tests/builders/plan/test_join.py index 42ad30a..8d4998a 100644 --- a/tests/builders/plan/test_join.py +++ b/tests/builders/plan/test_join.py @@ -10,19 +10,21 @@ struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) -named_struct = stt.NamedStruct( - names=["id", "is_applicable"], struct=struct -) +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) named_struct_2 = stt.NamedStruct( - names=["fk_id", "name"], struct=stt.Type.Struct(types=[i64(nullable=False), string()]) + names=["fk_id", "name"], + struct=stt.Type.Struct(types=[i64(nullable=False), string()]), ) + def test_join(): - table = read_named_table('table', named_struct) - table2 = read_named_table('table2', named_struct_2) + table = read_named_table("table", named_struct) + table2 = read_named_table("table2", named_struct_2) - actual = join(table, table2, literal(True, boolean()), stalg.JoinRel.JOIN_TYPE_INNER)(registry) + actual = join( + table, table2, literal(True, boolean()), stalg.JoinRel.JOIN_TYPE_INNER + )(registry) expected = stp.Plan( relations=[ @@ -32,15 +34,16 @@ def test_join(): join=stalg.JoinRel( left=table(None).relations[-1].root.input, right=table2(None).relations[-1].root.input, - expression=literal(True, boolean())(None, None).referred_expr[0].expression, - type=stalg.JoinRel.JOIN_TYPE_INNER + expression=literal(True, boolean())(None, None) + .referred_expr[0] + .expression, + type=stalg.JoinRel.JOIN_TYPE_INNER, ) ), - names=['id', 'is_applicable', 'fk_id', 'name'] + names=["id", "is_applicable", "fk_id", "name"], ) ) ] ) assert actual == expected - diff --git a/tests/builders/plan/test_project.py b/tests/builders/plan/test_project.py index 9535a32..2dd9ff1 100644 --- a/tests/builders/plan/test_project.py +++ b/tests/builders/plan/test_project.py @@ -10,14 +10,13 @@ struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) -named_struct = stt.NamedStruct( - names=["id", "is_applicable"], struct=struct -) +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + def test_project(): - table = read_named_table('table', named_struct) + table = read_named_table("table", named_struct) - actual = project(table, [column('id')])(registry) + actual = project(table, [column("id")])(registry) expected = stp.Plan( relations=[ @@ -25,24 +24,28 @@ def test_project(): root=stalg.RelRoot( input=stalg.Rel( project=stalg.ProjectRel( - common=stalg.RelCommon(emit=stalg.RelCommon.Emit(output_mapping=[2])), + common=stalg.RelCommon( + emit=stalg.RelCommon.Emit(output_mapping=[2]) + ), input=table(None).relations[-1].root.input, expressions=[ stalg.Expression( selection=stalg.Expression.FieldReference( direct_reference=stalg.Expression.ReferenceSegment( - struct_field=stalg.Expression.ReferenceSegment.StructField(field=0) + struct_field=stalg.Expression.ReferenceSegment.StructField( + field=0 + ) ), - root_reference=stalg.Expression.FieldReference.RootReference() + root_reference=stalg.Expression.FieldReference.RootReference(), ) ) - ] + ], ) ), - names=['id'] + names=["id"], ) ) ] ) - assert actual == expected \ No newline at end of file + assert actual == expected diff --git a/tests/builders/plan/test_read.py b/tests/builders/plan/test_read.py index 7e8fdf7..6e380b8 100644 --- a/tests/builders/plan/test_read.py +++ b/tests/builders/plan/test_read.py @@ -6,12 +6,11 @@ struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) -named_struct = stt.NamedStruct( - names=["id", "is_applicable"], struct=struct -) +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + def test_read_rel(): - actual = read_named_table('example_table', named_struct)(None) + actual = read_named_table("example_table", named_struct)(None) expected = stp.Plan( relations=[ @@ -20,11 +19,13 @@ def test_read_rel(): input=stalg.Rel( read=stalg.ReadRel( common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), - base_schema=named_struct, - named_table=stalg.ReadRel.NamedTable(names=['example_table']) + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable( + names=["example_table"] + ), ) ), - names=['id', 'is_applicable'] + names=["id", "is_applicable"], ) ) ] @@ -32,8 +33,9 @@ def test_read_rel(): assert actual == expected + def test_read_rel_db(): - actual = read_named_table(['example_db', 'example_table'], named_struct)(None) + actual = read_named_table(["example_db", "example_table"], named_struct)(None) expected = stp.Plan( relations=[ @@ -42,14 +44,16 @@ def test_read_rel_db(): input=stalg.Rel( read=stalg.ReadRel( common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), - base_schema=named_struct, - named_table=stalg.ReadRel.NamedTable(names=['example_db', 'example_table']) + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable( + names=["example_db", "example_table"] + ), ) ), - names=['id', 'is_applicable'] + names=["id", "is_applicable"], ) ) ] ) - assert actual == expected \ No newline at end of file + assert actual == expected diff --git a/tests/builders/plan/test_set.py b/tests/builders/plan/test_set.py index c701024..e761707 100644 --- a/tests/builders/plan/test_set.py +++ b/tests/builders/plan/test_set.py @@ -3,25 +3,21 @@ import substrait.gen.proto.algebra_pb2 as stalg from substrait.builders.type import boolean, i64 from substrait.builders.plan import read_named_table, set -from substrait.builders.extended_expression import column -from substrait.type_inference import infer_plan_schema from substrait.extension_registry import ExtensionRegistry registry = ExtensionRegistry(load_default_extensions=False) struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) -named_struct = stt.NamedStruct( - names=["id", "is_applicable"], struct=struct -) +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + def test_set(): - table = read_named_table('table', named_struct) - table2 = read_named_table('table2', named_struct) + table = read_named_table("table", named_struct) + table2 = read_named_table("table2", named_struct) actual = set([table, table2], stalg.SetRel.SET_OP_UNION_ALL)(None) - expected = stp.Plan( relations=[ stp.PlanRel( @@ -32,10 +28,10 @@ def test_set(): table(None).relations[-1].root.input, table2(None).relations[-1].root.input, ], - op=stalg.SetRel.SET_OP_UNION_ALL + op=stalg.SetRel.SET_OP_UNION_ALL, ) ), - names=['id', 'is_applicable'] + names=["id", "is_applicable"], ) ) ] diff --git a/tests/builders/plan/test_sort.py b/tests/builders/plan/test_sort.py index 66cfbd5..4b4f49c 100644 --- a/tests/builders/plan/test_sort.py +++ b/tests/builders/plan/test_sort.py @@ -11,14 +11,13 @@ struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) -named_struct = stt.NamedStruct( - names=["id", "is_applicable"], struct=struct -) +named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct) + def test_sort_no_direction(): - table = read_named_table('table', named_struct) + table = read_named_table("table", named_struct) - col = column('id') + col = column("id") actual = sort(table, expressions=[col])(registry) @@ -32,12 +31,14 @@ def test_sort_no_direction(): sorts=[ stalg.SortField( direction=stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST, - expr=col(infer_plan_schema(table(None)), registry).referred_expr[0].expression + expr=col(infer_plan_schema(table(None)), registry) + .referred_expr[0] + .expression, ) - ] + ], ) ), - names=['id', 'is_applicable'] + names=["id", "is_applicable"], ) ) ] @@ -45,12 +46,15 @@ def test_sort_no_direction(): assert actual == expected + def test_sort_direction(): - table = read_named_table('table', named_struct) + table = read_named_table("table", named_struct) - col = column('id') + col = column("id") - actual = sort(table, expressions=[(col, stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST)])(registry) + actual = sort( + table, expressions=[(col, stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST)] + )(registry) expected = stp.Plan( relations=[ @@ -62,15 +66,17 @@ def test_sort_direction(): sorts=[ stalg.SortField( direction=stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST, - expr=col(infer_plan_schema(table(None)), registry).referred_expr[0].expression + expr=col(infer_plan_schema(table(None)), registry) + .referred_expr[0] + .expression, ) - ] + ], ) ), - names=['id', 'is_applicable'] + names=["id", "is_applicable"], ) ) ] ) - assert actual == expected \ No newline at end of file + assert actual == expected diff --git a/tests/test_function_registry.py b/tests/test_function_registry.py index 14a227e..b4dd046 100644 --- a/tests/test_function_registry.py +++ b/tests/test_function_registry.py @@ -301,7 +301,7 @@ def test_function_with_discrete_nullability(): )[1] == i8(nullable=True) -def test_function_with_discrete_nullability(): +def test_function_with_discrete_nullability_nonexisting(): assert ( registry.lookup_function( uri="test", function_name="add_discrete", signature=[i8(), i8()] diff --git a/tests/test_json.py b/tests/test_json.py index b8651d2..741bd87 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,7 +1,6 @@ import os import pathlib import tempfile -import json from substrait.proto import Plan from substrait.json import load_json, parse_json, dump_json, write_json @@ -69,4 +68,4 @@ def _strip_json_comments(jsonfile): # a comment containing the SQL that matches the json plan. # As Python JSON parser doesn't support comments, # we have to strip them to make the content readable - return "\n".join(l for l in jsonfile.readlines() if l[0] != "#") + return "\n".join(line for line in jsonfile.readlines() if line[0] != "#")