diff --git a/pythran/optimizations/pattern_transform.py b/pythran/optimizations/pattern_transform.py index 3055cb0c2..685437ddf 100644 --- a/pythran/optimizations/pattern_transform.py +++ b/pythran/optimizations/pattern_transform.py @@ -66,27 +66,37 @@ ast.Num(n=-1)], keywords=[])), - # X * X => X ** 2 - (ast.BinOp(left=Placeholder(0), op=ast.Mult(), right=Placeholder(0)), - lambda: ast.BinOp(left=Placeholder(0), op=ast.Pow(), right=ast.Num(n=2))), - - # a + "..." + b => "...".join((a, b)) - (ast.BinOp(left=ast.BinOp(left=Placeholder(0), - op=ast.Add(), - right=ast.Str(Placeholder(1))), - op=ast.Add(), - right=Placeholder(2)), - lambda: ast.Call(func=ast.Attribute( - ast.Attribute( - ast.Name('__builtin__', ast.Load(), None), - 'str', - ast.Load()), - 'join', ast.Load()), - args=[ast.Str(Placeholder(1)), - ast.Tuple([Placeholder(0), Placeholder(2)], ast.Load())], - keywords=[])), ] +# Dictionary with ast operator name as keys for each a list of tuple of +# (left, right, replacement) is defined. +# replacement have to be a lambda function to have a new ast to replace when +# replacement is inserted in main ast +know_pattern_BinOp = { + ast.Mult.__name__ : [ + (Placeholder(0), ast.Num(1), lambda: Placeholder(0)), # X * 1 => X + (ast.Num(1), Placeholder(0), lambda: Placeholder(0)), # 1 * X => X + (Placeholder(0), Placeholder(0), lambda: ast.BinOp(left=Placeholder(0), op=ast.Pow(), right=ast.Num(n=2))), # X * X => X ** 2 + ], + ast.Add.__name__ : [ + (Placeholder(0), ast.Num(0), lambda: Placeholder(0)), # X + 0 => X + (ast.Num(0), Placeholder(0), lambda: Placeholder(0)), # 0 + X => X + ( # a + "..." + b => "...".join((a, b)) + ast.BinOp(left=Placeholder(0), op=ast.Add(), right=ast.Str(Placeholder(1))), + Placeholder(2), + lambda: ast.Call(func=ast.Attribute( + ast.Attribute(ast.Name('__builtin__', ast.Load(), None),'str',ast.Load()),'join', ast.Load()), + args=[ast.Str(Placeholder(1)),ast.Tuple([Placeholder(0), Placeholder(2)], ast.Load())], + keywords=[]) + ), + + ], + ast.Sub.__name__ : [ + (Placeholder(0), ast.Num(0), lambda: Placeholder(0)), # X - 0 => X + (ast.Num(0), Placeholder(0), lambda: ast.UnaryOp(op=ast.USub(), operand=Placeholder(0))), # 0 - X => -X + ], +} + class PlaceholderReplace(Transformation): @@ -125,3 +135,21 @@ def visit(self, node): node = PlaceholderReplace(check.placeholders).visit(replace()) self.update = True return super(PatternTransform, self).visit(node) + + def visit_BinOp(self, node): + """ + Special method for BinOp. + Try to replace if node match the given pattern. + """ + self.generic_visit(node) + op_name = node.op.__class__.__name__ + if op_name in know_pattern_BinOp: + for left, right, replace in know_pattern_BinOp[op_name]: + check_left = Check(node.left, dict()) + if check_left.visit(left): + check_right = Check(node.right, check_left.placeholders) + if check_right.visit(right): + node = PlaceholderReplace(check_right.placeholders).visit(replace()) + self.update = True + break + return node