|
| 1 | +from substrait.gen.proto.parameterized_types_pb2 import ParameterizedType |
| 2 | +from substrait.gen.proto.type_pb2 import Type |
| 3 | +from importlib.resources import files as importlib_files |
| 4 | +import itertools |
| 5 | +from collections import defaultdict |
| 6 | +from collections.abc import Mapping |
| 7 | +from pathlib import Path |
| 8 | +from typing import Any, Optional, Union |
| 9 | +from .derivation_expression import evaluate |
| 10 | + |
| 11 | +import yaml |
| 12 | +import re |
| 13 | + |
| 14 | +_normalized_key_names = { |
| 15 | + "binary": "vbin", |
| 16 | + "interval_compound": "icompound", |
| 17 | + "interval_day": "iday", |
| 18 | + "interval_year": "iyear", |
| 19 | + "string": "str", |
| 20 | + "timestamp": "ts", |
| 21 | + "timestamp_tz": "tstz", |
| 22 | +} |
| 23 | + |
| 24 | + |
| 25 | +def normalize_substrait_type_names(typ: str) -> str: |
| 26 | + # First strip off any punctuation |
| 27 | + typ = typ.strip("?").lower() |
| 28 | + |
| 29 | + # Common prefixes whose information does not matter to an extension function |
| 30 | + # signature |
| 31 | + for complex_type, abbr in [ |
| 32 | + ("fixedchar", "fchar"), |
| 33 | + ("varchar", "vchar"), |
| 34 | + ("fixedbinary", "fbin"), |
| 35 | + ("decimal", "dec"), |
| 36 | + ("precision_timestamp", "pts"), |
| 37 | + ("precision_timestamp_tz", "ptstz"), |
| 38 | + ("struct", "struct"), |
| 39 | + ("list", "list"), |
| 40 | + ("map", "map"), |
| 41 | + ("any", "any"), |
| 42 | + ("boolean", "bool"), |
| 43 | + ]: |
| 44 | + if typ.lower().startswith(complex_type): |
| 45 | + typ = abbr |
| 46 | + |
| 47 | + # Then pass through the dictionary of mappings, defaulting to just the |
| 48 | + # existing string |
| 49 | + typ = _normalized_key_names.get(typ.lower(), typ.lower()) |
| 50 | + return typ |
| 51 | + |
| 52 | + |
| 53 | +id_generator = itertools.count(1) |
| 54 | + |
| 55 | + |
| 56 | +def to_integer_option(txt: str): |
| 57 | + if txt.isnumeric(): |
| 58 | + return ParameterizedType.IntegerOption(literal=int(txt)) |
| 59 | + else: |
| 60 | + return ParameterizedType.IntegerOption( |
| 61 | + parameter=ParameterizedType.IntegerParameter(name=txt) |
| 62 | + ) |
| 63 | + |
| 64 | + |
| 65 | +def to_parameterized_type(dtype: str): |
| 66 | + if dtype == "boolean": |
| 67 | + return ParameterizedType(bool=Type.Boolean()) |
| 68 | + elif dtype == "i8": |
| 69 | + return ParameterizedType(i8=Type.I8()) |
| 70 | + elif dtype == "i16": |
| 71 | + return ParameterizedType(i16=Type.I16()) |
| 72 | + elif dtype == "i32": |
| 73 | + return ParameterizedType(i32=Type.I32()) |
| 74 | + elif dtype == "i64": |
| 75 | + return ParameterizedType(i64=Type.I64()) |
| 76 | + elif dtype == "fp32": |
| 77 | + return ParameterizedType(fp32=Type.FP32()) |
| 78 | + elif dtype == "fp64": |
| 79 | + return ParameterizedType(fp64=Type.FP64()) |
| 80 | + elif dtype == "timestamp": |
| 81 | + return ParameterizedType(timestamp=Type.Timestamp()) |
| 82 | + elif dtype == "timestamp_tz": |
| 83 | + return ParameterizedType(timestamp_tz=Type.TimestampTZ()) |
| 84 | + elif dtype == "date": |
| 85 | + return ParameterizedType(date=Type.Date()) |
| 86 | + elif dtype == "time": |
| 87 | + return ParameterizedType(time=Type.Time()) |
| 88 | + elif dtype == "interval_year": |
| 89 | + return ParameterizedType(interval_year=Type.IntervalYear()) |
| 90 | + elif dtype.startswith("decimal") or dtype.startswith("DECIMAL"): |
| 91 | + (_, precision, scale, _) = re.split(r"\W+", dtype) |
| 92 | + |
| 93 | + return ParameterizedType( |
| 94 | + decimal=ParameterizedType.ParameterizedDecimal( |
| 95 | + scale=to_integer_option(scale), precision=to_integer_option(precision) |
| 96 | + ) |
| 97 | + ) |
| 98 | + elif dtype.startswith("varchar"): |
| 99 | + (_, length, _) = re.split(r"\W+", dtype) |
| 100 | + |
| 101 | + return ParameterizedType( |
| 102 | + varchar=ParameterizedType.ParameterizedVarChar( |
| 103 | + length=to_integer_option(length) |
| 104 | + ) |
| 105 | + ) |
| 106 | + elif dtype.startswith("precision_timestamp"): |
| 107 | + (_, precision, _) = re.split(r"\W+", dtype) |
| 108 | + |
| 109 | + return ParameterizedType( |
| 110 | + precision_timestamp=ParameterizedType.ParameterizedPrecisionTimestamp( |
| 111 | + precision=to_integer_option(precision) |
| 112 | + ) |
| 113 | + ) |
| 114 | + elif dtype.startswith("precision_timestamp_tz"): |
| 115 | + (_, precision, _) = re.split(r"\W+", dtype) |
| 116 | + |
| 117 | + return ParameterizedType( |
| 118 | + precision_timestamp_tz=ParameterizedType.ParameterizedPrecisionTimestampTZ( |
| 119 | + precision=to_integer_option(precision) |
| 120 | + ) |
| 121 | + ) |
| 122 | + elif dtype.startswith("fixedchar"): |
| 123 | + (_, length, _) = re.split(r"\W+", dtype) |
| 124 | + |
| 125 | + return ParameterizedType( |
| 126 | + fixed_char=ParameterizedType.ParameterizedFixedChar( |
| 127 | + length=to_integer_option(length) |
| 128 | + ) |
| 129 | + ) |
| 130 | + elif dtype == "string": |
| 131 | + return ParameterizedType(string=Type.String()) |
| 132 | + elif dtype.startswith("list"): |
| 133 | + inner_dtype = dtype[5:-1] |
| 134 | + return ParameterizedType( |
| 135 | + list=ParameterizedType.ParameterizedList( |
| 136 | + type=to_parameterized_type(inner_dtype) |
| 137 | + ) |
| 138 | + ) |
| 139 | + elif dtype.startswith("interval_day"): |
| 140 | + (_, precision, _) = re.split(r"\W+", dtype) |
| 141 | + |
| 142 | + return ParameterizedType( |
| 143 | + interval_day=ParameterizedType.ParameterizedIntervalDay( |
| 144 | + precision=to_integer_option(precision) |
| 145 | + ) |
| 146 | + ) |
| 147 | + elif dtype.startswith("any"): |
| 148 | + return ParameterizedType( |
| 149 | + type_parameter=ParameterizedType.TypeParameter(name=dtype) |
| 150 | + ) |
| 151 | + elif dtype.startswith("u!") or dtype == "geometry": |
| 152 | + return ParameterizedType( |
| 153 | + user_defined=ParameterizedType.ParameterizedUserDefined() |
| 154 | + ) |
| 155 | + else: |
| 156 | + raise Exception(f"Unkownn type - {dtype}") |
| 157 | + |
| 158 | + |
| 159 | +def violates_integer_option( |
| 160 | + actual: int, option: ParameterizedType.IntegerOption, parameters: dict |
| 161 | +): |
| 162 | + integer_type = option.WhichOneof("integer_type") |
| 163 | + |
| 164 | + if integer_type == "literal" and actual != option.literal: |
| 165 | + return True |
| 166 | + else: |
| 167 | + parameter_name = option.parameter.name |
| 168 | + if parameter_name in parameters and parameters[parameter_name] != actual: |
| 169 | + return True |
| 170 | + else: |
| 171 | + parameters[parameter_name] = actual |
| 172 | + |
| 173 | + return False |
| 174 | + |
| 175 | + |
| 176 | +def covers(dtype: Type, parameterized_type: ParameterizedType, parameters: dict): |
| 177 | + expected_kind = parameterized_type.WhichOneof("kind") |
| 178 | + |
| 179 | + if expected_kind == "type_parameter": |
| 180 | + parameter_name = parameterized_type.type_parameter.name |
| 181 | + if parameter_name == "any": |
| 182 | + return True |
| 183 | + else: |
| 184 | + if parameter_name in parameters and parameters[ |
| 185 | + parameter_name |
| 186 | + ].SerializeToString(deterministic=True) != dtype.SerializeToString( |
| 187 | + deterministic=True |
| 188 | + ): |
| 189 | + return False |
| 190 | + else: |
| 191 | + parameters[parameter_name] = dtype |
| 192 | + return True |
| 193 | + |
| 194 | + kind = dtype.WhichOneof("kind") |
| 195 | + |
| 196 | + if kind != expected_kind: |
| 197 | + return False |
| 198 | + |
| 199 | + if kind == "decimal": |
| 200 | + if violates_integer_option( |
| 201 | + dtype.decimal.scale, parameterized_type.decimal.scale, parameters |
| 202 | + ) or violates_integer_option( |
| 203 | + dtype.decimal.precision, parameterized_type.decimal.precision, parameters |
| 204 | + ): |
| 205 | + return False |
| 206 | + |
| 207 | + # TODO handle all types |
| 208 | + |
| 209 | + return True |
| 210 | + |
| 211 | + |
| 212 | +class FunctionEntry: |
| 213 | + def __init__(self, uri: str, name: str, impl: Mapping[str, Any]) -> None: |
| 214 | + self.name = name |
| 215 | + self.normalized_inputs: list = [] |
| 216 | + self.uri: str = uri |
| 217 | + self.anchor = next(id_generator) |
| 218 | + self.arguments = [] |
| 219 | + self.rtn = impl["return"] |
| 220 | + self.nullability = impl.get("nullability", False) |
| 221 | + self.variadic = impl.get("variadic", False) |
| 222 | + if input_args := impl.get("args", []): |
| 223 | + for val in input_args: |
| 224 | + if typ := val.get("value"): |
| 225 | + self.arguments.append(to_parameterized_type(typ.strip("?"))) |
| 226 | + self.normalized_inputs.append(normalize_substrait_type_names(typ)) |
| 227 | + elif arg_name := val.get("name", None): |
| 228 | + self.arguments.append(val.get("options")) |
| 229 | + self.normalized_inputs.append("req") |
| 230 | + |
| 231 | + def __repr__(self) -> str: |
| 232 | + return f"{self.name}:{'_'.join(self.normalized_inputs)}" |
| 233 | + |
| 234 | + def satisfies_signature(self, signature: tuple) -> Optional[str]: |
| 235 | + if self.variadic: |
| 236 | + min_args_allowed = self.variadic.get("min", 0) |
| 237 | + if len(signature) < min_args_allowed: |
| 238 | + return None |
| 239 | + inputs = [self.arguments[0]] * len(signature) |
| 240 | + else: |
| 241 | + inputs = self.arguments |
| 242 | + if len(inputs) != len(signature): |
| 243 | + return None |
| 244 | + |
| 245 | + zipped_args = list(zip(inputs, signature)) |
| 246 | + |
| 247 | + parameters = {} |
| 248 | + |
| 249 | + for x, y in zipped_args: |
| 250 | + if type(y) == str: |
| 251 | + if y not in x: |
| 252 | + return None |
| 253 | + else: |
| 254 | + if not covers(y, x, parameters): |
| 255 | + return None |
| 256 | + |
| 257 | + return evaluate(self.rtn, parameters) |
| 258 | + |
| 259 | + |
| 260 | +class FunctionRegistry: |
| 261 | + def __init__(self) -> None: |
| 262 | + self._function_mapping: dict = defaultdict(dict) |
| 263 | + self.id_generator = itertools.count(1) |
| 264 | + |
| 265 | + self.uri_aliases = {} |
| 266 | + |
| 267 | + for fpath in importlib_files("substrait.extensions").glob( # type: ignore |
| 268 | + "functions*.yaml" |
| 269 | + ): |
| 270 | + uri = f"https://github.com/substrait-io/substrait/blob/main/extensions/{fpath.name}" |
| 271 | + self.uri_aliases[fpath.name] = uri |
| 272 | + self.register_extension_yaml(fpath, uri) |
| 273 | + |
| 274 | + def register_extension_yaml( |
| 275 | + self, |
| 276 | + fname: Union[str, Path], |
| 277 | + uri: str, |
| 278 | + ) -> None: |
| 279 | + fname = Path(fname) |
| 280 | + with open(fname) as f: # type: ignore |
| 281 | + extension_definitions = yaml.safe_load(f) |
| 282 | + |
| 283 | + self.register_extension_dict(extension_definitions, uri) |
| 284 | + |
| 285 | + def register_extension_dict(self, definitions: dict, uri: str) -> None: |
| 286 | + for named_functions in definitions.values(): |
| 287 | + for function in named_functions: |
| 288 | + for impl in function.get("impls", []): |
| 289 | + func = FunctionEntry(uri, function["name"], impl) |
| 290 | + if ( |
| 291 | + func.uri in self._function_mapping |
| 292 | + and function["name"] in self._function_mapping[func.uri] |
| 293 | + ): |
| 294 | + self._function_mapping[func.uri][function["name"]].append(func) |
| 295 | + else: |
| 296 | + self._function_mapping[func.uri][function["name"]] = [func] |
| 297 | + |
| 298 | + # TODO add an optional return type check |
| 299 | + def lookup_function( |
| 300 | + self, uri: str, function_name: str, signature: tuple |
| 301 | + ) -> Optional[tuple[FunctionEntry, Type]]: |
| 302 | + uri = self.uri_aliases.get(uri, uri) |
| 303 | + |
| 304 | + if ( |
| 305 | + uri not in self._function_mapping |
| 306 | + or function_name not in self._function_mapping[uri] |
| 307 | + ): |
| 308 | + return None |
| 309 | + functions = self._function_mapping[uri][function_name] |
| 310 | + for f in functions: |
| 311 | + assert isinstance(f, FunctionEntry) |
| 312 | + rtn = f.satisfies_signature(signature) |
| 313 | + if rtn is not None: |
| 314 | + return (f, rtn) |
| 315 | + |
| 316 | + return None |
0 commit comments