@@ -21,6 +21,8 @@ logger = name_patch(logger, __name__)
2121
2222cdef Literal NoOp = Literal(null_)
2323cdef int64_t MAX_RECURSIVE_CALL_DEPTH = 500
24+ cdef set EmptySet = set ()
25+ cdef Vector Two = Vector(2 )
2426
2527
2628cdef bint sequence_pack(list expressions):
@@ -84,7 +86,7 @@ cdef class Expression:
8486 return expr
8587
8688 def unbound_names (self , bound_names = None ):
87- return self ._unbound_names(set (bound_names) if bound_names is not None else set () )
89+ return self ._unbound_names(set (bound_names) if bound_names is not None else EmptySet )
8890
8991 cdef void _compile(self , Program program, list lnames):
9092 raise NotImplementedError ()
@@ -170,7 +172,7 @@ cdef class Export(Expression):
170172 return self
171173
172174 cdef set _unbound_names(self , set names):
173- return set ([ None ])
175+ return EmptySet
174176
175177 def __repr__ (self ):
176178 return f' Export({self.static_exports!r})'
@@ -243,7 +245,7 @@ cdef class Import(Expression):
243245 return Import(tuple (remaining), filename, expr)
244246
245247 cdef set _unbound_names(self , set names):
246- return self .filename._unbound_names(names) | self .expr._unbound_names(names | set (self .names))
248+ return self .filename._unbound_names(names). union ( self .expr._unbound_names(names.union (self .names) ))
247249
248250 def __repr__ (self ):
249251 return f' Import({self.names!r}, {self.filename!r}, {self.expr!r})'
@@ -291,7 +293,7 @@ cdef class Sequence(Expression):
291293 cdef set unbound = set ()
292294 cdef Expression expr
293295 for expr in self .expressions:
294- unbound |= expr._unbound_names(names)
296+ unbound.update( expr._unbound_names(names) )
295297 return unbound
296298
297299 def __repr__ (self ):
@@ -311,7 +313,7 @@ cdef class Literal(Expression):
311313 return self
312314
313315 cdef set _unbound_names(self , set names):
314- return set ()
316+ return EmptySet
315317
316318 def __repr__ (self ):
317319 return f' Literal({self.value!r})'
@@ -357,10 +359,9 @@ cdef class Name(Expression):
357359 return self
358360
359361 cdef set _unbound_names(self , set names):
360- cdef set unbound = set ()
361362 if self .name not in names:
362- unbound.add( self .name)
363- return unbound
363+ return set ([ self .name] )
364+ return EmptySet
364365
365366 def __repr__ (self ):
366367 return f' Name({self.name!r})'
@@ -427,7 +428,7 @@ cdef class Range(Expression):
427428 return Range(start, stop, step)
428429
429430 cdef set _unbound_names(self , set names):
430- return self .start._unbound_names(names) | self .stop._unbound_names(names) | self .step._unbound_names(names)
431+ return self .start._unbound_names(names). union ( self .stop._unbound_names(names)). union ( self .step._unbound_names(names) )
431432
432433 def __repr__ (self ):
433434 return f' Range({self.start!r}, {self.stop!r}, {self.step!r})'
@@ -574,11 +575,14 @@ cdef class BinaryOperation(Expression):
574575 cdef Expression _simplify(self , Context context):
575576 cdef Expression left = self .left._simplify(context)
576577 cdef Expression right = self .right._simplify(context)
577- cdef bint literal_left = isinstance (left, Literal)
578- cdef bint literal_right = isinstance (right, Literal)
578+ cdef bint literal_left = type (left) is Literal
579+ cdef bint literal_right = type (right) is Literal
579580 cdef Expression expr
581+ cdef Literal literal
580582 if literal_left and literal_right:
581- return Literal(self .op((< Literal> left).value, (< Literal> right).value))
583+ literal = Literal.__new__ (Literal)
584+ literal.value = self .op((< Literal> left).value, (< Literal> right).value)
585+ return literal
582586 elif literal_left:
583587 if (expr := self .constant_left((< Literal> left).value, right)) is not None :
584588 return expr._simplify(context)
@@ -589,7 +593,11 @@ cdef class BinaryOperation(Expression):
589593 return expr._simplify(context)
590594 if left is self .left and right is self .right:
591595 return self
592- return type (self )(left, right)
596+ cdef type T = type (self )
597+ cdef BinaryOperation binary = < BinaryOperation> T.__new__ (T)
598+ binary.left = left
599+ binary.right = right
600+ return binary
593601
594602 cdef Vector op(self , Vector left, Vector right):
595603 raise NotImplementedError ()
@@ -604,7 +612,13 @@ cdef class BinaryOperation(Expression):
604612 return None
605613
606614 cdef set _unbound_names(self , set names):
607- return self .left._unbound_names(names) | self .right._unbound_names(names)
615+ cdef set left = self .left._unbound_names(names)
616+ cdef set right = self .right._unbound_names(names)
617+ if not left:
618+ return right
619+ if not right:
620+ return left
621+ return left.union(right)
608622
609623 def __repr__ (self ):
610624 return f' {self.__class__.__name__}({self.left!r}, {self.right!r})'
@@ -613,11 +627,13 @@ cdef class BinaryOperation(Expression):
613627cdef class MathsBinaryOperation(BinaryOperation):
614628 cdef Expression _simplify(self , Context context):
615629 cdef Expression expr= BinaryOperation._simplify(self , context)
630+ cdef MathsBinaryOperation binary
616631 if isinstance (expr, MathsBinaryOperation):
617- if isinstance (expr.left, Positive):
618- return (< Expression> type (expr)(expr.left.expr, expr.right))._simplify(context)
619- elif isinstance (expr.right, Positive):
620- return (< Expression> type (expr)(expr.left, expr.right.expr))._simplify(context)
632+ binary = < MathsBinaryOperation> expr
633+ if type (binary.left) is Positive:
634+ return (< Expression> type (binary)((< Positive> binary.left).expr, binary.right))._simplify(context)
635+ elif type (binary.right) is Positive:
636+ return (< Expression> type (binary)(binary.left, (< Positive> binary.right).expr))._simplify(context)
621637 return expr
622638
623639
@@ -633,9 +649,9 @@ cdef class Add(MathsBinaryOperation):
633649 program.add()
634650
635651 cdef Expression constant_left(self , Vector left, Expression right):
636- if left.eq(null_) :
652+ if left.length == 0 :
637653 return NoOp
638- if left.eq(false_):
654+ if left.eq(false_) is true_ :
639655 return Positive(right)
640656
641657 cdef Expression constant_right(self , Expression left, Vector right):
@@ -658,13 +674,13 @@ cdef class Subtract(MathsBinaryOperation):
658674 program.sub()
659675
660676 cdef Expression constant_left(self , Vector left, Expression right):
661- if left.eq(null_) :
677+ if left.length == 0 :
662678 return NoOp
663- if left.eq(false_):
679+ if left.eq(false_) is true_ :
664680 return Negative(right)
665681
666682 cdef Expression constant_right(self , Expression left, Vector right):
667- if right.eq(false_):
683+ if right.eq(false_) is true_ :
668684 return Positive(left)
669685 return Add(left, Literal(right.neg()))
670686
@@ -681,11 +697,11 @@ cdef class Multiply(MathsBinaryOperation):
681697 program.mul()
682698
683699 cdef Expression constant_left(self , Vector left, Expression right):
684- if left.eq(null_) :
700+ if left.length == 0 :
685701 return NoOp
686- if left.eq(true_):
702+ if left.eq(true_) is true_ :
687703 return Positive(right)
688- if left.eq(minusone_):
704+ if left.eq(minusone_) is true_ :
689705 return Negative(right)
690706 cdef MathsBinaryOperation maths
691707 if isinstance (right, Add) or isinstance (right, Subtract):
@@ -717,13 +733,13 @@ cdef class Divide(MathsBinaryOperation):
717733 program.truediv()
718734
719735 cdef Expression constant_left(self , Vector left, Expression right):
720- if left.eq(null_) :
736+ if left.length == 0 :
721737 return NoOp
722738
723739 cdef Expression constant_right(self , Expression left, Vector right):
724- if right.eq(null_) :
740+ if right.length == 0 :
725741 return NoOp
726- if right.eq(true_):
742+ if right.eq(true_) is true_ :
727743 return Positive(left)
728744 return Multiply(Literal(true_.truediv(right)), left)
729745
@@ -736,13 +752,13 @@ cdef class FloorDivide(MathsBinaryOperation):
736752 program.floordiv()
737753
738754 cdef Expression constant_left(self , Vector left, Expression right):
739- if left.eq(null_) :
755+ if left.length == 0 :
740756 return NoOp
741757
742758 cdef Expression constant_right(self , Expression left, Vector right):
743- if right.eq(null_) :
759+ if right.length == 0 :
744760 return NoOp
745- if right.eq(true_):
761+ if right.eq(true_) is true_ :
746762 return Floor(left)
747763
748764
@@ -754,13 +770,13 @@ cdef class Modulo(MathsBinaryOperation):
754770 program.mod()
755771
756772 cdef Expression constant_left(self , Vector left, Expression right):
757- if left.eq(null_) :
773+ if left.length == 0 :
758774 return NoOp
759775
760776 cdef Expression constant_right(self , Expression left, Vector right):
761- if right.eq(null_) :
777+ if right.length == 0 :
762778 return NoOp
763- if right.eq(true_):
779+ if right.eq(true_) is true_ :
764780 return Fract(left)
765781
766782
@@ -770,21 +786,21 @@ cdef class Power(MathsBinaryOperation):
770786
771787 cdef void _compile_op(self , Program program):
772788 cdef Instruction instr = program.last_instruction()
773- if instr.code == OpCode.Literal and (< InstructionVector> instr).value.eq(Vector( 2 )) :
789+ if instr.code == OpCode.Literal and (< InstructionVector> instr).value.eq(Two) is true_ :
774790 program.pop_instruction()
775791 program.dup()
776792 program.mul()
777793 else :
778794 program.pow()
779795
780796 cdef Expression constant_left(self , Vector left, Expression right):
781- if left.eq(null_) :
797+ if left.length == 0 :
782798 return NoOp
783799
784800 cdef Expression constant_right(self , Expression left, Vector right):
785- if right.eq(null_) :
801+ if right.length == 0 :
786802 return NoOp
787- if right.eq(true_):
803+ if right.eq(true_) is true_ :
788804 return Positive(left)
789805
790806
@@ -946,7 +962,7 @@ cdef class Slice(Expression):
946962 return Slice(expr, index)
947963
948964 cdef set _unbound_names(self , set names):
949- return self .expr._unbound_names(names) | self .index._unbound_names(names)
965+ return self .expr._unbound_names(names). union ( self .index._unbound_names(names) )
950966
951967 def __repr__ (self ):
952968 return f' Slice({self.expr!r}, {self.index!r})'
@@ -1177,15 +1193,18 @@ cdef class Attributes(NodeModifier):
11771193 cdef Expression node = self
11781194 cdef list bindings = []
11791195 cdef Attributes attrs
1180- cdef Binding binding
1196+ cdef Binding binding, simplified
11811197 cdef Expression expr
11821198 cdef bint touched = False
11831199 while isinstance (node, Attributes):
11841200 attrs = < Attributes> node
11851201 for binding in reversed (attrs.bindings):
11861202 expr = binding.expr._simplify(context)
1187- bindings.append(Binding(binding.name, expr))
11881203 touched |= expr is not binding.expr
1204+ simplified = Binding.__new__ (Binding)
1205+ simplified.name = binding.name
1206+ simplified.expr = expr
1207+ bindings.append(simplified)
11891208 node = attrs.node
11901209 node = node._simplify(context)
11911210 touched |= node is not self .node
@@ -1200,7 +1219,7 @@ cdef class Attributes(NodeModifier):
12001219 for obj in nodes.objects:
12011220 objects.append((< Node> obj).copy() if isinstance (obj, Node) else obj)
12021221 while bindings and isinstance ((< Binding> bindings[- 1 ]).expr, Literal):
1203- binding = bindings.pop()
1222+ binding = < Binding > bindings.pop()
12041223 value = (< Literal> binding.expr).value
12051224 for obj in objects:
12061225 if isinstance (obj, Node):
@@ -1279,7 +1298,7 @@ cdef class Append(NodeModifier):
12791298 return Append(node, children)
12801299
12811300 cdef set _unbound_names(self , set names):
1282- return self .node._unbound_names(names) | self .children._unbound_names(names)
1301+ return self .node._unbound_names(names). union ( self .children._unbound_names(names) )
12831302
12841303 def __repr__ (self ):
12851304 return f' {self.__class__.__name__}({self.node!r}, {self.children!r})'
@@ -1349,7 +1368,7 @@ cdef class Let(Expression):
13491368 cdef Vector value
13501369 cdef str name, existing_name
13511370 cdef int64_t i, j, n
1352- cdef touched = False
1371+ cdef bint touched = False
13531372 cdef set shadowed= set (), discarded= set ()
13541373 while isinstance (body, Let):
13551374 bindings.extend((< Let> body).bindings)
@@ -1375,7 +1394,7 @@ cdef class Let(Expression):
13751394 if isinstance (expr, Literal):
13761395 value = (< Literal> expr).value
13771396 if n == 1 :
1378- name = binding.names[0 ]
1397+ name = < str > binding.names[0 ]
13791398 context.names[name] = value
13801399 discarded.add(name)
13811400 else :
@@ -1385,7 +1404,7 @@ cdef class Let(Expression):
13851404 touched = True
13861405 continue
13871406 if n == 1 and isinstance (expr, Name):
1388- name = binding.names[0 ]
1407+ name = < str > binding.names[0 ]
13891408 if (< Name> expr).name == name:
13901409 touched = True
13911410 continue
@@ -1411,19 +1430,6 @@ cdef class Let(Expression):
14111430 context.names = saved
14121431 if isinstance (sbody, Literal):
14131432 return sbody
1414- cdef set unbound = sbody._unbound_names(set ())
1415- cdef list original
1416- if None not in unbound:
1417- original = remaining
1418- remaining = []
1419- for binding in reversed (original):
1420- for name in binding.names:
1421- if name in unbound:
1422- remaining.insert(0 , binding)
1423- unbound |= binding.expr.unbound_names(set ())
1424- break
1425- else :
1426- touched = True
14271433 if isinstance (sbody, Let):
14281434 remaining.extend((< Let> sbody).bindings)
14291435 sbody = (< Let> sbody).body
@@ -1514,7 +1520,7 @@ cdef class For(Expression):
15141520 cdef set _unbound_names(self , set names):
15151521 cdef set unbound = set ()
15161522 unbound.update(self .source._unbound_names(names))
1517- unbound.update(self .body._unbound_names(names | set (self .names)))
1523+ unbound.update(self .body._unbound_names(names.union (self .names)))
15181524 return unbound
15191525
15201526 def __repr__ (self ):
@@ -1676,14 +1682,14 @@ cdef class Function(Expression):
16761682 context.names = saved
16771683 cdef set captures = body._unbound_names(bound_names)
16781684 touched |= body is not self .body
1679- cdef recursive = self .name in captures
1685+ cdef bint recursive = self .name in captures
16801686 if recursive:
16811687 captures.discard(self .name)
16821688 touched |= recursive != self .recursive
16831689 cdef tuple captures_t = tuple (captures)
16841690 cdef bint inlineable = literal and not captures_t
16851691 touched |= inlineable != self .inlineable
1686- touched |= captures_t != self .captures
1692+ touched |= < bint > ( captures_t != self .captures)
16871693 if not touched:
16881694 return self
16891695 return Function(self .name, tuple (parameters), body, captures_t, inlineable, recursive)
0 commit comments