From 8ab23590c90f2562a3e503f124255c151015d66d Mon Sep 17 00:00:00 2001 From: tokoko Date: Wed, 16 Apr 2025 19:29:36 +0000 Subject: [PATCH 1/5] feat: add plan builders for read, project, filter --- src/substrait/builders/plan.py | 81 +++++++++++++++++++ src/substrait/builders/type.py | 73 +++++++++++++++++ src/substrait/type_inference.py | 6 ++ .../extended_expression/test_column.py | 1 - tests/builders/plan/test_filter.py | 43 ++++++++++ tests/builders/plan/test_project.py | 48 +++++++++++ tests/builders/plan/test_read.py | 55 +++++++++++++ 7 files changed, 306 insertions(+), 1 deletion(-) create mode 100644 src/substrait/builders/plan.py create mode 100644 src/substrait/builders/type.py create mode 100644 tests/builders/plan/test_filter.py create mode 100644 tests/builders/plan/test_project.py create mode 100644 tests/builders/plan/test_read.py diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py new file mode 100644 index 0000000..2deafef --- /dev/null +++ b/src/substrait/builders/plan.py @@ -0,0 +1,81 @@ +from typing import Iterable, Union + +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +from substrait.extension_registry import ExtensionRegistry +from substrait.builders.extended_expression import UnboundExtendedExpression +from substrait.type_inference import infer_plan_schema +from substrait.utils import merge_extension_declarations, merge_extension_uris + + +def _merge_extensions(*objs): + return { + "extension_uris": merge_extension_uris(*[b.extension_uris for b in objs]), + "extensions": merge_extension_declarations(*[b.extensions for b in objs]), + } + + +def read_named_table(names: Union[str, Iterable[str]], named_struct: stt.NamedStruct): + names = [names] if isinstance(names, str) else names + + rel = stalg.Rel( + read=stalg.ReadRel( + common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable(names=names), + ) + ) + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=named_struct.names))] + ) + + +def project( + plan: stp.Plan, expressions: Iterable[UnboundExtendedExpression], registry: ExtensionRegistry +) -> stp.Plan: + ns = infer_plan_schema(plan) + expressions: Iterable[stee.ExtendedExpression] = [e(ns, registry) for e in expressions] + + start_index = len(plan.relations[-1].root.names) + + names = [e.output_names[0] for ee in expressions for e in ee.referred_expr] + + rel = stalg.Rel( + project=stalg.ProjectRel( + common=stalg.RelCommon( + emit=stalg.RelCommon.Emit( + output_mapping=[i + start_index for i in range(len(names))] + ) + ), + input=plan.relations[-1].root.input, + expressions=[e.expression for ee in expressions for e in ee.referred_expr], + ) + ) + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], + **_merge_extensions(plan, *expressions), + ) + +def filter( + plan: stp.Plan, expression: UnboundExtendedExpression, registry: ExtensionRegistry +) -> stp.Plan: + ns = infer_plan_schema(plan) + expression: stee.ExtendedExpression = expression(ns, registry) + + rel = stalg.Rel( + filter=stalg.FilterRel( + input=plan.relations[-1].root.input, + condition=expression.referred_expr[0].expression, + ) + ) + + names = ns.names + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], + **_merge_extensions(plan, expression), + ) diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py new file mode 100644 index 0000000..fd6bd9b --- /dev/null +++ b/src/substrait/builders/type.py @@ -0,0 +1,73 @@ +from typing import Iterable +import substrait.gen.proto.type_pb2 as stt + +def boolean(nullable=True): + return stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def i8(nullable=True): + return stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def i16(nullable=True): + return stt.Type(i16=stt.Type.I16(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def i32(nullable=True): + return stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def i64(nullable=True): + return stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def fp32(nullable=True): + return stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def fp64(nullable=True): + return stt.Type(fp64=stt.Type.FP64(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def string(nullable=True): + return stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def binary(nullable=True): + return stt.Type(binary=stt.Type.Binary(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def date(nullable=True): + return stt.Type(date=stt.Type.Date(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def interval_year(nullable=True): + return stt.Type(interval_year=stt.Type.IntervalYear(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def interval_day(precision: int, nullable=True): + return stt.Type(interval_day=stt.Type.IntervalDay(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def interval_compound(precision: int, nullable=True): + return stt.Type(interval_compound=stt.Type.IntervalCompound(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def uuid(nullable=True): + return stt.Type(uuid=stt.Type.UUID(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def fixed_char(length: int, nullable=True): + return stt.Type(fixed_char=stt.Type.FixedChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def var_char(length: int, nullable=True): + return stt.Type(var_char=stt.Type.VarChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def fixed_binary(length: int, nullable=True): + return stt.Type(fixed_binary=stt.Type.FixedBinary(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def decimal(scale: int, precision: int, nullable=True): + return stt.Type(decimal=stt.Type.Decimal(scale=scale, precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +# PrecisionTime + +def precision_timestamp(precision: int, nullable=True): + return stt.Type(precision_timestamp=stt.Type.PrecisionTimestamp(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def precision_timestamp_tz(precision: int, nullable=True): + return stt.Type(precision_timestamp_tz=stt.Type.PrecisionTimestampTZ(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def struct(types: Iterable[stt.Type], nullable=True): + return stt.Type(struct=stt.Type.Struct(types=types, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def list(type: stt.Type, nullable=True): + return stt.Type(list=stt.Type.List(type=type, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def map(key: stt.Type, value: stt.Type, nullable=True): + return stt.Type(map=stt.Type.Map(key=key, value=value, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) diff --git a/src/substrait/type_inference.py b/src/substrait/type_inference.py index 082da29..5ed16fa 100644 --- a/src/substrait/type_inference.py +++ b/src/substrait/type_inference.py @@ -1,6 +1,7 @@ import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.extended_expression_pb2 as stee import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp def infer_literal_type(literal: stalg.Expression.Literal) -> stt.Type: @@ -346,3 +347,8 @@ def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct: types=[struct.types[i] for i in common.emit.output_mapping], nullability=struct.nullability, ) + +def infer_plan_schema(plan: stp.Plan) -> stt.NamedStruct: + schema = infer_rel_schema(plan.relations[-1].root.input) + + return stt.NamedStruct(names=plan.relations[-1].root.names, struct=schema) \ No newline at end of file diff --git a/tests/builders/extended_expression/test_column.py b/tests/builders/extended_expression/test_column.py index 306fe70..5287d13 100644 --- a/tests/builders/extended_expression/test_column.py +++ b/tests/builders/extended_expression/test_column.py @@ -3,7 +3,6 @@ import substrait.gen.proto.extended_expression_pb2 as stee from substrait.builders.extended_expression import column - struct = stt.Type.Struct( types=[ stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), diff --git a/tests/builders/plan/test_filter.py b/tests/builders/plan/test_filter.py new file mode 100644 index 0000000..9375878 --- /dev/null +++ b/tests/builders/plan/test_filter.py @@ -0,0 +1,43 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, filter +from substrait.builders.extended_expression import column, literal +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_filter(): + table = read_named_table('table', named_struct) + + actual = filter(table, literal(True, boolean()), registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + filter=stalg.FilterRel( + input=table.relations[-1].root.input, + condition=stalg.Expression( + literal=stalg.Expression.Literal( + boolean=True, + nullable=True + ) + ) + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected \ No newline at end of file diff --git a/tests/builders/plan/test_project.py b/tests/builders/plan/test_project.py new file mode 100644 index 0000000..4e8cbff --- /dev/null +++ b/tests/builders/plan/test_project.py @@ -0,0 +1,48 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, project +from substrait.builders.extended_expression import column +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_project(): + table = read_named_table('table', named_struct) + + actual = project(table, [column('id')], registry=registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + project=stalg.ProjectRel( + common=stalg.RelCommon(emit=stalg.RelCommon.Emit(output_mapping=[2])), + input=table.relations[-1].root.input, + expressions=[ + stalg.Expression( + selection=stalg.Expression.FieldReference( + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField(field=0) + ), + root_reference=stalg.Expression.FieldReference.RootReference() + ) + ) + ] + ) + ), + names=['id'] + ) + ) + ] + ) + + assert actual == expected \ No newline at end of file diff --git a/tests/builders/plan/test_read.py b/tests/builders/plan/test_read.py new file mode 100644 index 0000000..73cb1aa --- /dev/null +++ b/tests/builders/plan/test_read.py @@ -0,0 +1,55 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_read_rel(): + actual = read_named_table('example_table', named_struct) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + read=stalg.ReadRel( + common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable(names=['example_table']) + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected + +def test_read_rel_db(): + actual = read_named_table(['example_db', 'example_table'], named_struct) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + read=stalg.ReadRel( + common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable(names=['example_db', 'example_table']) + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected \ No newline at end of file From 738f9511dab49347adfd001219936f936bdc9a9a Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 22 Apr 2025 19:45:57 +0000 Subject: [PATCH 2/5] feat: add builders for other rels --- src/substrait/builders/plan.py | 187 +++++++++++++++++++++++++- tests/builders/plan/test_aggregate.py | 100 ++++++++++++++ tests/builders/plan/test_cross.py | 44 ++++++ tests/builders/plan/test_fetch.py | 44 ++++++ tests/builders/plan/test_join.py | 46 +++++++ tests/builders/plan/test_set.py | 44 ++++++ tests/builders/plan/test_sort.py | 76 +++++++++++ 7 files changed, 537 insertions(+), 4 deletions(-) create mode 100644 tests/builders/plan/test_aggregate.py create mode 100644 tests/builders/plan/test_cross.py create mode 100644 tests/builders/plan/test_fetch.py create mode 100644 tests/builders/plan/test_join.py create mode 100644 tests/builders/plan/test_set.py create mode 100644 tests/builders/plan/test_sort.py diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index 2deafef..5fbca7d 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -19,7 +19,7 @@ def _merge_extensions(*objs): def read_named_table(names: Union[str, Iterable[str]], named_struct: stt.NamedStruct): names = [names] if isinstance(names, str) else names - + rel = stalg.Rel( read=stalg.ReadRel( common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), @@ -29,7 +29,8 @@ def read_named_table(names: Union[str, Iterable[str]], named_struct: stt.NamedSt ) return stp.Plan( - relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=named_struct.names))] + relations=[stp.PlanRel(root=stalg.RelRoot( + input=rel, names=named_struct.names))] ) @@ -37,7 +38,8 @@ def project( plan: stp.Plan, expressions: Iterable[UnboundExtendedExpression], registry: ExtensionRegistry ) -> stp.Plan: ns = infer_plan_schema(plan) - expressions: Iterable[stee.ExtendedExpression] = [e(ns, registry) for e in expressions] + expressions: Iterable[stee.ExtendedExpression] = [ + e(ns, registry) for e in expressions] start_index = len(plan.relations[-1].root.names) @@ -51,7 +53,8 @@ def project( ) ), input=plan.relations[-1].root.input, - expressions=[e.expression for ee in expressions for e in ee.referred_expr], + expressions=[ + e.expression for ee in expressions for e in ee.referred_expr], ) ) @@ -60,6 +63,7 @@ def project( **_merge_extensions(plan, *expressions), ) + def filter( plan: stp.Plan, expression: UnboundExtendedExpression, registry: ExtensionRegistry ) -> stp.Plan: @@ -79,3 +83,178 @@ def filter( relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], **_merge_extensions(plan, expression), ) + + +def sort( + plan: stp.Plan, + expressions: Iterable[Union[UnboundExtendedExpression, tuple[UnboundExtendedExpression, stalg.SortField.SortDirection.ValueType]]], + registry: ExtensionRegistry +) -> stp.Plan: + ns = infer_plan_schema(plan) + + expressions = [(e, stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST) if not isinstance(e, tuple) else e for e in expressions] + expressions = [(e[0](ns, registry), e[1]) for e in expressions] + + rel = stalg.Rel( + sort=stalg.SortRel( + input=plan.relations[-1].root.input, + sorts=[ + stalg.SortField( + expr=e[0].referred_expr[0].expression, + direction=e[1], + ) + for e in expressions + ], + ) + ) + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], + **_merge_extensions(plan, *[e[0] for e in expressions]), + ) + + +def set(inputs: Iterable[stp.Plan], op: stalg.SetRel.SetOp) -> stp.Plan: + rel = stalg.Rel( + set=stalg.SetRel( + inputs=[plan.relations[-1].root.input for plan in inputs], op=op + ) + ) + + return stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot(input=rel, names=inputs[0].relations[-1].root.names) + ) + ], + **_merge_extensions(*inputs), + ) + +def fetch(plan: stp.Plan, + offset: UnboundExtendedExpression, + count: UnboundExtendedExpression, + registry: ExtensionRegistry): + ns = infer_plan_schema(plan) + + bound_offset = offset(ns, registry) + bound_count = count(ns, registry) + + rel = stalg.Rel( + fetch=stalg.FetchRel( + input=plan.relations[-1].root.input, + offset_expr=bound_offset.referred_expr[0].expression, + count_expr=bound_count.referred_expr[0].expression + ) + ) + + return stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot(input=rel, names=plan.relations[-1].root.names) + ) + ], + **_merge_extensions(plan, bound_offset, bound_count), + ) + + +def join( + left: stp.Plan, + right: stp.Plan, + expression: UnboundExtendedExpression, + type: stalg.JoinRel.JoinType, + registry: ExtensionRegistry, +): + left_ns = infer_plan_schema(left) + right_ns = infer_plan_schema(right) + ns = stt.NamedStruct( + struct=stt.Type.Struct( + types=list(left_ns.struct.types) + list(right_ns.struct.types), + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ), + names=list(left_ns.names) + list(right_ns.names), + ) + expression: stee.ExtendedExpression = expression(ns, registry) + + rel = stalg.Rel( + join=stalg.JoinRel( + left=left.relations[-1].root.input, + right=right.relations[-1].root.input, + expression=expression.referred_expr[0].expression, + type=type, + ) + ) + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], + **_merge_extensions(left, right, expression), + ) + +def cross( + left: stp.Plan, + right: stp.Plan, + registry: ExtensionRegistry, +): + left_ns = infer_plan_schema(left) + right_ns = infer_plan_schema(right) + + ns = stt.NamedStruct( + struct=stt.Type.Struct( + types=list(left_ns.struct.types) + list(right_ns.struct.types), + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ), + names=list(left_ns.names) + list(right_ns.names), + ) + + rel = stalg.Rel( + cross=stalg.CrossRel( + left=left.relations[-1].root.input, + right=right.relations[-1].root.input + ) + ) + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], + **_merge_extensions(left, right), + ) + +# TODO grouping sets +def aggregate( + input: stp.Plan, + grouping_expressions: Iterable[UnboundExtendedExpression], + measures: Iterable[UnboundExtendedExpression], + registry: ExtensionRegistry, +): + ns = infer_plan_schema(input) + + grouping_expressions = [e(ns, registry) for e in grouping_expressions] + measures = [e(ns, registry) for e in measures] + + rel = stalg.Rel( + aggregate=stalg.AggregateRel( + input=input.relations[-1].root.input, + grouping_expressions=[ + e.referred_expr[0].expression for e in grouping_expressions + ], + groupings=[ + stalg.AggregateRel.Grouping( + expression_references=range(len(grouping_expressions)), + grouping_expressions=[ + e.referred_expr[0].expression for e in grouping_expressions + ], + ) + ], + measures=[ + stalg.AggregateRel.Measure(measure=m.referred_expr[0].measure) + for m in measures + ], + ) + ) + + names = [e.referred_expr[0].output_names[0] for e in grouping_expressions] + [ + e.referred_expr[0].output_names[0] for e in measures + ] + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], + **_merge_extensions(input, *grouping_expressions, *measures), + ) diff --git a/tests/builders/plan/test_aggregate.py b/tests/builders/plan/test_aggregate.py new file mode 100644 index 0000000..50508da --- /dev/null +++ b/tests/builders/plan/test_aggregate.py @@ -0,0 +1,100 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.extensions.extensions_pb2 as ste +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, aggregate +from substrait.builders.extended_expression import column, aggregate_function +from substrait.extension_registry import ExtensionRegistry +from substrait.type_inference import infer_plan_schema +import yaml + +content = """%YAML 1.2 +--- +aggregate_functions: + - name: "count" + description: Count a set of values + impls: + - args: + - name: x + value: any + nullability: DECLARED_OUTPUT + decomposable: MANY + intermediate: i64 + return: i64 +""" + + +registry = ExtensionRegistry(load_default_extensions=False) +registry.register_extension_dict(yaml.safe_load(content), uri="test_uri") + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_aggregate(): + table = read_named_table('table', named_struct) + + group_expr = column('id') + measure_expr = aggregate_function('test_uri', 'count', column('is_applicable'), alias=['count']) + + actual = aggregate(table, + grouping_expressions=[group_expr], + measures=[measure_expr], + registry=registry) + + ns = infer_plan_schema(table) + + expected = stp.Plan( + extension_uris=[ + ste.SimpleExtensionURI( + extension_uri_anchor=1, + uri='test_uri' + ) + ], + extensions=[ + ste.SimpleExtensionDeclaration( + extension_function=ste.SimpleExtensionDeclaration.ExtensionFunction( + extension_uri_reference=1, + function_anchor=1, + name='count' + ) + ) + ], + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + aggregate=stalg.AggregateRel( + input=table.relations[-1].root.input, + grouping_expressions=[ + group_expr(ns, registry).referred_expr[0].expression + ], + groupings=[ + stalg.AggregateRel.Grouping( + grouping_expressions=[ + group_expr(ns, registry).referred_expr[0].expression + ], + expression_references=[0] + ) + ], + measures=[ + stalg.AggregateRel.Measure( + measure=measure_expr(ns, registry).referred_expr[0].measure + ) + ] + + ) + ), + names=['id', 'count'] + ) + ) + ] + ) + + print(actual) + print(expected) + + assert actual == expected \ No newline at end of file diff --git a/tests/builders/plan/test_cross.py b/tests/builders/plan/test_cross.py new file mode 100644 index 0000000..37e4487 --- /dev/null +++ b/tests/builders/plan/test_cross.py @@ -0,0 +1,44 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64, string +from substrait.builders.plan import read_named_table, cross +from substrait.builders.extended_expression import literal +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +named_struct_2 = stt.NamedStruct( + names=["fk_id", "name"], struct=stt.Type.Struct(types=[i64(nullable=False), string()]) +) + +def test_join(): + table = read_named_table('table', named_struct) + table2 = read_named_table('table2', named_struct_2) + + actual = cross(table, table2, registry=registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + cross=stalg.CrossRel( + left=table.relations[-1].root.input, + right=table2.relations[-1].root.input, + ) + ), + names=['id', 'is_applicable', 'fk_id', 'name'] + ) + ) + ] + ) + + assert actual == expected + diff --git a/tests/builders/plan/test_fetch.py b/tests/builders/plan/test_fetch.py new file mode 100644 index 0000000..27a64ee --- /dev/null +++ b/tests/builders/plan/test_fetch.py @@ -0,0 +1,44 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, fetch +from substrait.builders.extended_expression import literal +from substrait.type_inference import infer_plan_schema +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_fetch(): + table = read_named_table('table', named_struct) + + offset = literal(10, i64()) + count = literal(5, i64()) + + actual = fetch(table, offset=offset, count=count, registry=registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + fetch=stalg.FetchRel( + input=table.relations[-1].root.input, + offset_expr=offset(None, None).referred_expr[0].expression, + count_expr=count(None, None).referred_expr[0].expression + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected + diff --git a/tests/builders/plan/test_join.py b/tests/builders/plan/test_join.py new file mode 100644 index 0000000..fd4cf4f --- /dev/null +++ b/tests/builders/plan/test_join.py @@ -0,0 +1,46 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64, string +from substrait.builders.plan import read_named_table, join +from substrait.builders.extended_expression import literal +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +named_struct_2 = stt.NamedStruct( + names=["fk_id", "name"], struct=stt.Type.Struct(types=[i64(nullable=False), string()]) +) + +def test_join(): + table = read_named_table('table', named_struct) + table2 = read_named_table('table2', named_struct_2) + + actual = join(table, table2, literal(True, boolean()), stalg.JoinRel.JOIN_TYPE_INNER, registry=registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + join=stalg.JoinRel( + left=table.relations[-1].root.input, + right=table2.relations[-1].root.input, + expression=literal(True, boolean())(None, None).referred_expr[0].expression, + type=stalg.JoinRel.JOIN_TYPE_INNER + ) + ), + names=['id', 'is_applicable', 'fk_id', 'name'] + ) + ) + ] + ) + + assert actual == expected + diff --git a/tests/builders/plan/test_set.py b/tests/builders/plan/test_set.py new file mode 100644 index 0000000..290996b --- /dev/null +++ b/tests/builders/plan/test_set.py @@ -0,0 +1,44 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, set +from substrait.builders.extended_expression import column +from substrait.type_inference import infer_plan_schema +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_set(): + table = read_named_table('table', named_struct) + table2 = read_named_table('table2', named_struct) + + actual = set([table, table2], stalg.SetRel.SET_OP_UNION_ALL) + + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + set=stalg.SetRel( + inputs=[ + table.relations[-1].root.input, + table2.relations[-1].root.input, + ], + op=stalg.SetRel.SET_OP_UNION_ALL + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected diff --git a/tests/builders/plan/test_sort.py b/tests/builders/plan/test_sort.py new file mode 100644 index 0000000..738da03 --- /dev/null +++ b/tests/builders/plan/test_sort.py @@ -0,0 +1,76 @@ +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.plan_pb2 as stp +import substrait.gen.proto.algebra_pb2 as stalg +from substrait.builders.type import boolean, i64 +from substrait.builders.plan import read_named_table, sort +from substrait.builders.extended_expression import column +from substrait.type_inference import infer_plan_schema +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=False) + +struct = stt.Type.Struct(types=[i64(nullable=False), boolean()]) + +named_struct = stt.NamedStruct( + names=["id", "is_applicable"], struct=struct +) + +def test_sort_no_direction(): + table = read_named_table('table', named_struct) + + col = column('id') + + actual = sort(table, expressions=[col], registry=registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + sort=stalg.SortRel( + input=table.relations[-1].root.input, + sorts=[ + stalg.SortField( + direction=stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST, + expr=col(infer_plan_schema(table), registry).referred_expr[0].expression + ) + ] + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected + +def test_sort_direction(): + table = read_named_table('table', named_struct) + + col = column('id') + + actual = sort(table, expressions=[(col, stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST)], registry=registry) + + expected = stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot( + input=stalg.Rel( + sort=stalg.SortRel( + input=table.relations[-1].root.input, + sorts=[ + stalg.SortField( + direction=stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST, + expr=col(infer_plan_schema(table), registry).referred_expr[0].expression + ) + ] + ) + ), + names=['id', 'is_applicable'] + ) + ) + ] + ) + + assert actual == expected \ No newline at end of file From 9dbcf210eceece3588c7fcaa6b10e7680df1e294 Mon Sep 17 00:00:00 2001 From: tokoko Date: Mon, 28 Apr 2025 13:52:58 +0000 Subject: [PATCH 3/5] feat: make plan builders return UnboundPlan --- examples/builder_example.py | 138 +++++++++ src/substrait/builders/plan.py | 418 ++++++++++++++------------ src/substrait/builders/type.py | 3 + tests/builders/plan/test_aggregate.py | 10 +- tests/builders/plan/test_cross.py | 8 +- tests/builders/plan/test_fetch.py | 4 +- tests/builders/plan/test_filter.py | 4 +- tests/builders/plan/test_join.py | 6 +- tests/builders/plan/test_project.py | 4 +- tests/builders/plan/test_read.py | 4 +- tests/builders/plan/test_set.py | 6 +- tests/builders/plan/test_sort.py | 12 +- 12 files changed, 394 insertions(+), 223 deletions(-) create mode 100644 examples/builder_example.py diff --git a/examples/builder_example.py b/examples/builder_example.py new file mode 100644 index 0000000..b0c9ed6 --- /dev/null +++ b/examples/builder_example.py @@ -0,0 +1,138 @@ +from substrait.builders.plan import read_named_table, project, filter +from substrait.builders.extended_expression import column, scalar_function, literal +from substrait.builders.type import i64, boolean, struct, named_struct +from substrait.extension_registry import ExtensionRegistry + +registry = ExtensionRegistry(load_default_extensions=True) + +ns = named_struct( + names=["id", "is_applicable"], + struct=struct( + types=[ + i64(nullable=False), + boolean() + ] + ) +) + +table = read_named_table('example_table', ns) +table = filter(table, expression=column('is_applicable')) +table = filter(table, expression=scalar_function('functions_comparison.yaml', 'lt', column('id'), literal(100, i64()))) +table = project(table, expressions=[column('id')]) + +print(table(registry)) + +""" +extension_uris { + extension_uri_anchor: 13 + uri: "functions_comparison.yaml" +} +extensions { + extension_function { + extension_uri_reference: 13 + function_anchor: 495 + name: "lt" + } +} +relations { + root { + input { + project { + common { + emit { + output_mapping: 2 + } + } + input { + filter { + input { + filter { + input { + read { + common { + direct { + } + } + base_schema { + names: "id" + names: "is_applicable" + struct { + types { + i64 { + nullability: NULLABILITY_REQUIRED + } + } + types { + bool { + nullability: NULLABILITY_NULLABLE + } + } + nullability: NULLABILITY_NULLABLE + } + } + named_table { + names: "example_table" + } + } + } + condition { + selection { + direct_reference { + struct_field { + field: 1 + } + } + root_reference { + } + } + } + } + } + condition { + scalar_function { + function_reference: 495 + output_type { + bool { + nullability: NULLABILITY_NULLABLE + } + } + arguments { + value { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + } + arguments { + value { + literal { + i64: 100 + nullable: true + } + } + } + } + } + } + } + expressions { + selection { + direct_reference { + struct_field { + } + } + root_reference { + } + } + } + } + } + names: "id" + } +} +""" diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index 5fbca7d..7b348d5 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -1,4 +1,4 @@ -from typing import Iterable, Union +from typing import Iterable, Union, Callable import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.plan_pb2 as stp @@ -9,6 +9,7 @@ from substrait.type_inference import infer_plan_schema from substrait.utils import merge_extension_declarations, merge_extension_uris +UnboundPlan = Callable[[ExtensionRegistry], stp.Plan] def _merge_extensions(*objs): return { @@ -17,144 +18,165 @@ def _merge_extensions(*objs): } -def read_named_table(names: Union[str, Iterable[str]], named_struct: stt.NamedStruct): - names = [names] if isinstance(names, str) else names +def read_named_table(names: Union[str, Iterable[str]], named_struct: stt.NamedStruct) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + _names = [names] if isinstance(names, str) else names - rel = stalg.Rel( - read=stalg.ReadRel( - common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), - base_schema=named_struct, - named_table=stalg.ReadRel.NamedTable(names=names), + rel = stalg.Rel( + read=stalg.ReadRel( + common=stalg.RelCommon(direct=stalg.RelCommon.Direct()), + base_schema=named_struct, + named_table=stalg.ReadRel.NamedTable(names=_names), + ) ) - ) - return stp.Plan( - relations=[stp.PlanRel(root=stalg.RelRoot( - input=rel, names=named_struct.names))] - ) + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot( + input=rel, names=named_struct.names))] + ) + + return resolve def project( - plan: stp.Plan, expressions: Iterable[UnboundExtendedExpression], registry: ExtensionRegistry -) -> stp.Plan: - ns = infer_plan_schema(plan) - expressions: Iterable[stee.ExtendedExpression] = [ - e(ns, registry) for e in expressions] - - start_index = len(plan.relations[-1].root.names) - - names = [e.output_names[0] for ee in expressions for e in ee.referred_expr] - - rel = stalg.Rel( - project=stalg.ProjectRel( - common=stalg.RelCommon( - emit=stalg.RelCommon.Emit( - output_mapping=[i + start_index for i in range(len(names))] - ) - ), - input=plan.relations[-1].root.input, - expressions=[ - e.expression for ee in expressions for e in ee.referred_expr], + plan: Union[stp.Plan, UnboundPlan], expressions: Iterable[UnboundExtendedExpression] +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + _plan = plan if isinstance(plan, stp.Plan) else plan(registry) + ns = infer_plan_schema(_plan) + bound_expressions: Iterable[stee.ExtendedExpression] = [ + e(ns, registry) for e in expressions] + + start_index = len(_plan.relations[-1].root.names) + + names = [e.output_names[0] for ee in bound_expressions for e in ee.referred_expr] + + rel = stalg.Rel( + project=stalg.ProjectRel( + common=stalg.RelCommon( + emit=stalg.RelCommon.Emit( + output_mapping=[i + start_index for i in range(len(names))] + ) + ), + input=_plan.relations[-1].root.input, + expressions=[ + e.expression for ee in bound_expressions for e in ee.referred_expr], + ) ) - ) - return stp.Plan( - relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], - **_merge_extensions(plan, *expressions), - ) + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], + **_merge_extensions(_plan, *bound_expressions), + ) + + return resolve def filter( - plan: stp.Plan, expression: UnboundExtendedExpression, registry: ExtensionRegistry -) -> stp.Plan: - ns = infer_plan_schema(plan) - expression: stee.ExtendedExpression = expression(ns, registry) - - rel = stalg.Rel( - filter=stalg.FilterRel( - input=plan.relations[-1].root.input, - condition=expression.referred_expr[0].expression, + plan: Union[stp.Plan, UnboundPlan], expression: UnboundExtendedExpression +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) + ns = infer_plan_schema(bound_plan) + bound_expression: stee.ExtendedExpression = expression(ns, registry) + + rel = stalg.Rel( + filter=stalg.FilterRel( + input=bound_plan.relations[-1].root.input, + condition=bound_expression.referred_expr[0].expression, + ) ) - ) - names = ns.names + names = ns.names - return stp.Plan( - relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], - **_merge_extensions(plan, expression), - ) + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], + **_merge_extensions(bound_plan, bound_expression), + ) + + return resolve def sort( plan: stp.Plan, - expressions: Iterable[Union[UnboundExtendedExpression, tuple[UnboundExtendedExpression, stalg.SortField.SortDirection.ValueType]]], - registry: ExtensionRegistry -) -> stp.Plan: - ns = infer_plan_schema(plan) - - expressions = [(e, stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST) if not isinstance(e, tuple) else e for e in expressions] - expressions = [(e[0](ns, registry), e[1]) for e in expressions] - - rel = stalg.Rel( - sort=stalg.SortRel( - input=plan.relations[-1].root.input, - sorts=[ - stalg.SortField( - expr=e[0].referred_expr[0].expression, - direction=e[1], - ) - for e in expressions - ], + expressions: Iterable[Union[UnboundExtendedExpression, tuple[UnboundExtendedExpression, stalg.SortField.SortDirection.ValueType]]] +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) + ns = infer_plan_schema(bound_plan) + + bound_expressions = [(e, stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST) if not isinstance(e, tuple) else e for e in expressions] + bound_expressions = [(e[0](ns, registry), e[1]) for e in bound_expressions] + + rel = stalg.Rel( + sort=stalg.SortRel( + input=bound_plan.relations[-1].root.input, + sorts=[ + stalg.SortField( + expr=e[0].referred_expr[0].expression, + direction=e[1], + ) + for e in bound_expressions + ], + ) ) - ) - return stp.Plan( - relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], - **_merge_extensions(plan, *[e[0] for e in expressions]), - ) + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], + **_merge_extensions(bound_plan, *[e[0] for e in bound_expressions]), + ) + + return resolve -def set(inputs: Iterable[stp.Plan], op: stalg.SetRel.SetOp) -> stp.Plan: - rel = stalg.Rel( - set=stalg.SetRel( - inputs=[plan.relations[-1].root.input for plan in inputs], op=op +def set(inputs: Iterable[Union[stp.Plan, UnboundPlan]], op: stalg.SetRel.SetOp) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_inputs = [i if isinstance(i, stp.Plan) else i(registry) for i in inputs] + rel = stalg.Rel( + set=stalg.SetRel( + inputs=[plan.relations[-1].root.input for plan in bound_inputs], op=op + ) ) - ) - return stp.Plan( - relations=[ - stp.PlanRel( - root=stalg.RelRoot(input=rel, names=inputs[0].relations[-1].root.names) - ) - ], - **_merge_extensions(*inputs), - ) + return stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot(input=rel, names=bound_inputs[0].relations[-1].root.names) + ) + ], + **_merge_extensions(*bound_inputs), + ) + + return resolve def fetch(plan: stp.Plan, offset: UnboundExtendedExpression, - count: UnboundExtendedExpression, - registry: ExtensionRegistry): - ns = infer_plan_schema(plan) - - bound_offset = offset(ns, registry) - bound_count = count(ns, registry) - - rel = stalg.Rel( - fetch=stalg.FetchRel( - input=plan.relations[-1].root.input, - offset_expr=bound_offset.referred_expr[0].expression, - count_expr=bound_count.referred_expr[0].expression + count: UnboundExtendedExpression) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) + ns = infer_plan_schema(bound_plan) + + bound_offset = offset(ns, registry) + bound_count = count(ns, registry) + + rel = stalg.Rel( + fetch=stalg.FetchRel( + input=bound_plan.relations[-1].root.input, + offset_expr=bound_offset.referred_expr[0].expression, + count_expr=bound_count.referred_expr[0].expression + ) ) - ) - return stp.Plan( - relations=[ - stp.PlanRel( - root=stalg.RelRoot(input=rel, names=plan.relations[-1].root.names) - ) - ], - **_merge_extensions(plan, bound_offset, bound_count), - ) + return stp.Plan( + relations=[ + stp.PlanRel( + root=stalg.RelRoot(input=rel, names=bound_plan.relations[-1].root.names) + ) + ], + **_merge_extensions(bound_plan, bound_offset, bound_count), + ) + + return resolve def join( @@ -162,99 +184,111 @@ def join( right: stp.Plan, expression: UnboundExtendedExpression, type: stalg.JoinRel.JoinType, - registry: ExtensionRegistry, -): - left_ns = infer_plan_schema(left) - right_ns = infer_plan_schema(right) - ns = stt.NamedStruct( - struct=stt.Type.Struct( - types=list(left_ns.struct.types) + list(right_ns.struct.types), - nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, - ), - names=list(left_ns.names) + list(right_ns.names), - ) - expression: stee.ExtendedExpression = expression(ns, registry) - - rel = stalg.Rel( - join=stalg.JoinRel( - left=left.relations[-1].root.input, - right=right.relations[-1].root.input, - expression=expression.referred_expr[0].expression, - type=type, +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_left = left if isinstance(left, stp.Plan) else left(registry) + bound_right = right if isinstance(right, stp.Plan) else right(registry) + left_ns = infer_plan_schema(bound_left) + right_ns = infer_plan_schema(bound_right) + + ns = stt.NamedStruct( + struct=stt.Type.Struct( + types=list(left_ns.struct.types) + list(right_ns.struct.types), + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ), + names=list(left_ns.names) + list(right_ns.names), + ) + bound_expression: stee.ExtendedExpression = expression(ns, registry) + + rel = stalg.Rel( + join=stalg.JoinRel( + left=bound_left.relations[-1].root.input, + right=bound_right.relations[-1].root.input, + expression=bound_expression.referred_expr[0].expression, + type=type, + ) ) - ) - return stp.Plan( - relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], - **_merge_extensions(left, right, expression), - ) + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], + **_merge_extensions(bound_left, bound_right, bound_expression), + ) + + return resolve def cross( - left: stp.Plan, - right: stp.Plan, - registry: ExtensionRegistry, -): - left_ns = infer_plan_schema(left) - right_ns = infer_plan_schema(right) - - ns = stt.NamedStruct( - struct=stt.Type.Struct( - types=list(left_ns.struct.types) + list(right_ns.struct.types), - nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, - ), - names=list(left_ns.names) + list(right_ns.names), - ) - - rel = stalg.Rel( - cross=stalg.CrossRel( - left=left.relations[-1].root.input, - right=right.relations[-1].root.input - ) - ) + left: Union[stp.Plan, UnboundPlan], + right: Union[stp.Plan, UnboundPlan], +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_left = left if isinstance(left, stp.Plan) else left(registry) + bound_right = right if isinstance(right, stp.Plan) else right(registry) + left_ns = infer_plan_schema(bound_left) + right_ns = infer_plan_schema(bound_right) + + ns = stt.NamedStruct( + struct=stt.Type.Struct( + types=list(left_ns.struct.types) + list(right_ns.struct.types), + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ), + names=list(left_ns.names) + list(right_ns.names), + ) - return stp.Plan( - relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], - **_merge_extensions(left, right), - ) + rel = stalg.Rel( + cross=stalg.CrossRel( + left=bound_left.relations[-1].root.input, + right=bound_right.relations[-1].root.input + ) + ) + + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=ns.names))], + **_merge_extensions(bound_left, bound_right), + ) + + return resolve # TODO grouping sets def aggregate( input: stp.Plan, grouping_expressions: Iterable[UnboundExtendedExpression], measures: Iterable[UnboundExtendedExpression], - registry: ExtensionRegistry, -): - ns = infer_plan_schema(input) - - grouping_expressions = [e(ns, registry) for e in grouping_expressions] - measures = [e(ns, registry) for e in measures] - - rel = stalg.Rel( - aggregate=stalg.AggregateRel( - input=input.relations[-1].root.input, - grouping_expressions=[ - e.referred_expr[0].expression for e in grouping_expressions - ], - groupings=[ - stalg.AggregateRel.Grouping( - expression_references=range(len(grouping_expressions)), - grouping_expressions=[ - e.referred_expr[0].expression for e in grouping_expressions - ], - ) - ], - measures=[ - stalg.AggregateRel.Measure(measure=m.referred_expr[0].measure) - for m in measures - ], +) -> UnboundPlan: + def resolve(registry: ExtensionRegistry) -> stp.Plan: + bound_input = input if isinstance(input, stp.Plan) else input(registry) + ns = infer_plan_schema(bound_input) + + bound_grouping_expressions = [e(ns, registry) for e in grouping_expressions] + bound_measures = [e(ns, registry) for e in measures] + + rel = stalg.Rel( + aggregate=stalg.AggregateRel( + input=bound_input.relations[-1].root.input, + grouping_expressions=[ + e.referred_expr[0].expression for e in bound_grouping_expressions + ], + groupings=[ + stalg.AggregateRel.Grouping( + expression_references=range(len(bound_grouping_expressions)), + grouping_expressions=[ + e.referred_expr[0].expression for e in bound_grouping_expressions + ], + ) + ], + measures=[ + stalg.AggregateRel.Measure(measure=m.referred_expr[0].measure) + for m in bound_measures + ], + ) ) - ) - names = [e.referred_expr[0].output_names[0] for e in grouping_expressions] + [ - e.referred_expr[0].output_names[0] for e in measures - ] + names = [e.referred_expr[0].output_names[0] for e in bound_grouping_expressions] + [ + e.referred_expr[0].output_names[0] for e in bound_measures + ] - return stp.Plan( - relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], - **_merge_extensions(input, *grouping_expressions, *measures), - ) + return stp.Plan( + relations=[stp.PlanRel(root=stalg.RelRoot(input=rel, names=names))], + **_merge_extensions(bound_input, *bound_grouping_expressions, *bound_measures), + ) + + return resolve diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index fd6bd9b..a3ba138 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -71,3 +71,6 @@ def list(type: stt.Type, nullable=True): def map(key: stt.Type, value: stt.Type, nullable=True): return stt.Type(map=stt.Type.Map(key=key, value=value, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + +def named_struct(names: Iterable[str], struct: stt.Type): + return stt.NamedStruct(names=names, struct=struct.struct) diff --git a/tests/builders/plan/test_aggregate.py b/tests/builders/plan/test_aggregate.py index 50508da..175519e 100644 --- a/tests/builders/plan/test_aggregate.py +++ b/tests/builders/plan/test_aggregate.py @@ -42,10 +42,9 @@ def test_aggregate(): actual = aggregate(table, grouping_expressions=[group_expr], - measures=[measure_expr], - registry=registry) + measures=[measure_expr])(registry) - ns = infer_plan_schema(table) + ns = infer_plan_schema(table(None)) expected = stp.Plan( extension_uris=[ @@ -68,7 +67,7 @@ def test_aggregate(): root=stalg.RelRoot( input=stalg.Rel( aggregate=stalg.AggregateRel( - input=table.relations[-1].root.input, + input=table(None).relations[-1].root.input, grouping_expressions=[ group_expr(ns, registry).referred_expr[0].expression ], @@ -94,7 +93,4 @@ def test_aggregate(): ] ) - print(actual) - print(expected) - assert actual == expected \ No newline at end of file diff --git a/tests/builders/plan/test_cross.py b/tests/builders/plan/test_cross.py index 37e4487..42cfb05 100644 --- a/tests/builders/plan/test_cross.py +++ b/tests/builders/plan/test_cross.py @@ -18,11 +18,11 @@ names=["fk_id", "name"], struct=stt.Type.Struct(types=[i64(nullable=False), string()]) ) -def test_join(): +def test_cross_join(): table = read_named_table('table', named_struct) table2 = read_named_table('table2', named_struct_2) - actual = cross(table, table2, registry=registry) + actual = cross(table, table2)(registry) expected = stp.Plan( relations=[ @@ -30,8 +30,8 @@ def test_join(): root=stalg.RelRoot( input=stalg.Rel( cross=stalg.CrossRel( - left=table.relations[-1].root.input, - right=table2.relations[-1].root.input, + left=table(None).relations[-1].root.input, + right=table2(None).relations[-1].root.input, ) ), names=['id', 'is_applicable', 'fk_id', 'name'] diff --git a/tests/builders/plan/test_fetch.py b/tests/builders/plan/test_fetch.py index 27a64ee..f53a9b9 100644 --- a/tests/builders/plan/test_fetch.py +++ b/tests/builders/plan/test_fetch.py @@ -21,7 +21,7 @@ def test_fetch(): offset = literal(10, i64()) count = literal(5, i64()) - actual = fetch(table, offset=offset, count=count, registry=registry) + actual = fetch(table, offset=offset, count=count)(registry) expected = stp.Plan( relations=[ @@ -29,7 +29,7 @@ def test_fetch(): root=stalg.RelRoot( input=stalg.Rel( fetch=stalg.FetchRel( - input=table.relations[-1].root.input, + input=table(None).relations[-1].root.input, offset_expr=offset(None, None).referred_expr[0].expression, count_expr=count(None, None).referred_expr[0].expression ) diff --git a/tests/builders/plan/test_filter.py b/tests/builders/plan/test_filter.py index 9375878..e40ed22 100644 --- a/tests/builders/plan/test_filter.py +++ b/tests/builders/plan/test_filter.py @@ -17,7 +17,7 @@ def test_filter(): table = read_named_table('table', named_struct) - actual = filter(table, literal(True, boolean()), registry) + actual = filter(table, literal(True, boolean()))(registry) expected = stp.Plan( relations=[ @@ -25,7 +25,7 @@ def test_filter(): root=stalg.RelRoot( input=stalg.Rel( filter=stalg.FilterRel( - input=table.relations[-1].root.input, + input=table(None).relations[-1].root.input, condition=stalg.Expression( literal=stalg.Expression.Literal( boolean=True, diff --git a/tests/builders/plan/test_join.py b/tests/builders/plan/test_join.py index fd4cf4f..42ad30a 100644 --- a/tests/builders/plan/test_join.py +++ b/tests/builders/plan/test_join.py @@ -22,7 +22,7 @@ def test_join(): table = read_named_table('table', named_struct) table2 = read_named_table('table2', named_struct_2) - actual = join(table, table2, literal(True, boolean()), stalg.JoinRel.JOIN_TYPE_INNER, registry=registry) + actual = join(table, table2, literal(True, boolean()), stalg.JoinRel.JOIN_TYPE_INNER)(registry) expected = stp.Plan( relations=[ @@ -30,8 +30,8 @@ def test_join(): root=stalg.RelRoot( input=stalg.Rel( join=stalg.JoinRel( - left=table.relations[-1].root.input, - right=table2.relations[-1].root.input, + left=table(None).relations[-1].root.input, + right=table2(None).relations[-1].root.input, expression=literal(True, boolean())(None, None).referred_expr[0].expression, type=stalg.JoinRel.JOIN_TYPE_INNER ) diff --git a/tests/builders/plan/test_project.py b/tests/builders/plan/test_project.py index 4e8cbff..9535a32 100644 --- a/tests/builders/plan/test_project.py +++ b/tests/builders/plan/test_project.py @@ -17,7 +17,7 @@ def test_project(): table = read_named_table('table', named_struct) - actual = project(table, [column('id')], registry=registry) + actual = project(table, [column('id')])(registry) expected = stp.Plan( relations=[ @@ -26,7 +26,7 @@ def test_project(): input=stalg.Rel( project=stalg.ProjectRel( common=stalg.RelCommon(emit=stalg.RelCommon.Emit(output_mapping=[2])), - input=table.relations[-1].root.input, + input=table(None).relations[-1].root.input, expressions=[ stalg.Expression( selection=stalg.Expression.FieldReference( diff --git a/tests/builders/plan/test_read.py b/tests/builders/plan/test_read.py index 73cb1aa..7e8fdf7 100644 --- a/tests/builders/plan/test_read.py +++ b/tests/builders/plan/test_read.py @@ -11,7 +11,7 @@ ) def test_read_rel(): - actual = read_named_table('example_table', named_struct) + actual = read_named_table('example_table', named_struct)(None) expected = stp.Plan( relations=[ @@ -33,7 +33,7 @@ def test_read_rel(): assert actual == expected def test_read_rel_db(): - actual = read_named_table(['example_db', 'example_table'], named_struct) + actual = read_named_table(['example_db', 'example_table'], named_struct)(None) expected = stp.Plan( relations=[ diff --git a/tests/builders/plan/test_set.py b/tests/builders/plan/test_set.py index 290996b..c701024 100644 --- a/tests/builders/plan/test_set.py +++ b/tests/builders/plan/test_set.py @@ -19,7 +19,7 @@ def test_set(): table = read_named_table('table', named_struct) table2 = read_named_table('table2', named_struct) - actual = set([table, table2], stalg.SetRel.SET_OP_UNION_ALL) + actual = set([table, table2], stalg.SetRel.SET_OP_UNION_ALL)(None) expected = stp.Plan( @@ -29,8 +29,8 @@ def test_set(): input=stalg.Rel( set=stalg.SetRel( inputs=[ - table.relations[-1].root.input, - table2.relations[-1].root.input, + table(None).relations[-1].root.input, + table2(None).relations[-1].root.input, ], op=stalg.SetRel.SET_OP_UNION_ALL ) diff --git a/tests/builders/plan/test_sort.py b/tests/builders/plan/test_sort.py index 738da03..66cfbd5 100644 --- a/tests/builders/plan/test_sort.py +++ b/tests/builders/plan/test_sort.py @@ -20,7 +20,7 @@ def test_sort_no_direction(): col = column('id') - actual = sort(table, expressions=[col], registry=registry) + actual = sort(table, expressions=[col])(registry) expected = stp.Plan( relations=[ @@ -28,11 +28,11 @@ def test_sort_no_direction(): root=stalg.RelRoot( input=stalg.Rel( sort=stalg.SortRel( - input=table.relations[-1].root.input, + input=table(None).relations[-1].root.input, sorts=[ stalg.SortField( direction=stalg.SortField.SORT_DIRECTION_ASC_NULLS_LAST, - expr=col(infer_plan_schema(table), registry).referred_expr[0].expression + expr=col(infer_plan_schema(table(None)), registry).referred_expr[0].expression ) ] ) @@ -50,7 +50,7 @@ def test_sort_direction(): col = column('id') - actual = sort(table, expressions=[(col, stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST)], registry=registry) + actual = sort(table, expressions=[(col, stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST)])(registry) expected = stp.Plan( relations=[ @@ -58,11 +58,11 @@ def test_sort_direction(): root=stalg.RelRoot( input=stalg.Rel( sort=stalg.SortRel( - input=table.relations[-1].root.input, + input=table(None).relations[-1].root.input, sorts=[ stalg.SortField( direction=stalg.SortField.SORT_DIRECTION_DESC_NULLS_FIRST, - expr=col(infer_plan_schema(table), registry).referred_expr[0].expression + expr=col(infer_plan_schema(table(None)), registry).referred_expr[0].expression ) ] ) From 7aab1ec25f62328722b4052c3f95d6deb90c89e2 Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 29 Apr 2025 16:45:53 +0000 Subject: [PATCH 4/5] fix: plan builder type annotations --- src/substrait/builders/plan.py | 27 ++++++++++++------- src/substrait/builders/type.py | 48 +++++++++++++++++----------------- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index 7b348d5..16b879e 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -1,3 +1,8 @@ +""" +Plan builders take either Plan or UnboundPlan objects as input rather than plain Rels. +This is to make sure that additional information like extension types of functions are not lost. +""" + from typing import Iterable, Union, Callable import substrait.gen.proto.algebra_pb2 as stalg @@ -11,6 +16,8 @@ UnboundPlan = Callable[[ExtensionRegistry], stp.Plan] +PlanOrUnbound = Union[stp.Plan, UnboundPlan] + def _merge_extensions(*objs): return { "extension_uris": merge_extension_uris(*[b.extension_uris for b in objs]), @@ -39,7 +46,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: def project( - plan: Union[stp.Plan, UnboundPlan], expressions: Iterable[UnboundExtendedExpression] + plan: PlanOrUnbound, expressions: Iterable[UnboundExtendedExpression] ) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: _plan = plan if isinstance(plan, stp.Plan) else plan(registry) @@ -73,7 +80,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: def filter( - plan: Union[stp.Plan, UnboundPlan], expression: UnboundExtendedExpression + plan: PlanOrUnbound, expression: UnboundExtendedExpression ) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_plan = plan if isinstance(plan, stp.Plan) else plan(registry) @@ -98,7 +105,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: def sort( - plan: stp.Plan, + plan: PlanOrUnbound, expressions: Iterable[Union[UnboundExtendedExpression, tuple[UnboundExtendedExpression, stalg.SortField.SortDirection.ValueType]]] ) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: @@ -129,7 +136,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: return resolve -def set(inputs: Iterable[Union[stp.Plan, UnboundPlan]], op: stalg.SetRel.SetOp) -> UnboundPlan: +def set(inputs: Iterable[PlanOrUnbound], op: stalg.SetRel.SetOp) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_inputs = [i if isinstance(i, stp.Plan) else i(registry) for i in inputs] rel = stalg.Rel( @@ -149,7 +156,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: return resolve -def fetch(plan: stp.Plan, +def fetch(plan: PlanOrUnbound, offset: UnboundExtendedExpression, count: UnboundExtendedExpression) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: @@ -180,8 +187,8 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: def join( - left: stp.Plan, - right: stp.Plan, + left: PlanOrUnbound, + right: PlanOrUnbound, expression: UnboundExtendedExpression, type: stalg.JoinRel.JoinType, ) -> UnboundPlan: @@ -217,8 +224,8 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: return resolve def cross( - left: Union[stp.Plan, UnboundPlan], - right: Union[stp.Plan, UnboundPlan], + left: PlanOrUnbound, + right: PlanOrUnbound, ) -> UnboundPlan: def resolve(registry: ExtensionRegistry) -> stp.Plan: bound_left = left if isinstance(left, stp.Plan) else left(registry) @@ -250,7 +257,7 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan: # TODO grouping sets def aggregate( - input: stp.Plan, + input: PlanOrUnbound, grouping_expressions: Iterable[UnboundExtendedExpression], measures: Iterable[UnboundExtendedExpression], ) -> UnboundPlan: diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index a3ba138..d7468f0 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -1,76 +1,76 @@ from typing import Iterable import substrait.gen.proto.type_pb2 as stt -def boolean(nullable=True): +def boolean(nullable=True) -> stt.Type: return stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def i8(nullable=True): +def i8(nullable=True) -> stt.Type: return stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def i16(nullable=True): +def i16(nullable=True) -> stt.Type: return stt.Type(i16=stt.Type.I16(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def i32(nullable=True): +def i32(nullable=True) -> stt.Type: return stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def i64(nullable=True): +def i64(nullable=True) -> stt.Type: return stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def fp32(nullable=True): +def fp32(nullable=True) -> stt.Type: return stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def fp64(nullable=True): +def fp64(nullable=True) -> stt.Type: return stt.Type(fp64=stt.Type.FP64(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def string(nullable=True): +def string(nullable=True) -> stt.Type: return stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def binary(nullable=True): +def binary(nullable=True) -> stt.Type: return stt.Type(binary=stt.Type.Binary(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def date(nullable=True): +def date(nullable=True) -> stt.Type: return stt.Type(date=stt.Type.Date(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def interval_year(nullable=True): +def interval_year(nullable=True) -> stt.Type: return stt.Type(interval_year=stt.Type.IntervalYear(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def interval_day(precision: int, nullable=True): +def interval_day(precision: int, nullable=True) -> stt.Type: return stt.Type(interval_day=stt.Type.IntervalDay(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def interval_compound(precision: int, nullable=True): +def interval_compound(precision: int, nullable=True) -> stt.Type: return stt.Type(interval_compound=stt.Type.IntervalCompound(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def uuid(nullable=True): +def uuid(nullable=True) -> stt.Type: return stt.Type(uuid=stt.Type.UUID(nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def fixed_char(length: int, nullable=True): +def fixed_char(length: int, nullable=True) -> stt.Type: return stt.Type(fixed_char=stt.Type.FixedChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def var_char(length: int, nullable=True): +def var_char(length: int, nullable=True) -> stt.Type: return stt.Type(var_char=stt.Type.VarChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def fixed_binary(length: int, nullable=True): +def fixed_binary(length: int, nullable=True) -> stt.Type: return stt.Type(fixed_binary=stt.Type.FixedBinary(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def decimal(scale: int, precision: int, nullable=True): +def decimal(scale: int, precision: int, nullable=True) -> stt.Type: return stt.Type(decimal=stt.Type.Decimal(scale=scale, precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) # PrecisionTime -def precision_timestamp(precision: int, nullable=True): +def precision_timestamp(precision: int, nullable=True) -> stt.Type: return stt.Type(precision_timestamp=stt.Type.PrecisionTimestamp(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def precision_timestamp_tz(precision: int, nullable=True): +def precision_timestamp_tz(precision: int, nullable=True) -> stt.Type: return stt.Type(precision_timestamp_tz=stt.Type.PrecisionTimestampTZ(precision=precision, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def struct(types: Iterable[stt.Type], nullable=True): +def struct(types: Iterable[stt.Type], nullable=True) -> stt.Type: return stt.Type(struct=stt.Type.Struct(types=types, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def list(type: stt.Type, nullable=True): +def list(type: stt.Type, nullable=True) -> stt.Type: return stt.Type(list=stt.Type.List(type=type, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def map(key: stt.Type, value: stt.Type, nullable=True): +def map(key: stt.Type, value: stt.Type, nullable=True) -> stt.Type: return stt.Type(map=stt.Type.Map(key=key, value=value, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) -def named_struct(names: Iterable[str], struct: stt.Type): +def named_struct(names: Iterable[str], struct: stt.Type) -> stt.NamedStruct: return stt.NamedStruct(names=names, struct=struct.struct) From 5dc7c977957ba754f0309b75d79ec0178495e8ec Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 29 Apr 2025 16:46:34 +0000 Subject: [PATCH 5/5] fix: plan builder type annotations --- src/substrait/builders/plan.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/substrait/builders/plan.py b/src/substrait/builders/plan.py index 16b879e..3075c4b 100644 --- a/src/substrait/builders/plan.py +++ b/src/substrait/builders/plan.py @@ -1,6 +1,8 @@ """ Plan builders take either Plan or UnboundPlan objects as input rather than plain Rels. This is to make sure that additional information like extension types of functions are not lost. +All builders return UnboundPlan objects that can be materialized to a Plan using an ExtensionRegistry. +See `examples/builder_example.py` for usage. """ from typing import Iterable, Union, Callable