diff --git a/examples/builder_example.py b/examples/builder_example.py new file mode 100644 index 0000000..b0c9ed6 --- /dev/null +++ b/examples/builder_example.py @@ -0,0 +1,138 @@ +from substrait.builders.plan import read_named_table, project, filter +from substrait.builders.extended_expression import column, scalar_function, literal +from substrait.builders.type import i64, boolean, struct, named_struct +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=True) + +ns = named_struct( + names=["id", "is_applicable"], + struct=struct( + types=[ + i64(nullable=False), + boolean() + ] + ) +) + +table = read_named_table('example_table', ns) +table = filter(table, expression=column('is_applicable')) +table = filter(table, expression=scalar_function('functions_comparison.yaml', 'lt', column('id'), literal(100, i64()))) +table = project(table, expressions=[column('id')]) + +print(table(registry)) + +""" +extension_uris { + extension_uri_anchor: 13 + uri: "functions_comparison.yaml" +} +extensions { + extension_function { + extension_uri_reference: 13 + function_anchor: 495 + name: "lt" + } +} +relations { + root { + input { + project { + common { + emit { + output_mapping: 2 + } + } + input { + filter { + input { + filter { + input { + read { + common { + direct { + } + } + base_schema { + names: "id" + names: "is_applicable" + struct { + types { + i64 { + nullability: NULLABILITY_REQUIRED + } + } + types { + bool { + nullability: NULLABILITY_NULLABLE + } + } + nullability: NULLABILITY_NULLABLE + } + } + named_table { + names: "example_table" + } + } + } + condition { + selection { + direct_reference { + struct_field { + field: 1 + } + } + root_reference { + } + } + } + } + } + condition { + scalar_function { + function_reference: 495 + output_type { + bool { + nullability: NULLABILITY_NULLABLE + } + } + arguments { + value { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + } + arguments { + value { + literal { + i64: 100 + nullable: true + } + } + } + } + } + } + } + expressions { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + } + } + names: "id" + } +} +""" diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py new file mode 100644 index 0000000..3075c4b --- /dev/null +++ b/src/substrait/builders/plan.py @@ -0,0 +1,303 @@ +""" +Plan builders take either Plan or UnboundPlan objects as input rather than plain Rels. +This is to make sure that additional information like extension types of functions are not lost. +All builders return UnboundPlan objects that can be materialized to a Plan using an ExtensionRegistry. +See `examples/builder_example.py` for usage. +""" + +from typing import Iterable, Union, Callable + +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +from substrait.extension_registry import ExtensionRegistry +from substrait.builders.extended_expression import UnboundExtendedExpression +from substrait.type_inference import infer_plan_schema +from substrait.utils import merge_extension_declarations, merge_extension_uris + +UnboundPlan = Callable[[ExtensionRegistry], stp.Plan] + +PlanOrUnbound = Union[stp.Plan, UnboundPlan] + +def _merge_extensions(*objs): + return { + "extension_uris": merge_extension_uris(*[b.extension_uris for b in objs]), + "extensions": merge_extension_declarations(*[b.extensions for b in objs]), + } + + +def read_named_table(names: Union[str, Iterable[str]], named_struct: stt.NamedStruct) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + _names = [names] if isinstance(names, str) else names + + rel = stalg.Rel( + read=stalg.ReadRel( + common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable(names=_names), + ) + ) + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot( + input=rel, names=named_struct.names))] + ) + + return resolve + + +def project( + plan: PlanOrUnbound, expressions: Iterable[UnboundExtendedExpression] +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + _plan = plan if isinstance(plan, stp.Plan) else plan(registry) + ns = infer_plan_schema(_plan) + bound_expressions: Iterable[stee.ExtendedExpression] = [ + e(ns, registry) for e in expressions] + + start_index = len(_plan.relations[-1].root.names) + + names = [e.output_names[0] for ee in bound_expressions for e in ee.referred_expr] + + rel = stalg.Rel( + project=stalg.ProjectRel( + common=stalg.RelCommon( + emit=stalg.RelCommon.Emit( + output_mapping=[i + start_index for i in range(len(names))] + ) + ), + input=_plan.relations[-1].root.input, + expressions=[ + e.expression for ee in bound_expressions for e in ee.referred_expr], + ) + ) + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], + **_merge_extensions(_plan, *bound_expressions), + ) + + return resolve + + +def filter( + plan: PlanOrUnbound, expression: UnboundExtendedExpression +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) + ns = infer_plan_schema(bound_plan) + bound_expression: stee.ExtendedExpression = expression(ns, registry) + + rel = stalg.Rel( + filter=stalg.FilterRel( + input=bound_plan.relations[-1].root.input, + condition=bound_expression.referred_expr[0].expression, + ) + ) + + names = ns.names + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], + **_merge_extensions(bound_plan, bound_expression), + ) + + return resolve + + +def sort( + plan: PlanOrUnbound, + expressions: Iterable[Union[UnboundExtendedExpression, tuple[UnboundExtendedExpression, stalg.SortField.SortDirection.ValueType]]] +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) + ns = infer_plan_schema(bound_plan) + + bound_expressions = [(e, stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST) if not isinstance(e, tuple) else e for e in expressions] + bound_expressions = [(e[0](ns, registry), e[1]) for e in bound_expressions] + + rel = stalg.Rel( + sort=stalg.SortRel( + input=bound_plan.relations[-1].root.input, + sorts=[ + stalg.SortField( + expr=e[0].referred_expr[0].expression, + direction=e[1], + ) + for e in bound_expressions + ], + ) + ) + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], + **_merge_extensions(bound_plan, *[e[0] for e in bound_expressions]), + ) + + return resolve + + +def set(inputs: Iterable[PlanOrUnbound], op: stalg.SetRel.SetOp) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_inputs = [i if isinstance(i, stp.Plan) else i(registry) for i in inputs] + rel = stalg.Rel( + set=stalg.SetRel( + inputs=[plan.relations[-1].root.input for plan in bound_inputs], op=op + ) + ) + + return stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot(input=rel, names=bound_inputs[0].relations[-1].root.names) + ) + ], + **_merge_extensions(*bound_inputs), + ) + + return resolve + +def fetch(plan: PlanOrUnbound, + offset: UnboundExtendedExpression, + count: UnboundExtendedExpression) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) + ns = infer_plan_schema(bound_plan) + + bound_offset = offset(ns, registry) + bound_count = count(ns, registry) + + rel = stalg.Rel( + fetch=stalg.FetchRel( + input=bound_plan.relations[-1].root.input, + offset_expr=bound_offset.referred_expr[0].expression, + count_expr=bound_count.referred_expr[0].expression + ) + ) + + return stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot(input=rel, names=bound_plan.relations[-1].root.names) + ) + ], + **_merge_extensions(bound_plan, bound_offset, bound_count), + ) + + return resolve + + +def join( + left: PlanOrUnbound, + right: PlanOrUnbound, + expression: UnboundExtendedExpression, + type: stalg.JoinRel.JoinType, +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_left = left if isinstance(left, stp.Plan) else left(registry) + bound_right = right if isinstance(right, stp.Plan) else right(registry) + left_ns = infer_plan_schema(bound_left) + right_ns = infer_plan_schema(bound_right) + + ns = stt.NamedStruct( + struct=stt.Type.Struct( + types=list(left_ns.struct.types) + list(right_ns.struct.types), + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ), + names=list(left_ns.names) + list(right_ns.names), + ) + bound_expression: stee.ExtendedExpression = expression(ns, registry) + + rel = stalg.Rel( + join=stalg.JoinRel( + left=bound_left.relations[-1].root.input, + right=bound_right.relations[-1].root.input, + expression=bound_expression.referred_expr[0].expression, + type=type, + ) + ) + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], + **_merge_extensions(bound_left, bound_right, bound_expression), + ) + + return resolve + +def cross( + left: PlanOrUnbound, + right: PlanOrUnbound, +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_left = left if isinstance(left, stp.Plan) else left(registry) + bound_right = right if isinstance(right, stp.Plan) else right(registry) + left_ns = infer_plan_schema(bound_left) + right_ns = infer_plan_schema(bound_right) + + ns = stt.NamedStruct( + struct=stt.Type.Struct( + types=list(left_ns.struct.types) + list(right_ns.struct.types), + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ), + names=list(left_ns.names) + list(right_ns.names), + ) + + rel = stalg.Rel( + cross=stalg.CrossRel( + left=bound_left.relations[-1].root.input, + right=bound_right.relations[-1].root.input + ) + ) + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], + **_merge_extensions(bound_left, bound_right), + ) + + return resolve + +# TODO grouping sets +def aggregate( + input: PlanOrUnbound, + grouping_expressions: Iterable[UnboundExtendedExpression], + measures: Iterable[UnboundExtendedExpression], +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_input = input if isinstance(input, stp.Plan) else input(registry) + ns = infer_plan_schema(bound_input) + + bound_grouping_expressions = [e(ns, registry) for e in grouping_expressions] + bound_measures = [e(ns, registry) for e in measures] + + rel = stalg.Rel( + aggregate=stalg.AggregateRel( + input=bound_input.relations[-1].root.input, + grouping_expressions=[ + e.referred_expr[0].expression for e in bound_grouping_expressions + ], + groupings=[ + stalg.AggregateRel.Grouping( + expression_references=range(len(bound_grouping_expressions)), + grouping_expressions=[ + e.referred_expr[0].expression for e in bound_grouping_expressions + ], + ) + ], + measures=[ + stalg.AggregateRel.Measure(measure=m.referred_expr[0].measure) + for m in bound_measures + ], + ) + ) + + names = [e.referred_expr[0].output_names[0] for e in bound_grouping_expressions] + [ + e.referred_expr[0].output_names[0] for e in bound_measures + ] + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], + **_merge_extensions(bound_input, *bound_grouping_expressions, *bound_measures), + ) + + return resolve diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py new file mode 100644 index 0000000..d7468f0 --- /dev/null +++ b/src/substrait/builders/type.py @@ -0,0 +1,76 @@ +from typing import Iterable +import substrait.gen.proto.type_pb2 as stt + +def boolean(nullable=True) -> stt.Type: + return stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def i8(nullable=True) -> stt.Type: + return stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def i16(nullable=True) -> stt.Type: + return stt.Type(i16=stt.Type.I16(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def i32(nullable=True) -> stt.Type: + return stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def i64(nullable=True) -> stt.Type: + return stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def fp32(nullable=True) -> stt.Type: + return stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def fp64(nullable=True) -> stt.Type: + return stt.Type(fp64=stt.Type.FP64(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def string(nullable=True) -> stt.Type: + return stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def binary(nullable=True) -> stt.Type: + return stt.Type(binary=stt.Type.Binary(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def date(nullable=True) -> stt.Type: + return stt.Type(date=stt.Type.Date(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def interval_year(nullable=True) -> stt.Type: + return stt.Type(interval_year=stt.Type.IntervalYear(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def interval_day(precision: int, nullable=True) -> stt.Type: + return stt.Type(interval_day=stt.Type.IntervalDay(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def interval_compound(precision: int, nullable=True) -> stt.Type: + return stt.Type(interval_compound=stt.Type.IntervalCompound(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def uuid(nullable=True) -> stt.Type: + return stt.Type(uuid=stt.Type.UUID(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def fixed_char(length: int, nullable=True) -> stt.Type: + return stt.Type(fixed_char=stt.Type.FixedChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def var_char(length: int, nullable=True) -> stt.Type: + return stt.Type(var_char=stt.Type.VarChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def fixed_binary(length: int, nullable=True) -> stt.Type: + return stt.Type(fixed_binary=stt.Type.FixedBinary(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def decimal(scale: int, precision: int, nullable=True) -> stt.Type: + return stt.Type(decimal=stt.Type.Decimal(scale=scale, precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +# PrecisionTime + +def precision_timestamp(precision: int, nullable=True) -> stt.Type: + return stt.Type(precision_timestamp=stt.Type.PrecisionTimestamp(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type: + return stt.Type(precision_timestamp_tz=stt.Type.PrecisionTimestampTZ(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type: + return stt.Type(struct=stt.Type.Struct(types=types, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def list(type: stt.Type, nullable=True) -> stt.Type: + return stt.Type(list=stt.Type.List(type=type, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def map(key: stt.Type, value: stt.Type, nullable=True) -> stt.Type: + return stt.Type(map=stt.Type.Map(key=key, value=value, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def named_struct(names: Iterable[str], struct: stt.Type) -> stt.NamedStruct: + return stt.NamedStruct(names=names, struct=struct.struct) diff --git a/src/substrait/type_inference.py b/src/substrait/type_inference.py index 082da29..5ed16fa 100644 --- a/src/substrait/type_inference.py +++ b/src/substrait/type_inference.py @@ -1,6 +1,7 @@ import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp def infer_literal_type(literal: stalg.Expression.Literal) -> stt.Type: @@ -346,3 +347,8 @@ def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct: types=[struct.types[i] for i in common.emit.output_mapping], nullability=struct.nullability, ) + +def infer_plan_schema(plan: stp.Plan) -> stt.NamedStruct: + schema = infer_rel_schema(plan.relations[-1].root.input) + + return stt.NamedStruct(names=plan.relations[-1].root.names, struct=schema) \ No newline at end of file diff --git a/tests/builders/extended_expression/test_column.py b/tests/builders/extended_expression/test_column.py index 306fe70..5287d13 100644 --- a/tests/builders/extended_expression/test_column.py +++ b/tests/builders/extended_expression/test_column.py @@ -3,7 +3,6 @@ import substrait.gen.proto.extended_expression_pb2 as stee from substrait.builders.extended_expression import column - struct = stt.Type.Struct( types=[ stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), diff --git a/tests/builders/plan/test_aggregate.py b/tests/builders/plan/test_aggregate.py new file mode 100644 index 0000000..175519e --- /dev/null +++ b/tests/builders/plan/test_aggregate.py @@ -0,0 +1,96 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.extensions.extensions_pb2 as ste +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, aggregate +from substrait.builders.extended_expression import column, aggregate_function +from substrait.extension_registry import ExtensionRegistry +from substrait.type_inference import infer_plan_schema +import yaml + +content = """%YAML 1.2 +--- +aggregate_functions: + - name: "count" + description: Count a set of values + impls: + - args: + - name: x + value: any + nullability: DECLARED_OUTPUT + decomposable: MANY + intermediate: i64 + return: i64 +""" + + +registry = ExtensionRegistry(load_default_extensions=False) +registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_aggregate(): + table = read_named_table('table', named_struct) + + group_expr = column('id') + measure_expr = aggregate_function('test_uri', 'count', column('is_applicable'), alias=['count']) + + actual = aggregate(table, + grouping_expressions=[group_expr], + measures=[measure_expr])(registry) + + ns = infer_plan_schema(table(None)) + + expected = stp.Plan( + extension_uris=[ + ste.SimpleExtensionURI( + extension_uri_anchor=1, + uri='test_uri' + ) + ], + extensions=[ + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=1, + function_anchor=1, + name='count' + ) + ) + ], + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + aggregate=stalg.AggregateRel( + input=table(None).relations[-1].root.input, + grouping_expressions=[ + group_expr(ns, registry).referred_expr[0].expression + ], + groupings=[ + stalg.AggregateRel.Grouping( + grouping_expressions=[ + group_expr(ns, registry).referred_expr[0].expression + ], + expression_references=[0] + ) + ], + measures=[ + stalg.AggregateRel.Measure( + measure=measure_expr(ns, registry).referred_expr[0].measure + ) + ] + + ) + ), + names=['id', 'count'] + ) + ) + ] + ) + + assert actual == expected \ No newline at end of file diff --git a/tests/builders/plan/test_cross.py b/tests/builders/plan/test_cross.py new file mode 100644 index 0000000..42cfb05 --- /dev/null +++ b/tests/builders/plan/test_cross.py @@ -0,0 +1,44 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64, string +from substrait.builders.plan import read_named_table, cross +from substrait.builders.extended_expression import literal +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +named_struct_2 = stt.NamedStruct( + names=["fk_id", "name"], struct=stt.Type.Struct(types=[i64(nullable=False), string()]) +) + +def test_cross_join(): + table = read_named_table('table', named_struct) + table2 = read_named_table('table2', named_struct_2) + + actual = cross(table, table2)(registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + cross=stalg.CrossRel( + left=table(None).relations[-1].root.input, + right=table2(None).relations[-1].root.input, + ) + ), + names=['id', 'is_applicable', 'fk_id', 'name'] + ) + ) + ] + ) + + assert actual == expected + diff --git a/tests/builders/plan/test_fetch.py b/tests/builders/plan/test_fetch.py new file mode 100644 index 0000000..f53a9b9 --- /dev/null +++ b/tests/builders/plan/test_fetch.py @@ -0,0 +1,44 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, fetch +from substrait.builders.extended_expression import literal +from substrait.type_inference import infer_plan_schema +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_fetch(): + table = read_named_table('table', named_struct) + + offset = literal(10, i64()) + count = literal(5, i64()) + + actual = fetch(table, offset=offset, count=count)(registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + fetch=stalg.FetchRel( + input=table(None).relations[-1].root.input, + offset_expr=offset(None, None).referred_expr[0].expression, + count_expr=count(None, None).referred_expr[0].expression + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected + diff --git a/tests/builders/plan/test_filter.py b/tests/builders/plan/test_filter.py new file mode 100644 index 0000000..e40ed22 --- /dev/null +++ b/tests/builders/plan/test_filter.py @@ -0,0 +1,43 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, filter +from substrait.builders.extended_expression import column, literal +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_filter(): + table = read_named_table('table', named_struct) + + actual = filter(table, literal(True, boolean()))(registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + filter=stalg.FilterRel( + input=table(None).relations[-1].root.input, + condition=stalg.Expression( + literal=stalg.Expression.Literal( + boolean=True, + nullable=True + ) + ) + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected \ No newline at end of file diff --git a/tests/builders/plan/test_join.py b/tests/builders/plan/test_join.py new file mode 100644 index 0000000..42ad30a --- /dev/null +++ b/tests/builders/plan/test_join.py @@ -0,0 +1,46 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64, string +from substrait.builders.plan import read_named_table, join +from substrait.builders.extended_expression import literal +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +named_struct_2 = stt.NamedStruct( + names=["fk_id", "name"], struct=stt.Type.Struct(types=[i64(nullable=False), string()]) +) + +def test_join(): + table = read_named_table('table', named_struct) + table2 = read_named_table('table2', named_struct_2) + + actual = join(table, table2, literal(True, boolean()), stalg.JoinRel.JOIN_TYPE_INNER)(registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + join=stalg.JoinRel( + left=table(None).relations[-1].root.input, + right=table2(None).relations[-1].root.input, + expression=literal(True, boolean())(None, None).referred_expr[0].expression, + type=stalg.JoinRel.JOIN_TYPE_INNER + ) + ), + names=['id', 'is_applicable', 'fk_id', 'name'] + ) + ) + ] + ) + + assert actual == expected + diff --git a/tests/builders/plan/test_project.py b/tests/builders/plan/test_project.py new file mode 100644 index 0000000..9535a32 --- /dev/null +++ b/tests/builders/plan/test_project.py @@ -0,0 +1,48 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, project +from substrait.builders.extended_expression import column +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_project(): + table = read_named_table('table', named_struct) + + actual = project(table, [column('id')])(registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + project=stalg.ProjectRel( + common=stalg.RelCommon(emit=stalg.RelCommon.Emit(output_mapping=[2])), + input=table(None).relations[-1].root.input, + expressions=[ + stalg.Expression( + selection=stalg.Expression.FieldReference( + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField(field=0) + ), + root_reference=stalg.Expression.FieldReference.RootReference() + ) + ) + ] + ) + ), + names=['id'] + ) + ) + ] + ) + + assert actual == expected \ No newline at end of file diff --git a/tests/builders/plan/test_read.py b/tests/builders/plan/test_read.py new file mode 100644 index 0000000..7e8fdf7 --- /dev/null +++ b/tests/builders/plan/test_read.py @@ -0,0 +1,55 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_read_rel(): + actual = read_named_table('example_table', named_struct)(None) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + read=stalg.ReadRel( + common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable(names=['example_table']) + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected + +def test_read_rel_db(): + actual = read_named_table(['example_db', 'example_table'], named_struct)(None) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + read=stalg.ReadRel( + common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable(names=['example_db', 'example_table']) + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected \ No newline at end of file diff --git a/tests/builders/plan/test_set.py b/tests/builders/plan/test_set.py new file mode 100644 index 0000000..c701024 --- /dev/null +++ b/tests/builders/plan/test_set.py @@ -0,0 +1,44 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, set +from substrait.builders.extended_expression import column +from substrait.type_inference import infer_plan_schema +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_set(): + table = read_named_table('table', named_struct) + table2 = read_named_table('table2', named_struct) + + actual = set([table, table2], stalg.SetRel.SET_OP_UNION_ALL)(None) + + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + set=stalg.SetRel( + inputs=[ + table(None).relations[-1].root.input, + table2(None).relations[-1].root.input, + ], + op=stalg.SetRel.SET_OP_UNION_ALL + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected diff --git a/tests/builders/plan/test_sort.py b/tests/builders/plan/test_sort.py new file mode 100644 index 0000000..66cfbd5 --- /dev/null +++ b/tests/builders/plan/test_sort.py @@ -0,0 +1,76 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, sort +from substrait.builders.extended_expression import column +from substrait.type_inference import infer_plan_schema +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_sort_no_direction(): + table = read_named_table('table', named_struct) + + col = column('id') + + actual = sort(table, expressions=[col])(registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + sort=stalg.SortRel( + input=table(None).relations[-1].root.input, + sorts=[ + stalg.SortField( + direction=stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST, + expr=col(infer_plan_schema(table(None)), registry).referred_expr[0].expression + ) + ] + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected + +def test_sort_direction(): + table = read_named_table('table', named_struct) + + col = column('id') + + actual = sort(table, expressions=[(col, stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST)])(registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + sort=stalg.SortRel( + input=table(None).relations[-1].root.input, + sorts=[ + stalg.SortField( + direction=stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST, + expr=col(infer_plan_schema(table(None)), registry).referred_expr[0].expression + ) + ] + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected \ No newline at end of file