Skip to content

Commit 09ba379

Browse files
committed
Make an attempt to inline recursive functions if some arguments literal
Now attempt recursive inlining if the arguments aren't all dynamic. Monitor inlining depth and raise a recursion error if it's getting out of hand. This also requires fixing all simplifier name binding to unwind the bindings if an exception is raised. New tests of this functionality and generally tidying-up of simplifier call tests. Fixes #69
1 parent 23f3832 commit 09ba379

File tree

3 files changed

+110
-56
lines changed

3 files changed

+110
-56
lines changed

src/flitter/language/tree.pyx

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ from .vm cimport Program, Instruction, InstructionInt, InstructionVector, OpCode
2020
logger = name_patch(logger, __name__)
2121

2222
cdef Literal NoOp = Literal(null_)
23+
cdef int64_t MAX_RECURSIVE_CALL_DEPTH = 500
2324

2425

2526
cdef bint sequence_pack(list expressions):
@@ -231,8 +232,10 @@ cdef class Import(Expression):
231232
cdef Expression expr = self.expr
232233
if let_names:
233234
expr = Let(tuple(PolyBinding((name,), value if isinstance(value, Function) else Literal(value)) for name, value in let_names.items()), expr)
234-
expr = expr._simplify(context)
235-
context.names = saved
235+
try:
236+
expr = expr._simplify(context)
237+
finally:
238+
context.names = saved
236239
if not remaining:
237240
return expr
238241
if filename is self.filename and expr is self.expr:
@@ -992,29 +995,33 @@ cdef class Call(Expression):
992995
cdef bint literal_func = isinstance(function, Literal)
993996
if literal_func and not (<Literal>function).value.objects:
994997
return NoOp
995-
cdef bint literal_args = True
998+
cdef bint all_literal_args=True, all_dynamic_args=True
996999
cdef Expression arg, sarg, expr
9971000
cdef list args = []
9981001
if self.args:
9991002
for arg in self.args:
10001003
sarg = arg._simplify(context)
10011004
touched |= sarg is not arg
10021005
args.append(sarg)
1003-
if not isinstance(sarg, Literal):
1004-
literal_args = False
1006+
if isinstance(sarg, Literal):
1007+
all_dynamic_args = False
1008+
else:
1009+
all_literal_args = False
10051010
cdef list keyword_args = []
10061011
cdef Binding binding
10071012
if self.keyword_args:
10081013
for binding in self.keyword_args:
10091014
arg = binding.expr._simplify(context)
10101015
touched |= arg is not binding.expr
10111016
keyword_args.append(Binding(binding.name, arg))
1012-
if not isinstance(arg, Literal):
1013-
literal_args = False
1017+
if isinstance(arg, Literal):
1018+
all_dynamic_args = False
1019+
else:
1020+
all_literal_args = False
10141021
cdef list bindings
10151022
cdef dict kwargs
10161023
cdef int64_t i
1017-
if func_expr is not None and not func_expr.captures and (not func_expr.recursive or literal_args):
1024+
if func_expr is not None and not func_expr.captures and not (func_expr.recursive and all_dynamic_args):
10181025
kwargs = {binding.name: binding.expr for binding in keyword_args}
10191026
bindings = []
10201027
for i, binding in enumerate(func_expr.parameters):
@@ -1026,11 +1033,26 @@ cdef class Call(Expression):
10261033
bindings.append(PolyBinding((binding.name,), binding.expr))
10271034
else:
10281035
bindings.append(PolyBinding((binding.name,), Literal(null_)))
1029-
expr = Let(tuple(bindings), func_expr.body)._simplify(context)
1030-
return expr
1036+
if func_expr.recursive:
1037+
if context.call_depth == 0:
1038+
context.call_depth = 1
1039+
try:
1040+
return Let(tuple(bindings), func_expr.body)._simplify(context)
1041+
except RecursionError:
1042+
pass
1043+
context.call_depth = 0
1044+
elif context.call_depth == MAX_RECURSIVE_CALL_DEPTH:
1045+
raise RecursionError()
1046+
else:
1047+
context.call_depth += 1
1048+
expr = Let(tuple(bindings), func_expr.body)._simplify(context)
1049+
context.call_depth -= 1
1050+
return expr
1051+
else:
1052+
return Let(tuple(bindings), func_expr.body)._simplify(context)
10311053
cdef list vector_args, results
10321054
cdef Literal literal_arg
1033-
if literal_func and literal_args:
1055+
if literal_func and all_literal_args:
10341056
vector_args = [literal_arg.value for literal_arg in args]
10351057
kwargs = {binding.name: (<Literal>binding.expr).value for binding in keyword_args}
10361058
results = []
@@ -1364,9 +1386,12 @@ cdef class Let(Expression):
13641386
context.names[name] = None
13651387
remaining.append(PolyBinding(binding.names, expr))
13661388
cdef bint resimplify = shadowed and shadowed & discarded
1367-
cdef Expression sbody = body._simplify(context)
1368-
touched |= sbody is not body
1369-
context.names = saved
1389+
cdef Expression sbody
1390+
try:
1391+
sbody = body._simplify(context)
1392+
touched |= sbody is not body
1393+
finally:
1394+
context.names = saved
13701395
if isinstance(sbody, Literal):
13711396
return sbody
13721397
if isinstance(sbody, Let):
@@ -1432,19 +1457,23 @@ cdef class For(Expression):
14321457
if not isinstance(source, Literal):
14331458
for name in self.names:
14341459
context.names[name] = None
1435-
body = self.body._simplify(context)
1436-
context.names = saved
1460+
try:
1461+
body = self.body._simplify(context)
1462+
finally:
1463+
context.names = saved
14371464
if source is self.source and body is self.body:
14381465
return self
14391466
return For(self.names, source, body)
14401467
values = (<Literal>source).value
14411468
cdef int64_t i=0, n=values.length
1442-
while i < n:
1443-
for name in self.names:
1444-
context.names[name] = values.item(i) if i < n else null_
1445-
i += 1
1446-
remaining.append(self.body._simplify(context))
1447-
context.names = saved
1469+
try:
1470+
while i < n:
1471+
for name in self.names:
1472+
context.names[name] = values.item(i) if i < n else null_
1473+
i += 1
1474+
remaining.append(self.body._simplify(context))
1475+
finally:
1476+
context.names = saved
14481477
sequence_pack(remaining)
14491478
if not remaining:
14501479
return NoOp
@@ -1611,7 +1640,10 @@ cdef class Function(Expression):
16111640
for parameter in parameters:
16121641
context.names[parameter.name] = None
16131642
bound_names.add(parameter.name)
1614-
body = self.body._simplify(context)
1643+
try:
1644+
body = self.body._simplify(context)
1645+
finally:
1646+
context.names = saved
16151647
cdef set captures = body._unbound_names(bound_names)
16161648
touched |= body is not self.body
16171649
cdef recursive = self.name in captures
@@ -1621,7 +1653,6 @@ cdef class Function(Expression):
16211653
cdef tuple captures_t = tuple(captures)
16221654
cdef bint inlineable = literal and not captures_t
16231655
touched |= inlineable != self.inlineable
1624-
context.names = saved
16251656
touched |= captures_t != self.captures
16261657
if not touched:
16271658
return self

src/flitter/model.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,4 @@ cdef class Context:
245245
cdef readonly object stack
246246
cdef readonly object lnames
247247
cdef readonly set dependencies
248+
cdef readonly int64_t call_depth

tests/test_simplifier.py

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import unittest
88
import unittest.mock
99

10+
from flitter import configure_logger
1011
from flitter.model import Vector, Node, StateDict, null, true, false
1112
from flitter.language import functions
1213
from flitter.language.tree import (Literal, Name, Sequence,
@@ -20,6 +21,9 @@
2021
Binding, PolyBinding, IfCondition)
2122

2223

24+
configure_logger('ERROR')
25+
26+
2327
class SimplifierTestCase(unittest.TestCase):
2428
def assertSimplifiesTo(self, x, y, state=None, dynamic=None, static=None, with_errors=None, with_dependencies=None, unbound=None):
2529
aliases = set()
@@ -1036,58 +1040,76 @@ def test_non_callable_literal(self):
10361040

10371041
def test_simple_named_inlining(self):
10381042
"""Calls to names that resolve to Function objects are inlined as let expressions"""
1039-
func = Function('func', (Binding('x', Literal(null)),), Add(Name('x'), Literal(5)), captures=(), inlineable=True)
1040-
self.assertSimplifiesTo(Call(Name('func'), (Add(Literal(1), Name('y')),), ()),
1043+
f = Function('f', (Binding('x', Literal(null)),), Add(Name('x'), Literal(5)), captures=(), inlineable=True)
1044+
self.assertSimplifiesTo(Call(Name('f'), (Add(Literal(1), Name('y')),), ()),
10411045
Let((PolyBinding(('x',), Add(Literal(1), Name('y'))),), Add(Name('x'), Literal(5))),
1042-
static={'func': func}, dynamic={'y'})
1046+
static={'f': f}, dynamic={'y'})
10431047

10441048
def test_simple_anonymous_inlining(self):
10451049
"""Direct calls to anonymous functions are inlined as let expressions"""
1046-
func = Function('<anon>', (Binding('x', Literal(null)),), Add(Name('x'), Literal(5)), captures=(), inlineable=True)
1047-
self.assertSimplifiesTo(Call(func, (Add(Literal(1), Name('y')),), ()),
1050+
f = Function('<anon>', (Binding('x', Literal(null)),), Add(Name('x'), Literal(5)), captures=(), inlineable=True)
1051+
self.assertSimplifiesTo(Call(f, (Add(Literal(1), Name('y')),), ()),
10481052
Let((PolyBinding(('x',), Add(Literal(1), Name('y'))),), Add(Name('x'), Literal(5))),
10491053
dynamic={'y'})
10501054

10511055
def test_simple_inlined_missing_parameter_default(self):
10521056
"""An inlined function with a default parameter value used"""
1053-
func = Function('func', (Binding('x', None), Binding('y', Literal(1))), Add(Name('x'), Name('y')), captures=(), inlineable=True)
1054-
self.assertSimplifiesTo(Call(Name('func'), (Literal(5),), ()),
1057+
f = Function('f', (Binding('x', None), Binding('y', Literal(1))), Add(Name('x'), Name('y')), captures=(), inlineable=True)
1058+
self.assertSimplifiesTo(Call(Name('f'), (Literal(5),), ()),
10551059
Literal(6),
1056-
static={'func': func})
1060+
static={'f': f})
10571061

10581062
def test_simple_inlined_keyword_argument(self):
10591063
"""An inlined function with a keyword argument given"""
1060-
func = Function('func', (Binding('x', None), Binding('y', Literal(1))), Add(Name('x'), Name('y')), captures=(), inlineable=True)
1061-
self.assertSimplifiesTo(Call(Name('func'), (Literal(1),), (Binding('y', Literal(2)),)),
1064+
f = Function('f', (Binding('x', None), Binding('y', Literal(1))), Add(Name('x'), Name('y')), captures=(), inlineable=True)
1065+
self.assertSimplifiesTo(Call(Name('f'), (Literal(1),), (Binding('y', Literal(2)),)),
10621066
Literal(3),
1063-
static={'func': func})
1067+
static={'f': f})
10641068

10651069
def test_simple_inlined_missing_default_parameter_used(self):
10661070
"""An inlined function with a keyword argument given"""
1067-
func = Function('func', (Binding('x', None), Binding('y', Literal(1))), Add(Name('x'), Name('y')), captures=(), inlineable=True)
1068-
self.assertSimplifiesTo(Call(Name('func'), (), ()),
1071+
f = Function('f', (Binding('x', None), Binding('y', Literal(1))), Add(Name('x'), Name('y')), captures=(), inlineable=True)
1072+
self.assertSimplifiesTo(Call(Name('f'), (), ()),
10691073
Literal(null),
1070-
static={'func': func})
1071-
1072-
def test_inlineable_recursive_non_literal(self):
1073-
"""Calls to inlineable, recursive functions are *not* inlined if arguments are not all literal"""
1074-
func = Function(
1075-
'func',
1076-
(Binding('x', Literal(null)),),
1077-
IfElse((IfCondition(GreaterThan(Name('x'), Literal(0)), Add(Name('x'), Call(Name('func'), (Subtract(Name('x'), Literal(1)),)))),), Literal(0)),
1078-
captures=(), inlineable=True, recursive=True
1079-
)
1080-
self.assertSimplifiesTo(Call(Name('func'), (Name('y'),)), Call(Name('func'), (Name('y'),)), static={'func': func}, dynamic={'y'})
1081-
1082-
def test_inlineable_recursive_literal(self):
1083-
"""Calls to inlineable, recursive functions *are* inlined if arguments are all literal"""
1084-
func = Function(
1085-
'func',
1086-
(Binding('x', Literal(null)),),
1087-
IfElse((IfCondition(GreaterThan(Name('x'), Literal(0)), Add(Name('x'), Call(Name('func'), (Subtract(Name('x'), Literal(1)),)))),), Literal(0)),
1074+
static={'f': f})
1075+
1076+
1077+
class TestRecursiveCall(SimplifierTestCase):
1078+
def setUp(self):
1079+
self.f = Function(
1080+
'f',
1081+
(Binding('x', Literal(null)), Binding('y', Literal(null))),
1082+
IfElse((IfCondition(GreaterThan(Name('x'), Literal(0)),
1083+
Add(Name('y'), Call(Name('f'), (Subtract(Name('x'), Literal(1)), Divide(Name('y'), Literal(2)))))),),
1084+
Literal(0)),
10881085
captures=(), inlineable=True, recursive=True
10891086
)
1090-
self.assertSimplifiesTo(Call(Name('func'), (Literal(5),)), Literal(15), static={'func': func})
1087+
1088+
def test_inlineable_recursive_all_non_literal(self):
1089+
"""A call to a recursive function with no literal arguments will not be inlined"""
1090+
self.assertSimplifiesTo(Call(Name('f'), (Name('z'), Name('w'))),
1091+
Call(Name('f'), (Name('z'), Name('w'))),
1092+
static={'f': self.f}, dynamic={'z', 'w'})
1093+
1094+
def test_inlineable_recursive_literal_bound(self):
1095+
"""A call to a recursive function with at least literal arguments will cause an inline attempt.
1096+
In this case the literal argument determines the bounds and so recursive inlining will succeed."""
1097+
self.assertSimplifiesTo(Call(Name('f'), (Literal(2), Name('w'))),
1098+
Add(Name('w'), Let((PolyBinding(('y',), Multiply(Literal(0.5), Name('w'))),), Positive(Name('y')))),
1099+
static={'f': self.f}, dynamic={'w'})
1100+
1101+
def test_inlineable_recursive_dynamic_bound(self):
1102+
"""A call to a recursive function with at least literal arguments will cause an inline attempt.
1103+
In this case the literal argument does not determine the bounds and so recursive inlining will fail."""
1104+
self.assertSimplifiesTo(Call(Name('f'), (Name('z'), Literal(1))),
1105+
Call(Name('f'), (Name('z'), Literal(1))),
1106+
static={'f': self.f}, dynamic={'z'})
1107+
1108+
def test_inlineable_recursive_all_literal(self):
1109+
"""A call to a recursive function with all literal arguments should be fully simplified"""
1110+
self.assertSimplifiesTo(Call(Name('f'), (Literal(4), Literal(32))),
1111+
Literal(32+16+8+4+0),
1112+
static={'f': self.f})
10911113

10921114

10931115
class TestFastFunctions(SimplifierTestCase):

0 commit comments

Comments
 (0)