Skip to content

Attempt at fixing copy propagation in LVN pass. #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions examples/lvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,65 @@ def read_first(instrs):
return read


def rename_nonlocal_duplicates(block):
"""Renames any variables that
(1) share a name with an argument in an `id` op where the argument
is not defined in the current block, and
(2) appear after the `id` op.
This avoids a wrong copy propagation. For example,
1 @main {
2 a: int = const 42;
3 .lbl:
4 b: int = id a;
5 a: int = const 5;
6 print b;
7 }

Here, we replace line 5 with:
_a: int = const 5;
to ensure that the copy propagation leads to the correct value of `a`.
"""
all_vars = set(instr['dest'] for instr in block if 'dest' in instr)

def fresh_id(v):
"""Appends `_` to `v` until it is uniquely named in the local scope."""
while v in all_vars:
v = '_{}'.format(v)
return v

def rename_until_next_assign(old_var, new_var, instrs):
"""Renames all instances of `old_var` to `new_var` until the next
assignment of `old_var`.
"""
for instr in instrs:
if 'args' in instr and old_var in instr['args']:
instr['args'] = [new_var if a == old_var else a for a in instr['args']]
if instr.get('dest') == old_var:
return

# A running set to track which variables have been assigned locally.
current_vars = set()
# Mapping from `id` op of a non-local variable to its line number.
id2line = {}

for index, instr in enumerate(block):
if instr.get('op') == 'id':
id_arg = instr['args'][0]
if id_arg not in current_vars and id_arg != instr['dest']:
# Argument of this `id` op is not defined locally.
id2line[id_arg] = index
if 'dest' not in instr:
continue
dest = instr['dest']
current_vars.add(dest)
if dest in id2line and id2line[dest] < index:
# Local assignment shares a name with
# previous `id` of a non-local variable.
old_var = instr['dest']
instr['dest'] = fresh_id(dest)
rename_until_next_assign(old_var, instr['dest'], block[index + 1:])


def lvn_block(block, lookup, canonicalize, fold):
"""Use local value numbering to optimize a basic block. Modify the
instructions in place.
Expand Down Expand Up @@ -96,6 +155,9 @@ def lvn_block(block, lookup, canonicalize, fold):
# Track constant values for values assigned with `const`.
num2const = {}

# Update names to variables that are used in `id`, yet defined outside the local scope.
rename_nonlocal_duplicates(block)

# Initialize the table with numbers for input variables. These
# variables are their own canonical source.
for var in read_first(block):
Expand Down Expand Up @@ -199,6 +261,9 @@ def _lookup(value2num, value):
'le': lambda a, b: a <= b,
'ne': lambda a, b: a != b,
'eq': lambda a, b: a == b,
'or': lambda a, b: a or b,
'and': lambda a, b: a and b,
'not': lambda a: not a
}


Expand All @@ -212,6 +277,14 @@ def _fold(num2const, value):
# Equivalent arguments may be evaluated for equality.
# E.g. `eq x x`, where `x` is not a constant evaluates to `true`.
return value.op != 'ne'

if value.op in {'and', 'or'} and any(v in num2const for v in value.args):
# Short circuiting the logical operators `and` and `or` for two cases:
# (1) `and x c0` -> false, where `c0` a constant that evaluates to `false`.
# (2) `or x c1` -> true, where `c1` a constant that evaluates to `true`.
const_val = num2const[value.args[0] if value.args[0] in num2const else value.args[1]]
if (value.op == 'and' and not const_val) or (value.op == 'or' and const_val):
return const_val
return None
except ZeroDivisionError: # If we hit a dynamic error, bail!
return None
Expand Down
10 changes: 10 additions & 0 deletions examples/test/lvn/nonlocal-variable-reuse-2.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# CMD: bril2json < {filename} | python3 ../../lvn.py -p | bril2txt
@main {
a: int = const 42;
.lbl:
# While `a` is outside the local scope, we essentially
# have a no-op here, so nonlocal re-naming isn't necessary.
a: int = id a;
a: int = const 5;
print a;
}
7 changes: 7 additions & 0 deletions examples/test/lvn/nonlocal-variable-reuse-2.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@main {
a: int = const 42;
.lbl:
a: int = id a;
a: int = const 5;
print a;
}
13 changes: 13 additions & 0 deletions examples/test/lvn/nonlocal-variable-reuse-3.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# CMD: bril2json < {filename} | python3 ../../lvn.py -p | bril2txt
@main {
# This should re-name within local scope only with `lvn`.
a: int = const 3;
b: int = id a;
a: int = id b;
print b;
print a;
a: int = const 4;
print a;
a: int = const 5;
print a;
}
11 changes: 11 additions & 0 deletions examples/test/lvn/nonlocal-variable-reuse-3.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@main {
lvn.0: int = const 3;
b: int = const 3;
a: int = const 3;
print lvn.0;
print lvn.0;
lvn.1: int = const 4;
print lvn.1;
a: int = const 5;
print a;
}
9 changes: 9 additions & 0 deletions examples/test/lvn/nonlocal-variable-reuse.bril
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# CMD: bril2json < {filename} | python3 ../../lvn.py -p | bril2txt
@main {
a: int = const 42;
.lbl:
b: int = id a;
# This should be re-named since `a` is defined outside of the local scope.
a: int = const 5;
print b;
}
7 changes: 7 additions & 0 deletions examples/test/lvn/nonlocal-variable-reuse.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@main {
a: int = const 42;
.lbl:
b: int = id a;
_a: int = const 5;
print a;
}