Skip to content

Commit 7701e8e

Browse files
committed
A bunch of simplifier performance fixes
Most importantly, unused let bindings are *no longer removed*. It turns out that this was incredibly expensive to do for large programs containing a lot of top-level `let`s because of the amount of time spent calling `_unbound_names()`. The additional benefit of it isn't worth that cost.
1 parent 26bc870 commit 7701e8e

File tree

2 files changed

+72
-87
lines changed

2 files changed

+72
-87
lines changed

src/flitter/language/tree.pyx

Lines changed: 69 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ logger = name_patch(logger, __name__)
2121

2222
cdef Literal NoOp = Literal(null_)
2323
cdef int64_t MAX_RECURSIVE_CALL_DEPTH = 500
24+
cdef set EmptySet = set()
25+
cdef Vector Two = Vector(2)
2426

2527

2628
cdef bint sequence_pack(list expressions):
@@ -84,7 +86,7 @@ cdef class Expression:
8486
return expr
8587

8688
def unbound_names(self, bound_names=None):
87-
return self._unbound_names(set(bound_names) if bound_names is not None else set())
89+
return self._unbound_names(set(bound_names) if bound_names is not None else EmptySet)
8890

8991
cdef void _compile(self, Program program, list lnames):
9092
raise NotImplementedError()
@@ -170,7 +172,7 @@ cdef class Export(Expression):
170172
return self
171173

172174
cdef set _unbound_names(self, set names):
173-
return set([None])
175+
return EmptySet
174176

175177
def __repr__(self):
176178
return f'Export({self.static_exports!r})'
@@ -243,7 +245,7 @@ cdef class Import(Expression):
243245
return Import(tuple(remaining), filename, expr)
244246

245247
cdef set _unbound_names(self, set names):
246-
return self.filename._unbound_names(names) | self.expr._unbound_names(names | set(self.names))
248+
return self.filename._unbound_names(names).union(self.expr._unbound_names(names.union(self.names)))
247249

248250
def __repr__(self):
249251
return f'Import({self.names!r}, {self.filename!r}, {self.expr!r})'
@@ -291,7 +293,7 @@ cdef class Sequence(Expression):
291293
cdef set unbound = set()
292294
cdef Expression expr
293295
for expr in self.expressions:
294-
unbound |= expr._unbound_names(names)
296+
unbound.update(expr._unbound_names(names))
295297
return unbound
296298

297299
def __repr__(self):
@@ -311,7 +313,7 @@ cdef class Literal(Expression):
311313
return self
312314

313315
cdef set _unbound_names(self, set names):
314-
return set()
316+
return EmptySet
315317

316318
def __repr__(self):
317319
return f'Literal({self.value!r})'
@@ -357,10 +359,9 @@ cdef class Name(Expression):
357359
return self
358360

359361
cdef set _unbound_names(self, set names):
360-
cdef set unbound = set()
361362
if self.name not in names:
362-
unbound.add(self.name)
363-
return unbound
363+
return set([self.name])
364+
return EmptySet
364365

365366
def __repr__(self):
366367
return f'Name({self.name!r})'
@@ -427,7 +428,7 @@ cdef class Range(Expression):
427428
return Range(start, stop, step)
428429

429430
cdef set _unbound_names(self, set names):
430-
return self.start._unbound_names(names) | self.stop._unbound_names(names) | self.step._unbound_names(names)
431+
return self.start._unbound_names(names).union(self.stop._unbound_names(names)).union(self.step._unbound_names(names))
431432

432433
def __repr__(self):
433434
return f'Range({self.start!r}, {self.stop!r}, {self.step!r})'
@@ -574,11 +575,14 @@ cdef class BinaryOperation(Expression):
574575
cdef Expression _simplify(self, Context context):
575576
cdef Expression left = self.left._simplify(context)
576577
cdef Expression right = self.right._simplify(context)
577-
cdef bint literal_left = isinstance(left, Literal)
578-
cdef bint literal_right = isinstance(right, Literal)
578+
cdef bint literal_left = type(left) is Literal
579+
cdef bint literal_right = type(right) is Literal
579580
cdef Expression expr
581+
cdef Literal literal
580582
if literal_left and literal_right:
581-
return Literal(self.op((<Literal>left).value, (<Literal>right).value))
583+
literal = Literal.__new__(Literal)
584+
literal.value = self.op((<Literal>left).value, (<Literal>right).value)
585+
return literal
582586
elif literal_left:
583587
if (expr := self.constant_left((<Literal>left).value, right)) is not None:
584588
return expr._simplify(context)
@@ -589,7 +593,11 @@ cdef class BinaryOperation(Expression):
589593
return expr._simplify(context)
590594
if left is self.left and right is self.right:
591595
return self
592-
return type(self)(left, right)
596+
cdef type T = type(self)
597+
cdef BinaryOperation binary = <BinaryOperation>T.__new__(T)
598+
binary.left = left
599+
binary.right = right
600+
return binary
593601

594602
cdef Vector op(self, Vector left, Vector right):
595603
raise NotImplementedError()
@@ -604,7 +612,13 @@ cdef class BinaryOperation(Expression):
604612
return None
605613

606614
cdef set _unbound_names(self, set names):
607-
return self.left._unbound_names(names) | self.right._unbound_names(names)
615+
cdef set left = self.left._unbound_names(names)
616+
cdef set right = self.right._unbound_names(names)
617+
if not left:
618+
return right
619+
if not right:
620+
return left
621+
return left.union(right)
608622

609623
def __repr__(self):
610624
return f'{self.__class__.__name__}({self.left!r}, {self.right!r})'
@@ -613,11 +627,13 @@ cdef class BinaryOperation(Expression):
613627
cdef class MathsBinaryOperation(BinaryOperation):
614628
cdef Expression _simplify(self, Context context):
615629
cdef Expression expr=BinaryOperation._simplify(self, context)
630+
cdef MathsBinaryOperation binary
616631
if isinstance(expr, MathsBinaryOperation):
617-
if isinstance(expr.left, Positive):
618-
return (<Expression>type(expr)(expr.left.expr, expr.right))._simplify(context)
619-
elif isinstance(expr.right, Positive):
620-
return (<Expression>type(expr)(expr.left, expr.right.expr))._simplify(context)
632+
binary = <MathsBinaryOperation>expr
633+
if type(binary.left) is Positive:
634+
return (<Expression>type(binary)((<Positive>binary.left).expr, binary.right))._simplify(context)
635+
elif type(binary.right) is Positive:
636+
return (<Expression>type(binary)(binary.left, (<Positive>binary.right).expr))._simplify(context)
621637
return expr
622638

623639

@@ -633,9 +649,9 @@ cdef class Add(MathsBinaryOperation):
633649
program.add()
634650

635651
cdef Expression constant_left(self, Vector left, Expression right):
636-
if left.eq(null_):
652+
if left.length == 0:
637653
return NoOp
638-
if left.eq(false_):
654+
if left.eq(false_) is true_:
639655
return Positive(right)
640656

641657
cdef Expression constant_right(self, Expression left, Vector right):
@@ -658,13 +674,13 @@ cdef class Subtract(MathsBinaryOperation):
658674
program.sub()
659675

660676
cdef Expression constant_left(self, Vector left, Expression right):
661-
if left.eq(null_):
677+
if left.length == 0:
662678
return NoOp
663-
if left.eq(false_):
679+
if left.eq(false_) is true_:
664680
return Negative(right)
665681

666682
cdef Expression constant_right(self, Expression left, Vector right):
667-
if right.eq(false_):
683+
if right.eq(false_) is true_:
668684
return Positive(left)
669685
return Add(left, Literal(right.neg()))
670686

@@ -681,11 +697,11 @@ cdef class Multiply(MathsBinaryOperation):
681697
program.mul()
682698

683699
cdef Expression constant_left(self, Vector left, Expression right):
684-
if left.eq(null_):
700+
if left.length == 0:
685701
return NoOp
686-
if left.eq(true_):
702+
if left.eq(true_) is true_:
687703
return Positive(right)
688-
if left.eq(minusone_):
704+
if left.eq(minusone_) is true_:
689705
return Negative(right)
690706
cdef MathsBinaryOperation maths
691707
if isinstance(right, Add) or isinstance(right, Subtract):
@@ -717,13 +733,13 @@ cdef class Divide(MathsBinaryOperation):
717733
program.truediv()
718734

719735
cdef Expression constant_left(self, Vector left, Expression right):
720-
if left.eq(null_):
736+
if left.length == 0:
721737
return NoOp
722738

723739
cdef Expression constant_right(self, Expression left, Vector right):
724-
if right.eq(null_):
740+
if right.length == 0:
725741
return NoOp
726-
if right.eq(true_):
742+
if right.eq(true_) is true_:
727743
return Positive(left)
728744
return Multiply(Literal(true_.truediv(right)), left)
729745

@@ -736,13 +752,13 @@ cdef class FloorDivide(MathsBinaryOperation):
736752
program.floordiv()
737753

738754
cdef Expression constant_left(self, Vector left, Expression right):
739-
if left.eq(null_):
755+
if left.length == 0:
740756
return NoOp
741757

742758
cdef Expression constant_right(self, Expression left, Vector right):
743-
if right.eq(null_):
759+
if right.length == 0:
744760
return NoOp
745-
if right.eq(true_):
761+
if right.eq(true_) is true_:
746762
return Floor(left)
747763

748764

@@ -754,13 +770,13 @@ cdef class Modulo(MathsBinaryOperation):
754770
program.mod()
755771

756772
cdef Expression constant_left(self, Vector left, Expression right):
757-
if left.eq(null_):
773+
if left.length == 0:
758774
return NoOp
759775

760776
cdef Expression constant_right(self, Expression left, Vector right):
761-
if right.eq(null_):
777+
if right.length == 0:
762778
return NoOp
763-
if right.eq(true_):
779+
if right.eq(true_) is true_:
764780
return Fract(left)
765781

766782

@@ -770,21 +786,21 @@ cdef class Power(MathsBinaryOperation):
770786

771787
cdef void _compile_op(self, Program program):
772788
cdef Instruction instr = program.last_instruction()
773-
if instr.code == OpCode.Literal and (<InstructionVector>instr).value.eq(Vector(2)):
789+
if instr.code == OpCode.Literal and (<InstructionVector>instr).value.eq(Two) is true_:
774790
program.pop_instruction()
775791
program.dup()
776792
program.mul()
777793
else:
778794
program.pow()
779795

780796
cdef Expression constant_left(self, Vector left, Expression right):
781-
if left.eq(null_):
797+
if left.length == 0:
782798
return NoOp
783799

784800
cdef Expression constant_right(self, Expression left, Vector right):
785-
if right.eq(null_):
801+
if right.length == 0:
786802
return NoOp
787-
if right.eq(true_):
803+
if right.eq(true_) is true_:
788804
return Positive(left)
789805

790806

@@ -946,7 +962,7 @@ cdef class Slice(Expression):
946962
return Slice(expr, index)
947963

948964
cdef set _unbound_names(self, set names):
949-
return self.expr._unbound_names(names) | self.index._unbound_names(names)
965+
return self.expr._unbound_names(names).union(self.index._unbound_names(names))
950966

951967
def __repr__(self):
952968
return f'Slice({self.expr!r}, {self.index!r})'
@@ -1177,15 +1193,18 @@ cdef class Attributes(NodeModifier):
11771193
cdef Expression node = self
11781194
cdef list bindings = []
11791195
cdef Attributes attrs
1180-
cdef Binding binding
1196+
cdef Binding binding, simplified
11811197
cdef Expression expr
11821198
cdef bint touched = False
11831199
while isinstance(node, Attributes):
11841200
attrs = <Attributes>node
11851201
for binding in reversed(attrs.bindings):
11861202
expr = binding.expr._simplify(context)
1187-
bindings.append(Binding(binding.name, expr))
11881203
touched |= expr is not binding.expr
1204+
simplified = Binding.__new__(Binding)
1205+
simplified.name = binding.name
1206+
simplified.expr = expr
1207+
bindings.append(simplified)
11891208
node = attrs.node
11901209
node = node._simplify(context)
11911210
touched |= node is not self.node
@@ -1200,7 +1219,7 @@ cdef class Attributes(NodeModifier):
12001219
for obj in nodes.objects:
12011220
objects.append((<Node>obj).copy() if isinstance(obj, Node) else obj)
12021221
while bindings and isinstance((<Binding>bindings[-1]).expr, Literal):
1203-
binding = bindings.pop()
1222+
binding = <Binding>bindings.pop()
12041223
value = (<Literal>binding.expr).value
12051224
for obj in objects:
12061225
if isinstance(obj, Node):
@@ -1279,7 +1298,7 @@ cdef class Append(NodeModifier):
12791298
return Append(node, children)
12801299

12811300
cdef set _unbound_names(self, set names):
1282-
return self.node._unbound_names(names) | self.children._unbound_names(names)
1301+
return self.node._unbound_names(names).union(self.children._unbound_names(names))
12831302

12841303
def __repr__(self):
12851304
return f'{self.__class__.__name__}({self.node!r}, {self.children!r})'
@@ -1349,7 +1368,7 @@ cdef class Let(Expression):
13491368
cdef Vector value
13501369
cdef str name, existing_name
13511370
cdef int64_t i, j, n
1352-
cdef touched = False
1371+
cdef bint touched = False
13531372
cdef set shadowed=set(), discarded=set()
13541373
while isinstance(body, Let):
13551374
bindings.extend((<Let>body).bindings)
@@ -1375,7 +1394,7 @@ cdef class Let(Expression):
13751394
if isinstance(expr, Literal):
13761395
value = (<Literal>expr).value
13771396
if n == 1:
1378-
name = binding.names[0]
1397+
name = <str>binding.names[0]
13791398
context.names[name] = value
13801399
discarded.add(name)
13811400
else:
@@ -1385,7 +1404,7 @@ cdef class Let(Expression):
13851404
touched = True
13861405
continue
13871406
if n == 1 and isinstance(expr, Name):
1388-
name = binding.names[0]
1407+
name = <str>binding.names[0]
13891408
if (<Name>expr).name == name:
13901409
touched = True
13911410
continue
@@ -1411,19 +1430,6 @@ cdef class Let(Expression):
14111430
context.names = saved
14121431
if isinstance(sbody, Literal):
14131432
return sbody
1414-
cdef set unbound = sbody._unbound_names(set())
1415-
cdef list original
1416-
if None not in unbound:
1417-
original = remaining
1418-
remaining = []
1419-
for binding in reversed(original):
1420-
for name in binding.names:
1421-
if name in unbound:
1422-
remaining.insert(0, binding)
1423-
unbound |= binding.expr.unbound_names(set())
1424-
break
1425-
else:
1426-
touched = True
14271433
if isinstance(sbody, Let):
14281434
remaining.extend((<Let>sbody).bindings)
14291435
sbody = (<Let>sbody).body
@@ -1514,7 +1520,7 @@ cdef class For(Expression):
15141520
cdef set _unbound_names(self, set names):
15151521
cdef set unbound = set()
15161522
unbound.update(self.source._unbound_names(names))
1517-
unbound.update(self.body._unbound_names(names | set(self.names)))
1523+
unbound.update(self.body._unbound_names(names.union(self.names)))
15181524
return unbound
15191525

15201526
def __repr__(self):
@@ -1676,14 +1682,14 @@ cdef class Function(Expression):
16761682
context.names = saved
16771683
cdef set captures = body._unbound_names(bound_names)
16781684
touched |= body is not self.body
1679-
cdef recursive = self.name in captures
1685+
cdef bint recursive = self.name in captures
16801686
if recursive:
16811687
captures.discard(self.name)
16821688
touched |= recursive != self.recursive
16831689
cdef tuple captures_t = tuple(captures)
16841690
cdef bint inlineable = literal and not captures_t
16851691
touched |= inlineable != self.inlineable
1686-
touched |= captures_t != self.captures
1692+
touched |= <bint>(captures_t != self.captures)
16871693
if not touched:
16881694
return self
16891695
return Function(self.name, tuple(parameters), body, captures_t, inlineable, recursive)

0 commit comments

Comments
 (0)