diff --git a/src/substrait/type_inference.py b/src/substrait/type_inference.py new file mode 100644 index 0000000..080af9a --- /dev/null +++ b/src/substrait/type_inference.py @@ -0,0 +1,336 @@ +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt + + +def infer_literal_type(literal: stalg.Expression.Literal) -> stt.Type: + literal_type = literal.WhichOneof("literal_type") + + nullability = ( + stt.Type.Nullability.NULLABILITY_NULLABLE + if literal.nullable + else stt.Type.Nullability.NULLABILITY_REQUIRED + ) + + if literal_type == "boolean": + return stt.Type(bool=stt.Type.Boolean(nullability=nullability)) + elif literal_type == "i8": + return stt.Type(i8=stt.Type.I8(nullability=nullability)) + elif literal_type == "i16": + return stt.Type(i16=stt.Type.I16(nullability=nullability)) + elif literal_type == "i32": + return stt.Type(i32=stt.Type.I32(nullability=nullability)) + elif literal_type == "i64": + return stt.Type(i64=stt.Type.I64(nullability=nullability)) + elif literal_type == "fp32": + return stt.Type(fp32=stt.Type.FP32(nullability=nullability)) + elif literal_type == "fp64": + return stt.Type(fp64=stt.Type.FP64(nullability=nullability)) + elif literal_type == "string": + return stt.Type(string=stt.Type.String(nullability=nullability)) + elif literal_type == "binary": + return stt.Type(binary=stt.Type.Binary(nullability=nullability)) + elif literal_type == "timestamp": + return stt.Type(timestamp=stt.Type.Timestamp(nullability=nullability)) + elif literal_type == "date": + return stt.Type(date=stt.Type.Date(nullability=nullability)) + elif literal_type == "time": + return stt.Type(time=stt.Type.Time(nullability=nullability)) + elif literal_type == "interval_year_to_month": + return stt.Type(interval_year=stt.Type.IntervalYear(nullability=nullability)) + elif literal_type == "interval_day_to_second": + return stt.Type( + interval_day=stt.Type.IntervalDay( + precision=literal.interval_day_to_second.precision, + nullability=nullability, + ) + ) + elif literal_type == "interval_compound": + return stt.Type( + interval_compound=stt.Type.IntervalCompound( + nullability=nullability, + precision=literal.interval_compound.interval_day_to_second.precision, + ) + ) + elif literal_type == "fixed_char": + return stt.Type( + fixed_char=stt.Type.FixedChar( + length=len(literal.fixed_char), nullability=nullability + ) + ) + elif literal_type == "var_char": + return stt.Type( + varchar=stt.Type.VarChar( + length=literal.var_char.length, nullability=nullability + ) + ) + elif literal_type == "fixed_binary": + return stt.Type( + fixed_binary=stt.Type.FixedBinary( + length=len(literal.fixed_binary), nullability=nullability + ) + ) + elif literal_type == "decimal": + return stt.Type( + decimal=stt.Type.Decimal( + scale=literal.decimal.scale, + precision=literal.decimal.precision, + nullability=nullability, + ) + ) + elif literal_type == "precision_timestamp": + return stt.Type( + precision_timestamp=stt.Type.PrecisionTimestamp( + precision=literal.precision_timestamp.precision, nullability=nullability + ) + ) + elif literal_type == "precision_timestamp_tz": + return stt.Type( + precision_timestamp_tz=stt.Type.PrecisionTimestampTZ( + precision=literal.precision_timestamp_tz.precision, + nullability=nullability, + ) + ) + elif literal_type == "struct": + return stt.Type( + struct=stt.Type.Struct( + types=[infer_literal_type(f) for f in literal.struct.fields], + nullability=nullability, + ) + ) + elif literal_type == "map": + return stt.Type( + map=stt.Type.Map( + key=infer_literal_type(literal.map.key_values[0].key), + value=infer_literal_type(literal.map.key_values[0].value), + nullability=nullability, + ) + ) + elif literal_type == "timestamp_tz": + return stt.Type(timestamp_tz=stt.Type.TimestampTZ(nullability=nullability)) + elif literal_type == "uuid": + return stt.Type(uuid=stt.Type.UUID(nullability=nullability)) + elif literal_type == "null": + return literal.null + elif literal_type == "list": + return stt.Type( + list=stt.Type.List( + type=infer_literal_type(literal.list.values[0]), nullability=nullability + ) + ) + elif literal_type == "empty_list": + return stt.Type(list=literal.empty_list) + elif literal_type == "empty_map": + return stt.Type(map=literal.empty_map) + else: + raise Exception(f"Unknown literal_type {literal_type}") + + +def infer_nested_type(nested: stalg.Expression.Nested) -> stt.Type: + nested_type = nested.WhichOneof("nested_type") + + nullability = ( + stt.Type.Nullability.NULLABILITY_NULLABLE + if nested.nullable + else stt.Type.Nullability.NULLABILITY_REQUIRED + ) + + if nested_type == "struct": + return stt.Type( + struct=stt.Type.Struct( + types=[infer_expression_type(f) for f in nested.struct.fields], + nullability=nullability, + ) + ) + elif nested_type == "list": + return stt.Type( + list=stt.Type.List( + type=infer_expression_type(nested.list.values[0]), + nullability=nullability, + ) + ) + elif nested_type == "map": + return stt.Type( + map=stt.Type.Map( + key=infer_expression_type(nested.map.key_values[0].key), + value=infer_expression_type(nested.map.key_values[0].value), + nullability=nullability, + ) + ) + else: + raise Exception(f"Unknown nested_type {nested_type}") + + +def infer_expression_type( + expression: stalg.Expression, parent_schema: stt.Type.Struct +) -> stt.Type: + rex_type = expression.WhichOneof("rex_type") + if rex_type == "selection": + root_type = expression.selection.WhichOneof("root_type") + assert root_type == "root_reference" + + reference_type = expression.selection.WhichOneof("reference_type") + + if reference_type == "direct_reference": + segment = expression.selection.direct_reference + + segment_reference_type = segment.WhichOneof("reference_type") + + if segment_reference_type == "struct_field": + return parent_schema.types[segment.struct_field.field] + else: + raise Exception(f"Unknown reference_type {reference_type}") + else: + raise Exception(f"Unknown reference_type {reference_type}") + + elif rex_type == "literal": + return infer_literal_type(expression.literal) + elif rex_type == "scalar_function": + return expression.scalar_function.output_type + elif rex_type == "window_function": + return expression.window_function.output_type + elif rex_type == "if_then": + return infer_expression_type(expression.if_then.ifs[0].then) + elif rex_type == "switch_expression": + return infer_expression_type(expression.switch_expression.ifs[0].then) + elif rex_type == "cast": + return expression.cast.type + elif rex_type == "singular_or_list" or rex_type == "multi_or_list": + return stt.Type( + bool=stt.Type.Boolean(nullability=stt.Type.Nullability.NULLABILITY_NULLABLE) + ) + elif rex_type == "nested": + return infer_nested_type(expression.nested) + elif rex_type == "subquery": + subquery_type = expression.subquery.WhichOneof("subquery_type") + + if subquery_type == "scalar": + scalar_rel = infer_rel_schema(expression.subquery.scalar.input) + return scalar_rel.types[0] + elif ( + subquery_type == "in_predicate" + or subquery_type == "set_comparison" + or subquery_type == "set_predicate" + ): + stt.Type.Boolean( + nullability=stt.Type.Nullability.NULLABILITY_NULLABLE + ) # can this be a null? + else: + raise Exception(f"Unknown subquery_type {subquery_type}") + else: + raise Exception(f"Unknown rex_type {rex_type}") + + +def infer_rel_schema(rel: stalg.Rel) -> stt.Type.Struct: + rel_type = rel.WhichOneof("rel_type") + + if rel_type == "read": + (common, struct) = (rel.read.common, rel.read.base_schema.struct) + elif rel_type == "filter": + (common, struct) = (rel.filter.common, infer_rel_schema(rel.filter.input)) + elif rel_type == "fetch": + (common, struct) = (rel.fetch.common, infer_rel_schema(rel.fetch.input)) + elif rel_type == "aggregate": + parent_schema = infer_rel_schema(rel.aggregate.input) + grouping_types = [ + infer_expression_type(g, parent_schema) + for g in rel.aggregate.grouping_expressions + ] + measure_types = [m.measure.output_type for m in rel.aggregate.measures] + + grouping_identifier_types = ( + [] + if len(rel.aggregate.groupings) <= 1 + else [stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_REQUIRED))] + ) + + raw_schema = stt.Type.Struct( + types=grouping_types + measure_types + grouping_identifier_types, + nullability=parent_schema.nullability, + ) + + (common, struct) = (rel.aggregate.common, raw_schema) + elif rel_type == "sort": + (common, struct) = (rel.sort.common, infer_rel_schema(rel.sort.input)) + elif rel_type == "project": + parent_schema = infer_rel_schema(rel.project.input) + expression_types = [ + infer_expression_type(e, parent_schema) for e in rel.project.expressions + ] + raw_schema = stt.Type.Struct( + types=list(parent_schema.types) + expression_types, + nullability=parent_schema.nullability, + ) + + (common, struct) = (rel.project.common, raw_schema) + elif rel_type == "set": + (common, struct) = (rel.fetch.common, infer_rel_schema(rel.set.inputs[0])) + elif rel_type == "cross": + left_schema = infer_rel_schema(rel.cross.left) + right_schema = infer_rel_schema(rel.cross.right) + + raw_schema = stt.Type.Struct( + types=list(left_schema.types) + list(right_schema.types), + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ) + + (common, struct) = (rel.cross.common, raw_schema) + elif rel_type == "join": + if rel.join.type in [ + stalg.JoinRel.JOIN_TYPE_INNER, + stalg.JoinRel.JOIN_TYPE_OUTER, + stalg.JoinRel.JOIN_TYPE_LEFT, + stalg.JoinRel.JOIN_TYPE_RIGHT, + stalg.JoinRel.JOIN_TYPE_LEFT_SINGLE, + stalg.JoinRel.JOIN_TYPE_RIGHT_SINGLE, + ]: + raw_schema = stt.Type.Struct( + types=list(infer_rel_schema(rel.join.left).types) + + list(infer_rel_schema(rel.join.right).types), + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ) + elif rel.join.type in [ + stalg.JoinRel.JOIN_TYPE_LEFT_ANTI, + stalg.JoinRel.JOIN_TYPE_LEFT_SEMI, + ]: + raw_schema = stt.Type.Struct( + types=infer_rel_schema(rel.join.left).types, + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ) + elif rel.join.type in [ + stalg.JoinRel.JOIN_TYPE_RIGHT_ANTI, + stalg.JoinRel.JOIN_TYPE_RIGHT_SEMI, + ]: + raw_schema = stt.Type.Struct( + types=infer_rel_schema(rel.join.right).types, + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ) + elif rel.join.type in [ + stalg.JoinRel.JOIN_TYPE_LEFT_MARK, + stalg.JoinRel.JOIN_TYPE_RIGHT_MARK, + ]: + raw_schema = stt.Type.Struct( + types=list(infer_rel_schema(rel.join.left).types) + + list(infer_rel_schema(rel.join.right).types) + + [ + stt.Type( + bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE) + ) + ], + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ) + else: + raise Exception(f"Unhandled join_type {rel.join.type}") + + (common, struct) = (rel.join.common, raw_schema) + else: + raise Exception(f"Unhandled rel_type {rel_type}") + + emit_kind = common.WhichOneof("emit_kind") or "direct" + + if emit_kind == "direct": + return struct + else: + return stt.Type.Struct( + types=[struct.types[i] for i in common.emit.output_mapping], + nullability=struct.nullability, + ) diff --git a/tests/test_literal_type_inference.py b/tests/test_literal_type_inference.py new file mode 100644 index 0000000..07274a2 --- /dev/null +++ b/tests/test_literal_type_inference.py @@ -0,0 +1,313 @@ +import pytest +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +from substrait.type_inference import infer_literal_type + +testcases = [ + ( + stalg.Expression.Literal(boolean=True, nullable=True), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal(i8=100, nullable=True), + stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal(i16=100, nullable=True), + stt.Type(i16=stt.Type.I16(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal(i32=100, nullable=True), + stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal(i64=100, nullable=True), + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal(fp32=100.5, nullable=True), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal(fp64=100.5, nullable=True), + stt.Type(fp64=stt.Type.FP64(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal(string="substrait", nullable=True), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal(binary=b"\xde", nullable=True), + stt.Type(binary=stt.Type.Binary(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal(timestamp=1000000, nullable=True), + stt.Type( + timestamp=stt.Type.Timestamp(nullability=stt.Type.NULLABILITY_NULLABLE) + ), + ), + ( + stalg.Expression.Literal(date=1000, nullable=True), + stt.Type(date=stt.Type.Date(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal(time=1000, nullable=True), + stt.Type(time=stt.Type.Time(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal( + interval_year_to_month=stalg.Expression.Literal.IntervalYearToMonth( + years=1, months=5 + ), + nullable=True, + ), + stt.Type( + interval_year=stt.Type.IntervalYear( + nullability=stt.Type.NULLABILITY_NULLABLE + ) + ), + ), + ( + stalg.Expression.Literal( + interval_day_to_second=stalg.Expression.Literal.IntervalDayToSecond( + days=1, seconds=100 + ), + nullable=True, + ), + stt.Type( + interval_day=stt.Type.IntervalDay( + precision=0, nullability=stt.Type.NULLABILITY_NULLABLE + ) + ), + ), + ( + stalg.Expression.Literal( + interval_day_to_second=stalg.Expression.Literal.IntervalDayToSecond( + days=1, seconds=100, precision=3, subseconds=10 + ), + nullable=True, + ), + stt.Type( + interval_day=stt.Type.IntervalDay( + precision=3, nullability=stt.Type.NULLABILITY_NULLABLE + ) + ), + ), + ( + stalg.Expression.Literal( + interval_compound=stalg.Expression.Literal.IntervalCompound( + interval_year_to_month=stalg.Expression.Literal.IntervalYearToMonth( + years=1, months=5 + ), + interval_day_to_second=stalg.Expression.Literal.IntervalDayToSecond( + days=1, seconds=100 + ), + ), + nullable=True, + ), + stt.Type( + interval_compound=stt.Type.IntervalCompound( + precision=0, nullability=stt.Type.NULLABILITY_NULLABLE + ) + ), + ), + ( + stalg.Expression.Literal( + fixed_char="substrait", + nullable=True, + ), + stt.Type( + fixed_char=stt.Type.FixedChar( + length=9, nullability=stt.Type.NULLABILITY_NULLABLE + ) + ), + ), + ( + stalg.Expression.Literal( + var_char=stalg.Expression.Literal.VarChar(value="substrait", length=10), + nullable=True, + ), + stt.Type( + varchar=stt.Type.VarChar( + length=10, nullability=stt.Type.NULLABILITY_NULLABLE + ) + ), + ), + ( + stalg.Expression.Literal( + fixed_binary=b"substrait", + nullable=True, + ), + stt.Type( + fixed_binary=stt.Type.FixedBinary( + length=9, nullability=stt.Type.NULLABILITY_NULLABLE + ) + ), + ), + ( + stalg.Expression.Literal( + decimal=stalg.Expression.Literal.Decimal( + value=b"somenumber", precision=10, scale=2 + ), + nullable=True, + ), + stt.Type( + decimal=stt.Type.Decimal( + precision=10, scale=2, nullability=stt.Type.NULLABILITY_NULLABLE + ) + ), + ), + ( + stalg.Expression.Literal( + precision_timestamp=stalg.Expression.Literal.PrecisionTimestamp( + precision=3, value=1000 + ), + nullable=True, + ), + stt.Type( + precision_timestamp=stt.Type.PrecisionTimestamp( + precision=3, nullability=stt.Type.NULLABILITY_NULLABLE + ) + ), + ), + ( + stalg.Expression.Literal( + precision_timestamp_tz=stalg.Expression.Literal.PrecisionTimestamp( + precision=3, value=1000 + ), + nullable=True, + ), + stt.Type( + precision_timestamp_tz=stt.Type.PrecisionTimestampTZ( + precision=3, nullability=stt.Type.NULLABILITY_NULLABLE + ) + ), + ), + ( + stalg.Expression.Literal( + struct=stalg.Expression.Literal.Struct( + fields=[ + stalg.Expression.Literal(boolean=True, nullable=False), + stalg.Expression.Literal(i8=100, nullable=False), + ] + ), + nullable=True, + ), + stt.Type( + struct=stt.Type.Struct( + types=[ + stt.Type( + bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + stt.Type(i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED)), + ], + nullability=stt.Type.NULLABILITY_NULLABLE, + ) + ), + ), + ( + stalg.Expression.Literal( + map=stalg.Expression.Literal.Map( + key_values=[ + stalg.Expression.Literal.Map.KeyValue( + key=stalg.Expression.Literal(boolean=True, nullable=False), + value=stalg.Expression.Literal(i8=100, nullable=False), + ) + ], + ), + nullable=True, + ), + stt.Type( + map=stt.Type.Map( + key=stt.Type( + bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + value=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + nullability=stt.Type.NULLABILITY_NULLABLE, + ) + ), + ), + ( + stalg.Expression.Literal( + uuid=b"uuid", + nullable=True, + ), + stt.Type(uuid=stt.Type.UUID(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal( + null=stt.Type( + bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE) + ), + nullable=False, # this should be ignored + ), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE)), + ), + ( + stalg.Expression.Literal( + list=stalg.Expression.Literal.List( + values=[stalg.Expression.Literal(i8=100, nullable=False)], + ), + nullable=True, + ), + stt.Type( + list=stt.Type.List( + type=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + nullability=stt.Type.NULLABILITY_NULLABLE, + ) + ), + ), + ( + stalg.Expression.Literal( + empty_list=stt.Type.List( + type=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + nullability=stt.Type.NULLABILITY_REQUIRED, + ), + nullable=True, + ), + stt.Type( + list=stt.Type.List( + type=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + nullability=stt.Type.NULLABILITY_REQUIRED, + ) + ), + ), + ( + stalg.Expression.Literal( + empty_map=stt.Type.Map( + key=stt.Type( + bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + value=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + nullability=stt.Type.NULLABILITY_NULLABLE, + ), + nullable=False, + ), + stt.Type( + map=stt.Type.Map( + key=stt.Type( + bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + value=stt.Type( + i8=stt.Type.I8(nullability=stt.Type.NULLABILITY_REQUIRED) + ), + nullability=stt.Type.NULLABILITY_NULLABLE, + ) + ), + ), +] + + +@pytest.mark.parametrize("testcase", testcases) +def test_inference_literal_bool(testcase): + assert infer_literal_type(testcase[0]) == testcase[1] diff --git a/tests/test_type_inference.py b/tests/test_type_inference.py new file mode 100644 index 0000000..d761672 --- /dev/null +++ b/tests/test_type_inference.py @@ -0,0 +1,314 @@ +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +from substrait.type_inference import infer_rel_schema + + +struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +named_struct = stt.NamedStruct( + names=["order_id", "description", "order_total"], struct=struct +) + +read_rel = stalg.Rel( + read=stalg.ReadRel( + base_schema=named_struct, named_table=stalg.ReadRel.NamedTable(names=["table"]) + ) +) + +right_struct = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] +) + +right_named_struct = stt.NamedStruct( + names=["order_id", "is_refundable"], struct=right_struct +) + +right_read_rel = stalg.Rel( + read=stalg.ReadRel( + base_schema=right_named_struct, + named_table=stalg.ReadRel.NamedTable(names=["table2"]), + ) +) + + +def test_inference_read_named_table(): + assert infer_rel_schema(read_rel) == struct + + +def test_inference_project_emit(): + rel = stalg.Rel( + project=stalg.ProjectRel( + input=read_rel, + common=stalg.RelCommon(emit=stalg.RelCommon.Emit(output_mapping=[0, 2])), + ) + ) + + expected = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ] + ) + + assert infer_rel_schema(rel) == expected + + +def test_inference_project_literal(): + rel = stalg.Rel( + project=stalg.ProjectRel( + input=read_rel, + expressions=[ + stalg.Expression( + literal=stalg.Expression.Literal(boolean=True, nullable=False) + ) + ], + ) + ) + + expected = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED)), + ] + ) + + assert infer_rel_schema(rel) == expected + + +def test_inference_project_scalar_function(): + rel = stalg.Rel( + project=stalg.ProjectRel( + input=read_rel, + expressions=[ + stalg.Expression( + scalar_function=stalg.Expression.ScalarFunction( + function_reference=0, + output_type=stt.Type( + bool=stt.Type.Boolean( + nullability=stt.Type.NULLABILITY_REQUIRED + ) + ), + ) + ) + ], + ) + ) + + expected = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED)), + ] + ) + + assert infer_rel_schema(rel) == expected + + +def test_inference_aggregate(): + rel = stalg.Rel( + aggregate=stalg.AggregateRel( + input=read_rel, + grouping_expressions=[ + stalg.Expression( + selection=stalg.Expression.FieldReference( + root_reference=stalg.Expression.FieldReference.RootReference(), + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField( + field=1, + ), + ), + ) + ) + ], + groupings=[stalg.AggregateRel.Grouping(expression_references=[0])], + measures=[ + stalg.AggregateRel.Measure( + measure=stalg.AggregateFunction( + function_reference=0, + output_type=stt.Type( + bool=stt.Type.Boolean( + nullability=stt.Type.NULLABILITY_REQUIRED + ) + ), + ) + ) + ], + ) + ) + + expected = stt.Type.Struct( + types=[ + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED)), + ] + ) + + assert infer_rel_schema(rel) == expected + + +def test_inference_aggregate_multiple_groupings(): + rel = stalg.Rel( + aggregate=stalg.AggregateRel( + input=read_rel, + grouping_expressions=[ + stalg.Expression( + selection=stalg.Expression.FieldReference( + root_reference=stalg.Expression.FieldReference.RootReference(), + direct_reference=stalg.Expression.ReferenceSegment( + struct_field=stalg.Expression.ReferenceSegment.StructField( + field=1, + ), + ), + ) + ) + ], + groupings=[ + stalg.AggregateRel.Grouping(expression_references=[]), + stalg.AggregateRel.Grouping(expression_references=[0]), + ], + measures=[ + stalg.AggregateRel.Measure( + measure=stalg.AggregateFunction( + function_reference=0, + output_type=stt.Type( + bool=stt.Type.Boolean( + nullability=stt.Type.NULLABILITY_REQUIRED + ) + ), + ) + ) + ], + ) + ) + + expected = stt.Type.Struct( + types=[ + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(i32=stt.Type.I32(nullability=stt.Type.NULLABILITY_REQUIRED)), + ] + ) + + assert infer_rel_schema(rel) == expected + + +def test_inference_cross(): + rel = stalg.Rel(cross=stalg.CrossRel(left=read_rel, right=right_read_rel)) + + expected = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE)), + ], + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ) + + assert infer_rel_schema(rel) == expected + + +def test_inference_join_inner(): + rel = stalg.Rel( + join=stalg.JoinRel( + left=read_rel, + right=right_read_rel, + type=stalg.JoinRel.JOIN_TYPE_INNER, + expression=None, + ) + ) + + expected = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE)), + ], + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ) + + assert infer_rel_schema(rel) == expected + + +def test_inference_join_left_anti(): + rel = stalg.Rel( + join=stalg.JoinRel( + left=read_rel, + right=right_read_rel, + type=stalg.JoinRel.JOIN_TYPE_LEFT_ANTI, + expression=None, + ) + ) + + expected = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + ], + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ) + + assert infer_rel_schema(rel) == expected + + +def test_inference_join_right_anti(): + rel = stalg.Rel( + join=stalg.JoinRel( + left=read_rel, + right=right_read_rel, + type=stalg.JoinRel.JOIN_TYPE_RIGHT_ANTI, + expression=None, + ) + ) + + expected = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE)), + ], + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ) + + assert infer_rel_schema(rel) == expected + + +def test_inference_join_left_mark(): + rel = stalg.Rel( + join=stalg.JoinRel( + left=read_rel, + right=right_read_rel, + type=stalg.JoinRel.JOIN_TYPE_LEFT_MARK, + expression=None, + ) + ) + + expected = stt.Type.Struct( + types=[ + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(string=stt.Type.String(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(fp32=stt.Type.FP32(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(i64=stt.Type.I64(nullability=stt.Type.NULLABILITY_REQUIRED)), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE)), + stt.Type(bool=stt.Type.Boolean(nullability=stt.Type.NULLABILITY_NULLABLE)), + ], + nullability=stt.Type.Nullability.NULLABILITY_REQUIRED, + ) + + assert infer_rel_schema(rel) == expected