Skip to content

Commit 4b4a6a0

Browse files
committed
More simplifier reworking
1 parent bd9551a commit 4b4a6a0

File tree

1 file changed

+118
-86
lines changed

1 file changed

+118
-86
lines changed

src/flitter/language/tree.pyx

Lines changed: 118 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ in unrolling loops, simpler does not necessarily mean "smaller."
1010
from loguru import logger
1111

1212
from libc.stdint cimport int64_t
13+
from cpython cimport PyObject
14+
from cpython.dict cimport PyDict_GetItem
1315

1416
from .. import name_patch
1517
from ..cache import SharedCache
@@ -28,14 +30,14 @@ cdef Vector Two = Vector(2)
2830
cdef bint sequence_pack(list expressions):
2931
cdef Expression expr
3032
cdef bint touched = False
31-
cdef list vectors, todo=[]
32-
while expressions:
33-
todo.append(expressions.pop())
33+
cdef list vectors, todo=expressions.copy()
34+
todo.reverse()
35+
expressions.clear()
3436
while todo:
3537
expr = <Expression>todo.pop()
36-
if todo and type(expr) is Literal:
38+
if type(expr) is Literal:
3739
vectors = [(<Literal>expr).value]
38-
while todo and type(todo[len(todo)-1]) is Literal:
40+
while todo and type(todo[-1]) is Literal:
3941
vectors.append((<Literal>todo.pop()).value)
4042
if len(vectors) > 1:
4143
expr = Literal(Vector._compose(vectors))
@@ -313,35 +315,48 @@ cdef class Name(Expression):
313315
self.unbound_names = frozenset([name])
314316

315317
cdef void _compile(self, Program program, list lnames):
316-
cdef int64_t i, n=len(lnames)-1
317-
for i in range(len(lnames)):
318-
if self.name == <str>lnames[n-i]:
318+
cdef int64_t i
319+
cdef PyObject* ptr
320+
for i, name in enumerate(reversed(lnames)):
321+
if self.name == <str>name:
319322
program.local_load(i)
320323
break
321324
else:
322-
if self.name in static_builtins:
323-
program.literal(static_builtins[self.name])
324-
elif self.name in dynamic_builtins:
325-
program.literal(dynamic_builtins[self.name])
325+
if (ptr := PyDict_GetItem(static_builtins, self.name)) != NULL:
326+
program.literal(<Vector>ptr)
327+
elif (ptr := PyDict_GetItem(dynamic_builtins, self.name)) != NULL:
328+
program.literal(<Vector>ptr)
326329
else:
327330
program.compiler_errors.add(f"Unbound name '{self.name}'")
328331
program.literal(null_)
329332

330333
cdef Expression _simplify(self, Context context):
331-
if self.name in context.names:
332-
value = context.names[self.name]
334+
cdef str name = self.name
335+
cdef PyObject* ptr
336+
cdef Literal literal
337+
if (ptr := PyDict_GetItem(context.names, name)) != NULL:
338+
value = <object>ptr
333339
if value is None or type(value) is Function:
334340
return self
335341
elif type(value) is Name:
336342
return (<Name>value)._simplify(context)
337343
elif type(value) is Vector:
338-
return Literal((<Vector>value).copy())
344+
literal = Literal.__new__(Literal)
345+
literal.value = (<Vector>value).copy()
346+
literal.unbound_names = EmptySet
347+
return literal
339348
else:
340-
return Literal(value)
341-
elif (value := static_builtins.get(self.name)) is not None:
342-
return Literal(value)
343-
elif self.name not in dynamic_builtins:
344-
context.errors.add(f"Unbound name '{self.name}'")
349+
literal = Literal.__new__(Literal)
350+
literal.value = Vector._coerce(value)
351+
literal.unbound_names = EmptySet
352+
return literal
353+
elif (ptr := PyDict_GetItem(static_builtins, name)) != NULL:
354+
literal = Literal.__new__(Literal)
355+
literal.value = <Vector>ptr
356+
literal.unbound_names = EmptySet
357+
return literal
358+
elif PyDict_GetItem(dynamic_builtins, name) == NULL:
359+
context.errors.add(f"Unbound name '{name}'")
345360
return NoOp
346361
return self
347362

@@ -957,34 +972,34 @@ cdef class Call(Expression):
957972
cdef set unbound = set()
958973
unbound.update(self.function.unbound_names)
959974
cdef Expression arg
960-
if self.args:
975+
if self.args is not None:
961976
for arg in self.args:
962977
unbound.update(arg.unbound_names)
963978
cdef Binding binding
964-
if self.keyword_args:
979+
if self.keyword_args is not None:
965980
for binding in self.keyword_args:
966981
unbound.update(binding.expr.unbound_names)
967982
self.unbound_names = frozenset(unbound)
968983

969984
cdef void _compile(self, Program program, list lnames):
970985
cdef Expression expr
971986
cdef list names = []
972-
if self.args:
987+
if self.args is not None:
973988
for expr in self.args:
974989
expr._compile(program, lnames)
975990
cdef Binding keyword_arg
976-
if self.keyword_args:
991+
if self.keyword_args is not None:
977992
for keyword_arg in self.keyword_args:
978993
names.append(keyword_arg.name)
979994
keyword_arg.expr._compile(program, lnames)
980995
if not names and type(self.function) is Literal \
981996
and (<Literal>self.function).value.length == 1 \
982997
and (<Literal>self.function).value.objects is not None \
983998
and not type(function := (<Literal>self.function).value.objects[0]) is Function:
984-
program.call_fast(function, len(self.args) if self.args else 0)
999+
program.call_fast(function, len(self.args) if self.args is not None else 0)
9851000
else:
9861001
self.function._compile(program, lnames)
987-
program.call(len(self.args) if self.args else 0, tuple(names) if names else None)
1002+
program.call(len(self.args) if self.args is not None else 0, tuple(names) if names else None)
9881003

9891004
cdef Expression _simplify(self, Context context):
9901005
cdef Expression function = self.function._simplify(context)
@@ -1002,7 +1017,7 @@ cdef class Call(Expression):
10021017
cdef bint all_literal_args=True, all_dynamic_args=True
10031018
cdef Expression arg, sarg, expr
10041019
cdef list args = []
1005-
if self.args:
1020+
if self.args is not None:
10061021
for arg in self.args:
10071022
sarg = arg._simplify(context)
10081023
touched |= sarg is not arg
@@ -1012,85 +1027,102 @@ cdef class Call(Expression):
10121027
else:
10131028
all_literal_args = False
10141029
cdef list keyword_args = []
1015-
cdef Binding binding
1016-
if self.keyword_args:
1030+
cdef Binding binding, sbinding
1031+
if self.keyword_args is not None:
10171032
for binding in self.keyword_args:
10181033
arg = binding.expr._simplify(context)
1019-
touched |= arg is not binding.expr
1020-
keyword_args.append(Binding(binding.name, arg))
1034+
if arg is not binding.expr:
1035+
sbinding = Binding.__new__(Binding)
1036+
sbinding.name = binding.name
1037+
sbinding.expr = arg
1038+
keyword_args.append(sbinding)
1039+
touched = True
1040+
else:
1041+
keyword_args.append(binding)
10211042
if type(arg) is Literal:
10221043
all_dynamic_args = False
10231044
else:
10241045
all_literal_args = False
10251046
cdef list bindings, renames
10261047
cdef dict kwargs
10271048
cdef int64_t i, j=0
1028-
cdef str temp_name
1049+
cdef str name
1050+
cdef tuple vector_args
1051+
cdef Vector result
1052+
cdef list results
1053+
cdef Literal literal
1054+
cdef PolyBinding polybinding
1055+
if literal_func and all_literal_args:
1056+
vector_args = tuple([literal.value for literal in args])
1057+
kwargs = {binding.name: (<Literal>binding.expr).value for binding in keyword_args}
1058+
results = []
1059+
if (<Literal>function).value.objects is not None:
1060+
for func in (<Literal>function).value.objects:
1061+
if callable(func):
1062+
try:
1063+
result = func(*vector_args, **kwargs)
1064+
except Exception as exc:
1065+
context.errors.add(f"Error calling {func.__name__}: {str(exc)}")
1066+
else:
1067+
results.append(result)
1068+
else:
1069+
context.errors.add(f"{func!r} is not callable")
1070+
elif (<Literal>function).value.numbers != NULL:
1071+
for i in range((<Literal>function).value.length):
1072+
context.errors.add(f"{(<Literal>function).value.numbers[i]!r} is not callable")
1073+
literal = Literal.__new__(Literal)
1074+
literal.value = Vector._compose(results)
1075+
literal.unbound_names = EmptySet
1076+
return literal
10291077
if func_expr is not None and not func_expr.captures and not (func_expr.recursive and all_dynamic_args):
10301078
kwargs = {binding.name: binding.expr for binding in keyword_args}
10311079
bindings = []
10321080
renames = []
10331081
for i, binding in enumerate(func_expr.parameters):
1082+
name = binding.name
10341083
if i < len(args):
10351084
expr = <Expression>args[i]
1036-
elif binding.name in kwargs:
1037-
expr = <Expression>kwargs[binding.name]
1085+
elif name in kwargs:
1086+
expr = <Expression>kwargs[name]
10381087
elif binding.expr is not None:
10391088
expr = binding.expr
10401089
else:
1041-
expr = Literal(null_)
1042-
if binding.name in context.names:
1043-
temp_name = f'__t{j}'
1090+
expr = NoOp
1091+
while name in context.names:
1092+
name = f'__t{j}'
10441093
j += 1
1045-
while temp_name in context.names:
1046-
temp_name = f'__t{j}'
1047-
j += 1
1048-
bindings.append(PolyBinding((temp_name,), expr))
1049-
renames.append(PolyBinding((binding.name,), Name(temp_name)))
1050-
else:
1051-
bindings.append(PolyBinding((binding.name,), expr))
1094+
polybinding = PolyBinding.__new__(PolyBinding)
1095+
polybinding.names = (name,)
1096+
polybinding.expr = expr
1097+
bindings.append(polybinding)
1098+
if name is not binding.name:
1099+
polybinding = PolyBinding.__new__(PolyBinding)
1100+
polybinding.names = (binding.name,)
1101+
polybinding.expr = Name(name)
1102+
renames.append(polybinding)
10521103
bindings.extend(renames)
1104+
expr = Let(tuple(bindings), func_expr.body)
10531105
if func_expr.recursive:
10541106
if context.call_depth == 0:
10551107
context.call_depth = 1
10561108
try:
1057-
return Let(tuple(bindings), func_expr.body)._simplify(context)
1109+
return expr._simplify(context)
10581110
except RecursionError:
1059-
pass
1111+
logger.trace("Abandoned inline attempt of recursive function: {}", func_expr.name)
10601112
context.call_depth = 0
10611113
elif context.call_depth == MAX_RECURSIVE_CALL_DEPTH:
10621114
raise RecursionError()
10631115
else:
10641116
context.call_depth += 1
1065-
expr = Let(tuple(bindings), func_expr.body)._simplify(context)
1117+
expr = expr._simplify(context)
10661118
context.call_depth -= 1
10671119
return expr
10681120
else:
1069-
return Let(tuple(bindings), func_expr.body)._simplify(context)
1070-
cdef list vector_args, results
1071-
cdef Literal literal_arg
1072-
if literal_func and all_literal_args:
1073-
vector_args = [literal_arg.value for literal_arg in args]
1074-
kwargs = {binding.name: (<Literal>binding.expr).value for binding in keyword_args}
1075-
results = []
1076-
if (<Literal>function).value.objects is not None:
1077-
for func in (<Literal>function).value.objects:
1078-
if callable(func):
1079-
try:
1080-
assert not hasattr(func, 'context_func')
1081-
results.append(func(*vector_args, **kwargs))
1082-
except Exception as exc:
1083-
context.errors.add(f"Error calling {func.__name__}: {str(exc)}")
1084-
else:
1085-
context.errors.add(f"{func!r} is not callable")
1086-
elif (<Literal>function).value.numbers != NULL:
1087-
for i in range((<Literal>function).value.length):
1088-
context.errors.add(f"{(<Literal>function).value.numbers[i]!r} is not callable")
1089-
return Literal(Vector._compose(results))
1121+
return expr._simplify(context)
10901122
if type(function) is Literal and len(args) == 1:
10911123
if (<Literal>function).value == static_builtins['ceil']:
10921124
return Ceil(args[0])
1093-
elif (<Literal>function).value == static_builtins['floor']:
1125+
if (<Literal>function).value == static_builtins['floor']:
10941126
return Floor(args[0])
10951127
if (<Literal>function).value == static_builtins['fract']:
10961128
return Fract(args[0])
@@ -1343,7 +1375,7 @@ cdef class Let(Expression):
13431375
cdef PolyBinding binding
13441376
cdef Expression expr, body=self.body
13451377
cdef Vector value
1346-
cdef str name, existing_name
1378+
cdef str name, rename
13471379
cdef int64_t i, j, n
13481380
cdef bint touched = False
13491381
cdef set shadowed=set(), discarded=set()
@@ -1352,20 +1384,20 @@ cdef class Let(Expression):
13521384
body = (<Let>body).body
13531385
touched = True
13541386
cdef dict renames = {}
1355-
for existing_name, existing_value in context.names.items():
1356-
if type(existing_value) is Name:
1357-
name = (<Name>existing_value).name
1358-
if name in renames:
1359-
(<list>renames[name]).append(existing_name)
1387+
for name, name_value in context.names.items():
1388+
if type(name_value) is Name:
1389+
rename = (<Name>name_value).name
1390+
if rename in renames:
1391+
(<list>renames[rename]).append(name)
13601392
else:
1361-
renames[name] = [existing_name]
1393+
renames[rename] = [name]
13621394
for i, binding in enumerate(bindings):
1363-
for name in binding.names:
1364-
if name not in shadowed and name in renames:
1365-
for existing_name in <list>renames.pop(name):
1366-
context.names[existing_name] = None
1367-
remaining.append(PolyBinding((existing_name,), Name(name)))
1368-
shadowed.add(name)
1395+
for rename in binding.names:
1396+
if rename not in shadowed and rename in renames:
1397+
for name in <list>renames.pop(rename):
1398+
context.names[name] = None
1399+
remaining.append(PolyBinding((name,), Name(rename)))
1400+
shadowed.add(rename)
13691401
touched = True
13701402
n = len(binding.names)
13711403
expr = binding.expr._simplify(context)
@@ -1388,13 +1420,13 @@ cdef class Let(Expression):
13881420
continue
13891421
if n == 1 and type(expr) is Name:
13901422
name = <str>binding.names[0]
1391-
existing_name = (<Name>expr).name
1392-
if existing_name == name:
1423+
rename = (<Name>expr).name
1424+
if name == rename:
13931425
touched = True
13941426
continue
13951427
for j in range(i+1, len(bindings)):
1396-
if existing_name in (<Binding>bindings[j]).names:
1397-
shadowed.add(existing_name)
1428+
if rename in (<Binding>bindings[j]).names:
1429+
shadowed.add(rename)
13981430
break
13991431
else:
14001432
context.names[name] = expr

0 commit comments

Comments
 (0)