diff --git a/pyproject.toml b/pyproject.toml index 808aa28..4c4ab62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,9 +12,9 @@ dynamic = ["version"] write_to = "src/substrait/_version.py" [project.optional-dependencies] -extensions = ["antlr4-python3-runtime"] +extensions = ["antlr4-python3-runtime", "pyyaml"] gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"] -test = ["pytest >= 7.0.0", "antlr4-python3-runtime"] +test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml"] [tool.pytest.ini_options] pythonpath = "src" @@ -31,7 +31,7 @@ target-version = "py39" # never autoformat upstream or generated code exclude = ["third_party/", "src/substrait/gen"] # do not autofix the following (will still get flagged in lint) -unfixable = [ +lint.unfixable = [ "F401", # unused imports "T201", # print statements ] diff --git a/src/substrait/derivation_expression.py b/src/substrait/derivation_expression.py index 276d518..f5e68c1 100644 --- a/src/substrait/derivation_expression.py +++ b/src/substrait/derivation_expression.py @@ -37,7 +37,6 @@ def _evaluate(x, values: dict): elif type(x) == SubstraitTypeParser.FunctionCallContext: exprs = [_evaluate(e, values) for e in x.expr()] func = x.Identifier().symbol.text - if func == "min": return min(*exprs) elif func == "max": @@ -48,27 +47,39 @@ def _evaluate(x, values: dict): scalar_type = x.scalarType() parametrized_type = x.parameterizedType() if scalar_type: + nullability = ( + Type.NULLABILITY_NULLABLE if x.isnull else Type.NULLABILITY_REQUIRED + ) if isinstance(scalar_type, SubstraitTypeParser.I8Context): - return Type(i8=Type.I8()) + return Type(i8=Type.I8(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.I16Context): - return Type(i16=Type.I16()) + return Type(i16=Type.I16(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.I32Context): - return Type(i32=Type.I32()) + return Type(i32=Type.I32(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.I64Context): - return Type(i64=Type.I64()) + return Type(i64=Type.I64(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.Fp32Context): - return Type(fp32=Type.FP32()) + return Type(fp32=Type.FP32(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.Fp64Context): - return Type(fp64=Type.FP64()) + return Type(fp64=Type.FP64(nullability=nullability)) elif isinstance(scalar_type, SubstraitTypeParser.BooleanContext): - return Type(bool=Type.Boolean()) + return Type(bool=Type.Boolean(nullability=nullability)) else: raise Exception(f"Unknown scalar type {type(scalar_type)}") elif parametrized_type: if isinstance(parametrized_type, SubstraitTypeParser.DecimalContext): precision = _evaluate(parametrized_type.precision, values) scale = _evaluate(parametrized_type.scale, values) - return Type(decimal=Type.Decimal(precision=precision, scale=scale)) + nullability = ( + Type.NULLABILITY_NULLABLE + if parametrized_type.isnull + else Type.NULLABILITY_REQUIRED + ) + return Type( + decimal=Type.Decimal( + precision=precision, scale=scale, nullability=nullability + ) + ) raise Exception(f"Unknown parametrized type {type(parametrized_type)}") else: raise Exception("either scalar_type or parametrized_type is required") @@ -91,12 +102,18 @@ def _evaluate(x, values: dict): return _evaluate(x.finalType, values) elif type(x) == SubstraitTypeParser.TypeLiteralContext: return _evaluate(x.type_(), values) + elif type(x) == SubstraitTypeParser.NumericLiteralContext: + return int(str(x.Number())) else: raise Exception(f"Unknown token type {type(x)}") -def evaluate(x: str, values: Optional[dict] = None): +def _parse(x: str): lexer = SubstraitTypeLexer(InputStream(x)) stream = CommonTokenStream(lexer) parser = SubstraitTypeParser(stream) - return _evaluate(parser.expr(), values) + return parser.expr() + + +def evaluate(x: str, values: Optional[dict] = None): + return _evaluate(_parse(x), values) diff --git a/src/substrait/function_registry.py b/src/substrait/function_registry.py new file mode 100644 index 0000000..101f2d7 --- /dev/null +++ b/src/substrait/function_registry.py @@ -0,0 +1,287 @@ +from substrait.gen.proto.parameterized_types_pb2 import ParameterizedType +from substrait.gen.proto.type_pb2 import Type +from importlib.resources import files as importlib_files +import itertools +from collections import defaultdict +from collections.abc import Mapping +from pathlib import Path +from typing import Any, Optional, Union +from .derivation_expression import evaluate, _evaluate, _parse + +import yaml + + +DEFAULT_URI_PREFIX = "https://github.com/substrait-io/substrait/blob/main/extensions" + + +# mapping from argument types to shortened signature names: https://substrait.io/extensions/#function-signature-compound-names +_normalized_key_names = { + "i8": "i8", + "i16": "i16", + "i32": "i32", + "i64": "i64", + "fp32": "fp32", + "fp64": "fp64", + "string": "str", + "binary": "vbin", + "boolean": "bool", + "timestamp": "ts", + "timestamp_tz": "tstz", + "date": "date", + "time": "time", + "interval_year": "iyear", + "interval_day": "iday", + "interval_compound": "icompound", + "uuid": "uuid", + "fixedchar": "fchar", + "varchar": "vchar", + "fixedbinary": "fbin", + "decimal": "dec", + "precision_time": "pt", + "precision_timestamp": "pts", + "precision_timestamp_tz": "ptstz", + "struct": "struct", + "list": "list", + "map": "map", +} + + +def normalize_substrait_type_names(typ: str) -> str: + # Strip type specifiers + typ = typ.split("<")[0] + # First strip nullability marker + typ = typ.strip("?").lower() + + if typ.startswith("any"): + return "any" + elif typ.startswith("u!"): + return typ + elif typ in _normalized_key_names: + return _normalized_key_names[typ] + else: + raise Exception(f"Unrecognized substrait type {typ}") + + +def violates_integer_option(actual: int, option, parameters: dict): + if isinstance(option, SubstraitTypeParser.NumericLiteralContext): + return actual != int(str(option.Number())) + elif isinstance(option, SubstraitTypeParser.NumericParameterNameContext): + parameter_name = str(option.Identifier()) + if parameter_name in parameters and parameters[parameter_name] != actual: + return True + else: + parameters[parameter_name] = actual + else: + raise Exception( + f"Input should be either NumericLiteralContext or NumericParameterNameContext, got {type(option)} instead" + ) + + return False + + +from substrait.gen.antlr.SubstraitTypeParser import SubstraitTypeParser + + +def types_equal(type1: Type, type2: Type, check_nullability=False): + if check_nullability: + return type1 == type2 + else: + x, y = Type(), Type() + x.CopyFrom(type1) + y.CopyFrom(type2) + x.__getattribute__( + x.WhichOneof("kind") + ).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED + y.__getattribute__( + y.WhichOneof("kind") + ).nullability = Type.Nullability.NULLABILITY_UNSPECIFIED + return x == y + + +def covers( + covered: Type, + covering: SubstraitTypeParser.TypeLiteralContext, + parameters: dict, + check_nullability=False, +): + if isinstance(covering, SubstraitTypeParser.TypeParamContext): + parameter_name = str(covering.Identifier()) + + if parameter_name in parameters: + covering = parameters[parameter_name] + + return types_equal(covering, covered, check_nullability) + else: + parameters[parameter_name] = covered + return True + + covering = covering.type_() + scalar_type = covering.scalarType() + if scalar_type: + covering = _evaluate(covering, {}) + return types_equal(covering, covered, check_nullability) + + parameterized_type = covering.parameterizedType() + if parameterized_type: + if isinstance(parameterized_type, SubstraitTypeParser.DecimalContext): + if covered.WhichOneof("kind") != "decimal": + return False + + nullability = ( + Type.NULLABILITY_NULLABLE + if parameterized_type.isnull + else Type.NULLABILITY_REQUIRED + ) + + if ( + check_nullability + and nullability + != covered.__getattribute__(covered.WhichOneof("kind")).nullability + ): + return False + + return not ( + violates_integer_option( + covered.decimal.scale, parameterized_type.scale, parameters + ) + or violates_integer_option( + covered.decimal.precision, parameterized_type.precision, parameters + ) + ) + else: + raise Exception(f"Unhandled type {type(parameterized_type)}") + + any_type = covering.anyType() + if any_type: + return True + + +class FunctionEntry: + def __init__( + self, uri: str, name: str, impl: Mapping[str, Any], anchor: int + ) -> None: + self.name = name + self.normalized_inputs: list = [] + self.uri: str = uri + self.anchor = anchor + self.arguments = [] + self.rtn = impl["return"] + self.nullability = impl.get("nullability", "MIRROR") + self.variadic = impl.get("variadic", False) + if input_args := impl.get("args", []): + for val in input_args: + if typ := val.get("value"): + self.arguments.append(_parse(typ)) + self.normalized_inputs.append(normalize_substrait_type_names(typ)) + elif arg_name := val.get("name", None): + self.arguments.append(val.get("options")) + self.normalized_inputs.append("req") + + def __repr__(self) -> str: + return f"{self.name}:{'_'.join(self.normalized_inputs)}" + + def satisfies_signature(self, signature: tuple) -> Optional[str]: + if self.variadic: + min_args_allowed = self.variadic.get("min", 0) + if len(signature) < min_args_allowed: + return None + inputs = [self.arguments[0]] * len(signature) + else: + inputs = self.arguments + if len(inputs) != len(signature): + return None + + zipped_args = list(zip(inputs, signature)) + + parameters = {} + + for x, y in zipped_args: + if type(y) == str: + if y not in x: + return None + else: + if not covers( + y, x, parameters, check_nullability=self.nullability == "DISCRETE" + ): + return None + + output_type = evaluate(self.rtn, parameters) + + if self.nullability == "MIRROR": + sig_contains_nullable = any( + [ + p.__getattribute__(p.WhichOneof("kind")).nullability + == Type.NULLABILITY_NULLABLE + for p in signature + if type(p) == Type + ] + ) + output_type.__getattribute__(output_type.WhichOneof("kind")).nullability = ( + Type.NULLABILITY_NULLABLE + if sig_contains_nullable + else Type.NULLABILITY_REQUIRED + ) + + return output_type + + +class FunctionRegistry: + def __init__(self, load_default_extensions=True) -> None: + self._function_mapping: dict = defaultdict(dict) + self._id_generator = itertools.count(1) + + self._uri_aliases = {} + + if load_default_extensions: + for fpath in importlib_files("substrait.extensions").glob( # type: ignore + "functions*.yaml" + ): + uri = f"{DEFAULT_URI_PREFIX}/{fpath.name}" + self._uri_aliases[fpath.name] = uri + self.register_extension_yaml(fpath, uri) + + def register_extension_yaml( + self, + fname: Union[str, Path], + uri: str, + ) -> None: + fname = Path(fname) + with open(fname) as f: # type: ignore + extension_definitions = yaml.safe_load(f) + + self.register_extension_dict(extension_definitions, uri) + + def register_extension_dict(self, definitions: dict, uri: str) -> None: + for named_functions in definitions.values(): + for function in named_functions: + for impl in function.get("impls", []): + func = FunctionEntry( + uri, function["name"], impl, next(self._id_generator) + ) + if ( + func.uri in self._function_mapping + and function["name"] in self._function_mapping[func.uri] + ): + self._function_mapping[func.uri][function["name"]].append(func) + else: + self._function_mapping[func.uri][function["name"]] = [func] + + # TODO add an optional return type check + def lookup_function( + self, uri: str, function_name: str, signature: tuple + ) -> Optional[tuple[FunctionEntry, Type]]: + uri = self._uri_aliases.get(uri, uri) + + if ( + uri not in self._function_mapping + or function_name not in self._function_mapping[uri] + ): + return None + functions = self._function_mapping[uri][function_name] + for f in functions: + assert isinstance(f, FunctionEntry) + rtn = f.satisfies_signature(signature) + if rtn is not None: + return (f, rtn) + + return None diff --git a/tests/test_derivation_expression.py b/tests/test_derivation_expression.py index 5df2e2d..4b11b3d 100644 --- a/tests/test_derivation_expression.py +++ b/tests/test_derivation_expression.py @@ -24,29 +24,59 @@ def test_ternary(): def test_multiline(): - assert ( - evaluate( - """temp = min(var, 7) + max(var, 7) + assert evaluate( + """temp = min(var, 7) + max(var, 7) decimal""", - {"var": 5}, + {"var": 5}, + ) == Type( + decimal=Type.Decimal( + precision=13, scale=11, nullability=Type.NULLABILITY_REQUIRED ) - == Type(decimal=Type.Decimal(precision=13, scale=11)) ) def test_simple_data_types(): - assert evaluate("i8") == Type(i8=Type.I8()) - assert evaluate("i16") == Type(i16=Type.I16()) - assert evaluate("i32") == Type(i32=Type.I32()) - assert evaluate("i64") == Type(i64=Type.I64()) - assert evaluate("fp32") == Type(fp32=Type.FP32()) - assert evaluate("fp64") == Type(fp64=Type.FP64()) - assert evaluate("boolean") == Type(bool=Type.Boolean()) + assert evaluate("i8") == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) + assert evaluate("i16") == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED)) + assert evaluate("i32") == Type(i32=Type.I32(nullability=Type.NULLABILITY_REQUIRED)) + assert evaluate("i64") == Type(i64=Type.I64(nullability=Type.NULLABILITY_REQUIRED)) + assert evaluate("fp32") == Type( + fp32=Type.FP32(nullability=Type.NULLABILITY_REQUIRED) + ) + assert evaluate("fp64") == Type( + fp64=Type.FP64(nullability=Type.NULLABILITY_REQUIRED) + ) + assert evaluate("boolean") == Type( + bool=Type.Boolean(nullability=Type.NULLABILITY_REQUIRED) + ) + assert evaluate("i8?") == Type(i8=Type.I8(nullability=Type.NULLABILITY_NULLABLE)) + assert evaluate("i16?") == Type(i16=Type.I16(nullability=Type.NULLABILITY_NULLABLE)) + assert evaluate("i32?") == Type(i32=Type.I32(nullability=Type.NULLABILITY_NULLABLE)) + assert evaluate("i64?") == Type(i64=Type.I64(nullability=Type.NULLABILITY_NULLABLE)) + assert evaluate("fp32?") == Type( + fp32=Type.FP32(nullability=Type.NULLABILITY_NULLABLE) + ) + assert evaluate("fp64?") == Type( + fp64=Type.FP64(nullability=Type.NULLABILITY_NULLABLE) + ) + assert evaluate("boolean?") == Type( + bool=Type.Boolean(nullability=Type.NULLABILITY_NULLABLE) + ) def test_data_type(): assert evaluate("decimal

", {"S": 10, "P": 20}) == Type( - decimal=Type.Decimal(precision=21, scale=11) + decimal=Type.Decimal( + precision=21, scale=11, nullability=Type.NULLABILITY_REQUIRED + ) + ) + + +def test_data_type_nullable(): + assert evaluate("decimal?

", {"S": 10, "P": 20}) == Type( + decimal=Type.Decimal( + precision=21, scale=11, nullability=Type.NULLABILITY_NULLABLE + ) ) @@ -59,7 +89,11 @@ def func(P1, S1, P2, S2): prec = min(init_prec, 38) scale_after_borrow = max(init_scale - delta, min_scale) scale = scale_after_borrow if init_prec > 38 else init_scale - return Type(decimal=Type.Decimal(precision=prec, scale=scale)) + return Type( + decimal=Type.Decimal( + precision=prec, scale=scale, nullability=Type.NULLABILITY_REQUIRED + ) + ) args = {"P1": 10, "S1": 8, "P2": 14, "S2": 2} @@ -78,4 +112,4 @@ def func(P1, S1, P2, S2): args, ) == func_eval - ) \ No newline at end of file + ) diff --git a/tests/test_function_registry.py b/tests/test_function_registry.py new file mode 100644 index 0000000..ef7387e --- /dev/null +++ b/tests/test_function_registry.py @@ -0,0 +1,335 @@ +import yaml + +from substrait.gen.proto.type_pb2 import Type +from substrait.function_registry import FunctionRegistry, covers +from substrait.derivation_expression import _parse + +content = """%YAML 1.2 +--- +scalar_functions: + - name: "test_fn" + description: "" + impls: + - args: + - value: i8 + variadic: + min: 2 + return: i8 + - name: "test_fn_variadic_any" + description: "" + impls: + - args: + - value: any1 + variadic: + min: 2 + return: any1 + - name: "add" + description: "Add two values." + impls: + - args: + - name: x + value: i8 + - name: y + value: i8 + options: + overflow: + values: [ SILENT, SATURATE, ERROR ] + return: i8 + - args: + - name: x + value: i8 + - name: y + value: i8 + - name: z + value: any + options: + overflow: + values: [ SILENT, SATURATE, ERROR ] + return: i16 + - args: + - name: x + value: any1 + - name: y + value: any1 + - name: z + value: any2 + options: + overflow: + values: [ SILENT, SATURATE, ERROR ] + return: any2 + - name: "test_decimal" + impls: + - args: + - name: x + value: decimal + - name: y + value: decimal + return: decimal + - name: "test_enum" + impls: + - args: + - name: op + options: [ INTACT, FLIP ] + - name: x + value: i8 + return: i8 + - name: "add_declared" + description: "Add two values." + impls: + - args: + - name: x + value: i8 + - name: y + value: i8 + nullability: DECLARED_OUTPUT + return: i8? + - name: "add_discrete" + description: "Add two values." + impls: + - args: + - name: x + value: i8? + - name: y + value: i8 + nullability: DISCRETE + return: i8? + - name: "test_decimal_discrete" + impls: + - args: + - name: x + value: decimal? + - name: y + value: decimal + nullability: DISCRETE + return: decimal? +""" + + +registry = FunctionRegistry() + +registry.register_extension_dict(yaml.safe_load(content), uri="test") + + +def i8(nullable=False): + return Type( + i8=Type.I8( + nullability=Type.NULLABILITY_REQUIRED + if not nullable + else Type.NULLABILITY_NULLABLE + ) + ) + + +def i16(nullable=False): + return Type( + i16=Type.I16( + nullability=Type.NULLABILITY_REQUIRED + if not nullable + else Type.NULLABILITY_NULLABLE + ) + ) + + +def bool(nullable=False): + return Type( + bool=Type.Boolean( + nullability=Type.NULLABILITY_REQUIRED + if not nullable + else Type.NULLABILITY_NULLABLE + ) + ) + + +def decimal(precision, scale, nullable=False): + return Type( + decimal=Type.Decimal( + scale=scale, + precision=precision, + nullability=Type.NULLABILITY_REQUIRED + if not nullable + else Type.NULLABILITY_NULLABLE, + ) + ) + + +def test_non_existing_uri(): + assert ( + registry.lookup_function( + uri="non_existent", function_name="add", signature=[i8(), i8()] + ) + is None + ) + + +def test_non_existing_function(): + assert ( + registry.lookup_function( + uri="test", function_name="sub", signature=[i8(), i8()] + ) + is None + ) + + +def test_non_existing_function_signature(): + assert ( + registry.lookup_function(uri="test", function_name="add", signature=[i8()]) + is None + ) + + +def test_exact_match(): + assert registry.lookup_function( + uri="test", function_name="add", signature=[i8(), i8()] + )[1] == Type(i8=Type.I8(nullability=Type.NULLABILITY_REQUIRED)) + + +def test_wildcard_match(): + assert registry.lookup_function( + uri="test", function_name="add", signature=[i8(), i8(), bool()] + )[1] == Type(i16=Type.I16(nullability=Type.NULLABILITY_REQUIRED)) + + +def test_wildcard_match_fails_with_constraits(): + assert ( + registry.lookup_function( + uri="test", function_name="add", signature=[i8(), i16(), i16()] + ) + is None + ) + + +def test_wildcard_match_with_constraits(): + assert ( + registry.lookup_function( + uri="test", function_name="add", signature=[i16(), i16(), i8()] + )[1] + == i8() + ) + + +def test_variadic(): + assert ( + registry.lookup_function( + uri="test", function_name="test_fn", signature=[i8(), i8(), i8()] + )[1] + == i8() + ) + + +def test_variadic_any(): + assert ( + registry.lookup_function( + uri="test", + function_name="test_fn_variadic_any", + signature=[i16(), i16(), i16()], + )[1] + == i16() + ) + + +def test_variadic_fails_min_constraint(): + assert ( + registry.lookup_function(uri="test", function_name="test_fn", signature=[i8()]) + is None + ) + + +def test_decimal_happy_path(): + assert registry.lookup_function( + uri="test", + function_name="test_decimal", + signature=[decimal(10, 8), decimal(8, 6)], + )[1] == decimal(11, 7) + + +def test_decimal_violates_constraint(): + assert ( + registry.lookup_function( + uri="test", + function_name="test_decimal", + signature=[decimal(10, 8), decimal(12, 10)], + ) + is None + ) + + +def test_decimal_happy_path_discrete(): + assert registry.lookup_function( + uri="test", + function_name="test_decimal_discrete", + signature=[decimal(10, 8, nullable=True), decimal(8, 6)], + )[1] == decimal(11, 7, nullable=True) + + +def test_enum_with_valid_option(): + assert ( + registry.lookup_function( + uri="test", + function_name="test_enum", + signature=["FLIP", i8()], + )[1] + == i8() + ) + + +def test_enum_with_nonexistent_option(): + assert ( + registry.lookup_function( + uri="test", + function_name="test_enum", + signature=["NONEXISTENT", i8()], + ) + is None + ) + + +def test_function_with_nullable_args(): + assert registry.lookup_function( + uri="test", function_name="add", signature=[i8(nullable=True), i8()] + )[1] == i8(nullable=True) + + +def test_function_with_declared_output_nullability(): + assert registry.lookup_function( + uri="test", function_name="add_declared", signature=[i8(), i8()] + )[1] == i8(nullable=True) + + +def test_function_with_discrete_nullability(): + assert registry.lookup_function( + uri="test", function_name="add_discrete", signature=[i8(nullable=True), i8()] + )[1] == i8(nullable=True) + + +def test_function_with_discrete_nullability(): + assert ( + registry.lookup_function( + uri="test", function_name="add_discrete", signature=[i8(), i8()] + ) + is None + ) + + +def test_covers(): + params = {} + assert covers(i8(), _parse("i8"), params) + assert params == {} + + +def test_covers_nullability(): + assert not covers(i8(nullable=True), _parse("i8"), {}, check_nullability=True) + assert covers(i8(nullable=True), _parse("i8?"), {}, check_nullability=True) + + +def test_covers_decimal(): + assert not covers(decimal(10, 8), _parse("decimal<11, A>"), {}) + + +def test_covers_decimal_happy_path(): + params = {} + assert covers(decimal(10, 8), _parse("decimal<10, A>"), params) + assert params == {"A": 8} + + +def test_covers_any(): + assert covers(decimal(10, 8), _parse("any"), {})