Skip to content

Commit ec31b45

Browse files
committed
feat: add registry for extension functions
1 parent 56e7b1e commit ec31b45

File tree

3 files changed

+538
-2
lines changed

3 files changed

+538
-2
lines changed

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ dynamic = ["version"]
1212
write_to = "src/substrait/_version.py"
1313

1414
[project.optional-dependencies]
15-
extensions = ["antlr4-python3-runtime"]
15+
extensions = ["antlr4-python3-runtime", "pyyaml"]
1616
gen_proto = ["protobuf == 3.20.1", "protoletariat >= 2.0.0"]
17-
test = ["pytest >= 7.0.0", "antlr4-python3-runtime"]
17+
test = ["pytest >= 7.0.0", "antlr4-python3-runtime", "pyyaml"]
1818

1919
[tool.pytest.ini_options]
2020
pythonpath = "src"

src/substrait/function_registry.py

+316
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
from substrait.gen.proto.parameterized_types_pb2 import ParameterizedType
2+
from substrait.gen.proto.type_pb2 import Type
3+
from importlib.resources import files as importlib_files
4+
import itertools
5+
from collections import defaultdict
6+
from collections.abc import Mapping
7+
from pathlib import Path
8+
from typing import Any, Optional, Union
9+
from .derivation_expression import evaluate
10+
11+
import yaml
12+
import re
13+
14+
_normalized_key_names = {
15+
"binary": "vbin",
16+
"interval_compound": "icompound",
17+
"interval_day": "iday",
18+
"interval_year": "iyear",
19+
"string": "str",
20+
"timestamp": "ts",
21+
"timestamp_tz": "tstz",
22+
}
23+
24+
25+
def normalize_substrait_type_names(typ: str) -> str:
26+
# First strip off any punctuation
27+
typ = typ.strip("?").lower()
28+
29+
# Common prefixes whose information does not matter to an extension function
30+
# signature
31+
for complex_type, abbr in [
32+
("fixedchar", "fchar"),
33+
("varchar", "vchar"),
34+
("fixedbinary", "fbin"),
35+
("decimal", "dec"),
36+
("precision_timestamp", "pts"),
37+
("precision_timestamp_tz", "ptstz"),
38+
("struct", "struct"),
39+
("list", "list"),
40+
("map", "map"),
41+
("any", "any"),
42+
("boolean", "bool"),
43+
]:
44+
if typ.lower().startswith(complex_type):
45+
typ = abbr
46+
47+
# Then pass through the dictionary of mappings, defaulting to just the
48+
# existing string
49+
typ = _normalized_key_names.get(typ.lower(), typ.lower())
50+
return typ
51+
52+
53+
id_generator = itertools.count(1)
54+
55+
56+
def to_integer_option(txt: str):
57+
if txt.isnumeric():
58+
return ParameterizedType.IntegerOption(literal=int(txt))
59+
else:
60+
return ParameterizedType.IntegerOption(
61+
parameter=ParameterizedType.IntegerParameter(name=txt)
62+
)
63+
64+
65+
def to_parameterized_type(dtype: str):
66+
if dtype == "boolean":
67+
return ParameterizedType(bool=Type.Boolean())
68+
elif dtype == "i8":
69+
return ParameterizedType(i8=Type.I8())
70+
elif dtype == "i16":
71+
return ParameterizedType(i16=Type.I16())
72+
elif dtype == "i32":
73+
return ParameterizedType(i32=Type.I32())
74+
elif dtype == "i64":
75+
return ParameterizedType(i64=Type.I64())
76+
elif dtype == "fp32":
77+
return ParameterizedType(fp32=Type.FP32())
78+
elif dtype == "fp64":
79+
return ParameterizedType(fp64=Type.FP64())
80+
elif dtype == "timestamp":
81+
return ParameterizedType(timestamp=Type.Timestamp())
82+
elif dtype == "timestamp_tz":
83+
return ParameterizedType(timestamp_tz=Type.TimestampTZ())
84+
elif dtype == "date":
85+
return ParameterizedType(date=Type.Date())
86+
elif dtype == "time":
87+
return ParameterizedType(time=Type.Time())
88+
elif dtype == "interval_year":
89+
return ParameterizedType(interval_year=Type.IntervalYear())
90+
elif dtype.startswith("decimal") or dtype.startswith("DECIMAL"):
91+
(_, precision, scale, _) = re.split(r"\W+", dtype)
92+
93+
return ParameterizedType(
94+
decimal=ParameterizedType.ParameterizedDecimal(
95+
scale=to_integer_option(scale), precision=to_integer_option(precision)
96+
)
97+
)
98+
elif dtype.startswith("varchar"):
99+
(_, length, _) = re.split(r"\W+", dtype)
100+
101+
return ParameterizedType(
102+
varchar=ParameterizedType.ParameterizedVarChar(
103+
length=to_integer_option(length)
104+
)
105+
)
106+
elif dtype.startswith("precision_timestamp"):
107+
(_, precision, _) = re.split(r"\W+", dtype)
108+
109+
return ParameterizedType(
110+
precision_timestamp=ParameterizedType.ParameterizedPrecisionTimestamp(
111+
precision=to_integer_option(precision)
112+
)
113+
)
114+
elif dtype.startswith("precision_timestamp_tz"):
115+
(_, precision, _) = re.split(r"\W+", dtype)
116+
117+
return ParameterizedType(
118+
precision_timestamp_tz=ParameterizedType.ParameterizedPrecisionTimestampTZ(
119+
precision=to_integer_option(precision)
120+
)
121+
)
122+
elif dtype.startswith("fixedchar"):
123+
(_, length, _) = re.split(r"\W+", dtype)
124+
125+
return ParameterizedType(
126+
fixed_char=ParameterizedType.ParameterizedFixedChar(
127+
length=to_integer_option(length)
128+
)
129+
)
130+
elif dtype == "string":
131+
return ParameterizedType(string=Type.String())
132+
elif dtype.startswith("list"):
133+
inner_dtype = dtype[5:-1]
134+
return ParameterizedType(
135+
list=ParameterizedType.ParameterizedList(
136+
type=to_parameterized_type(inner_dtype)
137+
)
138+
)
139+
elif dtype.startswith("interval_day"):
140+
(_, precision, _) = re.split(r"\W+", dtype)
141+
142+
return ParameterizedType(
143+
interval_day=ParameterizedType.ParameterizedIntervalDay(
144+
precision=to_integer_option(precision)
145+
)
146+
)
147+
elif dtype.startswith("any"):
148+
return ParameterizedType(
149+
type_parameter=ParameterizedType.TypeParameter(name=dtype)
150+
)
151+
elif dtype.startswith("u!") or dtype == "geometry":
152+
return ParameterizedType(
153+
user_defined=ParameterizedType.ParameterizedUserDefined()
154+
)
155+
else:
156+
raise Exception(f"Unkownn type - {dtype}")
157+
158+
159+
def violates_integer_option(
160+
actual: int, option: ParameterizedType.IntegerOption, parameters: dict
161+
):
162+
integer_type = option.WhichOneof("integer_type")
163+
164+
if integer_type == "literal" and actual != option.literal:
165+
return True
166+
else:
167+
parameter_name = option.parameter.name
168+
if parameter_name in parameters and parameters[parameter_name] != actual:
169+
return True
170+
else:
171+
parameters[parameter_name] = actual
172+
173+
return False
174+
175+
176+
def covers(dtype: Type, parameterized_type: ParameterizedType, parameters: dict):
177+
expected_kind = parameterized_type.WhichOneof("kind")
178+
179+
if expected_kind == "type_parameter":
180+
parameter_name = parameterized_type.type_parameter.name
181+
if parameter_name == "any":
182+
return True
183+
else:
184+
if parameter_name in parameters and parameters[
185+
parameter_name
186+
].SerializeToString(deterministic=True) != dtype.SerializeToString(
187+
deterministic=True
188+
):
189+
return False
190+
else:
191+
parameters[parameter_name] = dtype
192+
return True
193+
194+
kind = dtype.WhichOneof("kind")
195+
196+
if kind != expected_kind:
197+
return False
198+
199+
if kind == "decimal":
200+
if violates_integer_option(
201+
dtype.decimal.scale, parameterized_type.decimal.scale, parameters
202+
) or violates_integer_option(
203+
dtype.decimal.precision, parameterized_type.decimal.precision, parameters
204+
):
205+
return False
206+
207+
# TODO handle all types
208+
209+
return True
210+
211+
212+
class FunctionEntry:
213+
def __init__(self, uri: str, name: str, impl: Mapping[str, Any]) -> None:
214+
self.name = name
215+
self.normalized_inputs: list = []
216+
self.uri: str = uri
217+
self.anchor = next(id_generator)
218+
self.arguments = []
219+
self.rtn = impl["return"]
220+
self.nullability = impl.get("nullability", False)
221+
self.variadic = impl.get("variadic", False)
222+
if input_args := impl.get("args", []):
223+
for val in input_args:
224+
if typ := val.get("value"):
225+
self.arguments.append(to_parameterized_type(typ.strip("?")))
226+
self.normalized_inputs.append(normalize_substrait_type_names(typ))
227+
elif arg_name := val.get("name", None):
228+
self.arguments.append(val.get("options"))
229+
self.normalized_inputs.append("req")
230+
231+
def __repr__(self) -> str:
232+
return f"{self.name}:{'_'.join(self.normalized_inputs)}"
233+
234+
def satisfies_signature(self, signature: tuple) -> Optional[str]:
235+
if self.variadic:
236+
min_args_allowed = self.variadic.get("min", 0)
237+
if len(signature) < min_args_allowed:
238+
return None
239+
inputs = [self.arguments[0]] * len(signature)
240+
else:
241+
inputs = self.arguments
242+
if len(inputs) != len(signature):
243+
return None
244+
245+
zipped_args = list(zip(inputs, signature))
246+
247+
parameters = {}
248+
249+
for x, y in zipped_args:
250+
if type(y) == str:
251+
if y not in x:
252+
return None
253+
else:
254+
if not covers(y, x, parameters):
255+
return None
256+
257+
return evaluate(self.rtn, parameters)
258+
259+
260+
class FunctionRegistry:
261+
def __init__(self) -> None:
262+
self._function_mapping: dict = defaultdict(dict)
263+
self.id_generator = itertools.count(1)
264+
265+
self.uri_aliases = {}
266+
267+
for fpath in importlib_files("substrait.extensions").glob( # type: ignore
268+
"functions*.yaml"
269+
):
270+
uri = f"https://github.com/substrait-io/substrait/blob/main/extensions/{fpath.name}"
271+
self.uri_aliases[fpath.name] = uri
272+
self.register_extension_yaml(fpath, uri)
273+
274+
def register_extension_yaml(
275+
self,
276+
fname: Union[str, Path],
277+
uri: str,
278+
) -> None:
279+
fname = Path(fname)
280+
with open(fname) as f: # type: ignore
281+
extension_definitions = yaml.safe_load(f)
282+
283+
self.register_extension_dict(extension_definitions, uri)
284+
285+
def register_extension_dict(self, definitions: dict, uri: str) -> None:
286+
for named_functions in definitions.values():
287+
for function in named_functions:
288+
for impl in function.get("impls", []):
289+
func = FunctionEntry(uri, function["name"], impl)
290+
if (
291+
func.uri in self._function_mapping
292+
and function["name"] in self._function_mapping[func.uri]
293+
):
294+
self._function_mapping[func.uri][function["name"]].append(func)
295+
else:
296+
self._function_mapping[func.uri][function["name"]] = [func]
297+
298+
# TODO add an optional return type check
299+
def lookup_function(
300+
self, uri: str, function_name: str, signature: tuple
301+
) -> Optional[tuple[FunctionEntry, Type]]:
302+
uri = self.uri_aliases.get(uri, uri)
303+
304+
if (
305+
uri not in self._function_mapping
306+
or function_name not in self._function_mapping[uri]
307+
):
308+
return None
309+
functions = self._function_mapping[uri][function_name]
310+
for f in functions:
311+
assert isinstance(f, FunctionEntry)
312+
rtn = f.satisfies_signature(signature)
313+
if rtn is not None:
314+
return (f, rtn)
315+
316+
return None

0 commit comments

Comments
 (0)