diff --git a/src/substrait/builders/extended_expression.py b/src/substrait/builders/extended_expression.py index f757980..fa4b444 100644 --- a/src/substrait/builders/extended_expression.py +++ b/src/substrait/builders/extended_expression.py @@ -1,3 +1,4 @@ +from datetime import date import itertools import substrait.gen.proto.algebra_pb2 as stalg import substrait.gen.proto.type_pb2 as stp @@ -41,6 +42,36 @@ def resolve(base_schema: stp.NamedStruct, registry: ExtensionRegistry) -> stee.E literal = stalg.Expression.Literal(fp64=value, nullable=type.fp64.nullability == stp.Type.NULLABILITY_NULLABLE) elif kind == "string": literal = stalg.Expression.Literal(string=value, nullable=type.string.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "binary": + literal = stalg.Expression.Literal(binary=value, nullable=type.binary.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "date": + date_value = (value - date(1970,1,1)).days if isinstance(value, date) else value + literal = stalg.Expression.Literal(date=date_value, nullable=type.date.nullability == stp.Type.NULLABILITY_NULLABLE) + # TODO + # IntervalYearToMonth interval_year_to_month = 19; + # IntervalDayToSecond interval_day_to_second = 20; + # IntervalCompound interval_compound = 36; + elif kind == "fixed_char": + literal = stalg.Expression.Literal(fixed_char=value, nullable=type.fixed_char.nullability == stp.Type.NULLABILITY_NULLABLE) + elif kind == "varchar": + literal = stalg.Expression.Literal( + var_char=stalg.Expression.Literal.VarChar(value=value, length=type.varchar.length), + nullable=type.varchar.nullability == stp.Type.NULLABILITY_NULLABLE + ) + elif kind == "fixed_binary": + literal = stalg.Expression.Literal(fixed_binary=value, nullable=type.fixed_binary.nullability == stp.Type.NULLABILITY_NULLABLE) + # TODO + # Decimal decimal = 24; + # PrecisionTime precision_time = 37; // Time in precision units past midnight. + # PrecisionTimestamp precision_timestamp = 34; + # PrecisionTimestamp precision_timestamp_tz = 35; + # Struct struct = 25; + # Map map = 26; + # bytes uuid = 28; + # Type null = 29; // a typed null literal + # List list = 30; + # Type.List empty_list = 31; + # Type.Map empty_map = 32; else: raise Exception(f"Unknown literal type - {type}") diff --git a/src/substrait/builders/type.py b/src/substrait/builders/type.py index d7468f0..f20c441 100644 --- a/src/substrait/builders/type.py +++ b/src/substrait/builders/type.py @@ -47,7 +47,7 @@ def fixed_char(length: int, nullable=True) -> stt.Type: return stt.Type(fixed_char=stt.Type.FixedChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) def var_char(length: int, nullable=True) -> stt.Type: - return stt.Type(var_char=stt.Type.VarChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) + return stt.Type(varchar=stt.Type.VarChar(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) def fixed_binary(length: int, nullable=True) -> stt.Type: return stt.Type(fixed_binary=stt.Type.FixedBinary(length=length, nullability=stt.Type.NULLABILITY_NULLABLE if nullable else stt.Type.NULLABILITY_REQUIRED)) diff --git a/tests/builders/extended_expression/test_literal.py b/tests/builders/extended_expression/test_literal.py new file mode 100644 index 0000000..45f86dd --- /dev/null +++ b/tests/builders/extended_expression/test_literal.py @@ -0,0 +1,36 @@ +from datetime import date +import substrait.gen.proto.algebra_pb2 as stalg +import substrait.gen.proto.type_pb2 as stt +import substrait.gen.proto.extended_expression_pb2 as stee +from substrait.builders.extended_expression import literal +from substrait.builders import type as sttb + +def extract_literal(builder): + return builder(None, None).referred_expr[0].expression.literal + +def test_boolean(): + assert extract_literal(literal(True, sttb.boolean())) == stalg.Expression.Literal(boolean=True, nullable=True) + assert extract_literal(literal(False, sttb.boolean())) == stalg.Expression.Literal(boolean=False, nullable=True) + +def test_integer(): + assert extract_literal(literal(100, sttb.i16())) == stalg.Expression.Literal(i16=100, nullable=True) + +def test_string(): + assert extract_literal(literal("Hello", sttb.string())) == stalg.Expression.Literal(string="Hello", nullable=True) + +def test_binary(): + assert extract_literal(literal(b"Hello", sttb.binary())) == stalg.Expression.Literal(binary=b"Hello", nullable=True) + +def test_date(): + assert extract_literal(literal(1000, sttb.date())) == stalg.Expression.Literal(date=1000, nullable=True) + assert extract_literal(literal(date(1970, 1, 11), sttb.date())) == stalg.Expression.Literal(date=10, nullable=True) + +def test_fixed_char(): + assert extract_literal(literal("Hello", sttb.fixed_char(length=5))) == stalg.Expression.Literal(fixed_char="Hello", nullable=True) + +def test_var_char(): + assert extract_literal(literal("Hello", sttb.var_char(length=5))) \ + == stalg.Expression.Literal(var_char=stalg.Expression.Literal.VarChar(value="Hello", length=5), nullable=True) + +def test_fixed_binary(): + assert extract_literal(literal(b"Hello", sttb.fixed_binary(length=5))) == stalg.Expression.Literal(fixed_binary=b"Hello", nullable=True)