Skip to content

Commit 16fc8a7

Browse files
committed
frontend: support simple loops structs in python frontend
Signed-off-by: Asra <[email protected]>
1 parent e9e1248 commit 16fc8a7

File tree

2 files changed

+205
-17
lines changed

2 files changed

+205
-17
lines changed

heir_py/loop.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Example of HEIR Python usage."""
2+
3+
from heir_py import pipeline
4+
5+
6+
def foo(a):
7+
"""An example function with a static loop."""
8+
result = 2
9+
for i in range(3):
10+
result = a + result
11+
return result
12+
13+
14+
# to replace with decorator
15+
_heir_foo = pipeline.run_compiler(foo).module
16+
17+
cc = _heir_foo.foo__generate_crypto_context()
18+
kp = cc.KeyGen()
19+
_heir_foo.foo__configure_crypto_context(cc, kp.secretKey)
20+
arg0_enc = _heir_foo.foo__encrypt__arg0(cc, 2, kp.publicKey)
21+
res_enc = _heir_foo.foo(cc, arg0_enc)
22+
res = _heir_foo.foo__decrypt__result0(cc, res_enc, kp.secretKey)
23+
24+
print(res) # should be 8

heir_py/mlir_emitter.py

Lines changed: 181 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,108 @@
22

33
import operator
44
import textwrap
5+
from dataclasses import dataclass
56

67
from numba.core import ir
78

89

10+
def get_constant(var, ssa_ir):
11+
# Get constant value defining this var, else raise error
12+
assert var.name in ssa_ir._definitions
13+
vardef = ssa_ir._definitions[var.name][0]
14+
if type(vardef) != ir.Const:
15+
raise ValueError("expected constant variable")
16+
return vardef.value
17+
18+
19+
class HeaderInfo:
20+
21+
def __init__(self, header_block):
22+
body = header_block.body
23+
assert len(body) == 5
24+
self.phi_var = body[3].target
25+
self.body_id = body[4].truebr
26+
self.next_id = body[4].falsebr
27+
28+
29+
class RangeArgs:
30+
31+
def __init__(self, range_call, ssa_ir):
32+
args = range_call.value.args
33+
self.stop = get_constant(args[0], ssa_ir)
34+
self.start = 0
35+
self.step = 1
36+
if len(args) > 1:
37+
self.stop = get_constant(args[1], ssa_ir)
38+
if len(args) > 2:
39+
self.step = get_constant(args[2], ssa_ir)
40+
41+
42+
@dataclass
43+
class Loop:
44+
header_id: int
45+
header: HeaderInfo
46+
range: RangeArgs
47+
inits: list[str]
48+
49+
50+
def build_loop_from_call(index, body, blocks, ssa_ir):
51+
# Build a loop from a range call starting at index
52+
assert type(body[index + 1] == ir.Assign)
53+
assert type(body[index + 2] == ir.Assign)
54+
assert type(body[index + 3] == ir.Jump)
55+
56+
header_id = body[index + 3].target
57+
header = HeaderInfo(blocks[header_id])
58+
range_args = RangeArgs(body[index], ssa_ir)
59+
60+
# Loop body must start with assigning the local iter var
61+
loop_body = blocks[header.body_id].body
62+
assert loop_body[0].value == header.phi_var
63+
64+
inits = []
65+
for instr in loop_body[1:]:
66+
if type(instr) == ir.Assign and not instr.target.is_temp:
67+
inits.append(instr.target)
68+
if len(inits) > 1:
69+
raise NotImplementedError("Multiple iter_args not supported")
70+
71+
return Loop(
72+
header_id,
73+
header,
74+
range_args,
75+
inits,
76+
)
77+
78+
79+
def is_range_call(instr, ssa_ir):
80+
# Returns true if the IR instruction is a call to a range
81+
if type(instr) != ir.Assign or type(instr.value) != ir.Expr:
82+
return False
83+
if instr.value.op != "call":
84+
return False
85+
func = instr.value.func
86+
87+
assert func.name in ssa_ir._definitions
88+
func_def_list = ssa_ir._definitions[func.name]
89+
assert len(func_def_list) == 1
90+
func_def = func_def_list[0]
91+
92+
if type(func_def) != ir.Global:
93+
return False
94+
return func_def.name == "range"
95+
96+
997
class TextualMlirEmitter:
1098

1199
def __init__(self, ssa_ir):
12100
self.ssa_ir = ssa_ir
13101
self.temp_var_id = 0
14102
self.numba_names_to_ssa_var_names = {}
15103
self.globals_map = {}
104+
self.loops = {}
105+
self.printed_blocks = {}
106+
self.omit_block_header = {}
16107

17108
def emit(self):
18109
func_name = self.ssa_ir.func_id.func_name
@@ -25,38 +116,60 @@ def emit(self):
25116
# TODO(#1162): get inferred or explicit return types
26117
return_types_str = "i64"
27118

28-
body = self.emit_body()
119+
body = self.emit_blocks()
29120

30121
mlir_func = f"""func.func @{func_name}({args_str}) -> ({return_types_str}) {{
31122
{textwrap.indent(body, ' ')}
32123
}}
33124
"""
34125
return mlir_func
35126

36-
def emit_body(self):
127+
def emit_blocks(self):
37128
blocks = self.ssa_ir.blocks
38-
str_blocks = []
39-
first = True
40129

130+
# collect loops and block header needs
131+
self.omit_block_header[sorted(blocks.items())[0][0]] = True
41132
for block_id, block in sorted(blocks.items()):
42-
instructions = []
43-
for instr in block.body:
44-
result = self.emit_instruction(instr)
45-
if result:
46-
instructions.append(result)
133+
for i in range(len(block.body)):
134+
# Detect a range call
135+
instr = block.body[i]
136+
if is_range_call(instr, self.ssa_ir):
137+
loop = build_loop_from_call(i, block.body, blocks, self.ssa_ir)
138+
self.loops[instr.target] = loop
139+
self.omit_block_header[loop.header.next_id] = True
47140

48-
if first:
49-
first = False
141+
# print blocks
142+
str_blocks = []
143+
for block_id, block in sorted(blocks.items()):
144+
if block_id in self.printed_blocks:
145+
continue
146+
if block_id in self.omit_block_header:
50147
block_header = ""
51148
else:
52149
block_header = f"^bb{block_id}:\n"
53-
54150
str_blocks.append(
55-
block_header + textwrap.indent("\n".join(instructions), " ")
151+
block_header + textwrap.indent(self.emit_block(block), " ")
56152
)
153+
self.printed_blocks[block_id] = True
57154

58155
return "\n".join(str_blocks)
59156

157+
def emit_block(self, block):
158+
instructions = []
159+
for i in range(len(block.body)):
160+
instr = block.body[i]
161+
if type(instr) == ir.Assign and instr.target in self.loops:
162+
# We hit a range call for a loop
163+
result = self.emit_loop(instr.target)
164+
instructions.append(result)
165+
# Exit instructions, should be the end of the block
166+
break
167+
else:
168+
result = self.emit_instruction(instr)
169+
if result:
170+
instructions.append(result)
171+
return "\n".join(instructions)
172+
60173
def emit_instruction(self, instr):
61174
match instr:
62175
case ir.Assign():
@@ -65,6 +178,10 @@ def emit_instruction(self, instr):
65178
return self.emit_branch(instr)
66179
case ir.Return():
67180
return self.emit_return(instr)
181+
case ir.Jump():
182+
# TODO ignore
183+
assert instr.is_terminator
184+
return
68185
raise NotImplementedError("Unsupported instruction: " + str(instr))
69186

70187
def get_or_create_name(self, var):
@@ -125,6 +242,10 @@ def emit_assign(self, assign):
125242
case ir.Global():
126243
self.globals_map[assign.target.name] = assign.value.name
127244
return ""
245+
case ir.Var():
246+
# FIXME: keep track of loop var, and assign
247+
self.forward_name(from_var=assign.target, to_var=assign.value)
248+
return ""
128249
raise NotImplementedError()
129250

130251
def emit_expr(self, expr):
@@ -134,16 +255,16 @@ def emit_expr(self, expr):
134255
raise NotImplementedError()
135256
elif expr.op == "unary":
136257
raise NotImplementedError()
137-
138-
# these are all things numba has hooks for upstream, but we didn't implement
139-
# in the prototype
140-
141258
elif expr.op == "pair_first":
142259
raise NotImplementedError()
143260
elif expr.op == "pair_second":
144261
raise NotImplementedError()
145262
elif expr.op in ("getiter", "iternext"):
146263
raise NotImplementedError()
264+
265+
# these are all things numba has hooks for upstream, but we didn't implement
266+
# in the prototype
267+
147268
elif expr.op == "exhaust_iter":
148269
raise NotImplementedError()
149270
elif expr.op == "getattr":
@@ -191,6 +312,49 @@ def emit_branch(self, branch):
191312
condvar = self.get_name(branch.cond)
192313
return f"cf.cond_br {condvar}, ^bb{branch.truebr}, ^bb{branch.falsebr}"
193314

315+
def emit_loop(self, target):
316+
# Note right now loops that use the %i value directly will need an index_cast to an i64 element.
317+
loop = self.loops[target]
318+
resultvar = self.get_or_create_name(target)
319+
itvar = self.get_or_create_name(loop.header.phi_var) # create var for i
320+
for_str = (
321+
f"affine.for {itvar} = {loop.range.start} to {loop.range.stop} step"
322+
f" {loop.range.step}"
323+
)
324+
325+
if len(loop.inits) == 1:
326+
# Note: we must generalize to inits > 1
327+
init_val = self.get_name(loop.inits[0])
328+
# Within the loop, forward the name for the init val to a new temp var
329+
iter_arg = "itarg"
330+
self.numba_names_to_ssa_var_names[loop.inits[0].name] = iter_arg
331+
for_str = (
332+
f"{resultvar} = {for_str} iter_args(%{iter_arg} = {init_val}) ->"
333+
" (i64)"
334+
)
335+
self.printed_blocks[loop.header_id] = True
336+
337+
# TODO(#1412): support nested loops.
338+
loop_block = self.ssa_ir.blocks[loop.header.body_id]
339+
for instr in loop_block.body:
340+
if type(instr) == ir.Assign and instr.target in self.loops:
341+
raise NotImplementedError("Nested loops are not supported")
342+
343+
body_str = self.emit_block(loop_block)
344+
if len(loop.inits) == 1:
345+
# Yield the iter arg
346+
yield_var = self.get_name(loop.inits[0])
347+
yield_str = f"affine.yield {yield_var} : i64"
348+
body_str += "\n" + yield_str
349+
350+
# After we emit the body, we need to update the printed block map and replace all uses of the iterarg after the block with the result of the for loop.
351+
if loop.inits:
352+
self.forward_name(loop.inits[0], target)
353+
self.printed_blocks[loop.header.body_id] = True
354+
355+
result = for_str + " {\n" + textwrap.indent(body_str, " ") + "\n}"
356+
return result
357+
194358
def emit_return(self, ret):
195359
var = self.get_name(ret.value)
196360
# TODO(#1162): replace i64 with inferred or explicit return type

0 commit comments

Comments
 (0)