Skip to content

Commit

Permalink
frontend: support simple loops structs in python frontend
Browse files Browse the repository at this point in the history
Signed-off-by: Asra <[email protected]>
  • Loading branch information
asraa committed Feb 14, 2025
1 parent e9e1248 commit 16fc8a7
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 17 deletions.
24 changes: 24 additions & 0 deletions heir_py/loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Example of HEIR Python usage."""

from heir_py import pipeline


def foo(a):
"""An example function with a static loop."""
result = 2
for i in range(3):
result = a + result
return result


# to replace with decorator
_heir_foo = pipeline.run_compiler(foo).module

cc = _heir_foo.foo__generate_crypto_context()
kp = cc.KeyGen()
_heir_foo.foo__configure_crypto_context(cc, kp.secretKey)
arg0_enc = _heir_foo.foo__encrypt__arg0(cc, 2, kp.publicKey)
res_enc = _heir_foo.foo(cc, arg0_enc)
res = _heir_foo.foo__decrypt__result0(cc, res_enc, kp.secretKey)

print(res) # should be 8
198 changes: 181 additions & 17 deletions heir_py/mlir_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,108 @@

import operator
import textwrap
from dataclasses import dataclass

from numba.core import ir


def get_constant(var, ssa_ir):
# Get constant value defining this var, else raise error
assert var.name in ssa_ir._definitions
vardef = ssa_ir._definitions[var.name][0]
if type(vardef) != ir.Const:
raise ValueError("expected constant variable")
return vardef.value


class HeaderInfo:

def __init__(self, header_block):
body = header_block.body
assert len(body) == 5
self.phi_var = body[3].target
self.body_id = body[4].truebr
self.next_id = body[4].falsebr


class RangeArgs:

def __init__(self, range_call, ssa_ir):
args = range_call.value.args
self.stop = get_constant(args[0], ssa_ir)
self.start = 0
self.step = 1
if len(args) > 1:
self.stop = get_constant(args[1], ssa_ir)
if len(args) > 2:
self.step = get_constant(args[2], ssa_ir)


@dataclass
class Loop:
header_id: int
header: HeaderInfo
range: RangeArgs
inits: list[str]


def build_loop_from_call(index, body, blocks, ssa_ir):
# Build a loop from a range call starting at index
assert type(body[index + 1] == ir.Assign)
assert type(body[index + 2] == ir.Assign)
assert type(body[index + 3] == ir.Jump)

header_id = body[index + 3].target
header = HeaderInfo(blocks[header_id])
range_args = RangeArgs(body[index], ssa_ir)

# Loop body must start with assigning the local iter var
loop_body = blocks[header.body_id].body
assert loop_body[0].value == header.phi_var

inits = []
for instr in loop_body[1:]:
if type(instr) == ir.Assign and not instr.target.is_temp:
inits.append(instr.target)
if len(inits) > 1:
raise NotImplementedError("Multiple iter_args not supported")

return Loop(
header_id,
header,
range_args,
inits,
)


def is_range_call(instr, ssa_ir):
# Returns true if the IR instruction is a call to a range
if type(instr) != ir.Assign or type(instr.value) != ir.Expr:
return False
if instr.value.op != "call":
return False
func = instr.value.func

assert func.name in ssa_ir._definitions
func_def_list = ssa_ir._definitions[func.name]
assert len(func_def_list) == 1
func_def = func_def_list[0]

if type(func_def) != ir.Global:
return False
return func_def.name == "range"


class TextualMlirEmitter:

def __init__(self, ssa_ir):
self.ssa_ir = ssa_ir
self.temp_var_id = 0
self.numba_names_to_ssa_var_names = {}
self.globals_map = {}
self.loops = {}
self.printed_blocks = {}
self.omit_block_header = {}

def emit(self):
func_name = self.ssa_ir.func_id.func_name
Expand All @@ -25,38 +116,60 @@ def emit(self):
# TODO(#1162): get inferred or explicit return types
return_types_str = "i64"

body = self.emit_body()
body = self.emit_blocks()

mlir_func = f"""func.func @{func_name}({args_str}) -> ({return_types_str}) {{
{textwrap.indent(body, ' ')}
}}
"""
return mlir_func

def emit_body(self):
def emit_blocks(self):
blocks = self.ssa_ir.blocks
str_blocks = []
first = True

# collect loops and block header needs
self.omit_block_header[sorted(blocks.items())[0][0]] = True
for block_id, block in sorted(blocks.items()):
instructions = []
for instr in block.body:
result = self.emit_instruction(instr)
if result:
instructions.append(result)
for i in range(len(block.body)):
# Detect a range call
instr = block.body[i]
if is_range_call(instr, self.ssa_ir):
loop = build_loop_from_call(i, block.body, blocks, self.ssa_ir)
self.loops[instr.target] = loop
self.omit_block_header[loop.header.next_id] = True

if first:
first = False
# print blocks
str_blocks = []
for block_id, block in sorted(blocks.items()):
if block_id in self.printed_blocks:
continue
if block_id in self.omit_block_header:
block_header = ""
else:
block_header = f"^bb{block_id}:\n"

str_blocks.append(
block_header + textwrap.indent("\n".join(instructions), " ")
block_header + textwrap.indent(self.emit_block(block), " ")
)
self.printed_blocks[block_id] = True

return "\n".join(str_blocks)

def emit_block(self, block):
instructions = []
for i in range(len(block.body)):
instr = block.body[i]
if type(instr) == ir.Assign and instr.target in self.loops:
# We hit a range call for a loop
result = self.emit_loop(instr.target)
instructions.append(result)
# Exit instructions, should be the end of the block
break
else:
result = self.emit_instruction(instr)
if result:
instructions.append(result)
return "\n".join(instructions)

def emit_instruction(self, instr):
match instr:
case ir.Assign():
Expand All @@ -65,6 +178,10 @@ def emit_instruction(self, instr):
return self.emit_branch(instr)
case ir.Return():
return self.emit_return(instr)
case ir.Jump():
# TODO ignore
assert instr.is_terminator
return
raise NotImplementedError("Unsupported instruction: " + str(instr))

def get_or_create_name(self, var):
Expand Down Expand Up @@ -125,6 +242,10 @@ def emit_assign(self, assign):
case ir.Global():
self.globals_map[assign.target.name] = assign.value.name
return ""
case ir.Var():
# FIXME: keep track of loop var, and assign
self.forward_name(from_var=assign.target, to_var=assign.value)
return ""
raise NotImplementedError()

def emit_expr(self, expr):
Expand All @@ -134,16 +255,16 @@ def emit_expr(self, expr):
raise NotImplementedError()
elif expr.op == "unary":
raise NotImplementedError()

# these are all things numba has hooks for upstream, but we didn't implement
# in the prototype

elif expr.op == "pair_first":
raise NotImplementedError()
elif expr.op == "pair_second":
raise NotImplementedError()
elif expr.op in ("getiter", "iternext"):
raise NotImplementedError()

# these are all things numba has hooks for upstream, but we didn't implement
# in the prototype

elif expr.op == "exhaust_iter":
raise NotImplementedError()
elif expr.op == "getattr":
Expand Down Expand Up @@ -191,6 +312,49 @@ def emit_branch(self, branch):
condvar = self.get_name(branch.cond)
return f"cf.cond_br {condvar}, ^bb{branch.truebr}, ^bb{branch.falsebr}"

def emit_loop(self, target):
# Note right now loops that use the %i value directly will need an index_cast to an i64 element.
loop = self.loops[target]
resultvar = self.get_or_create_name(target)
itvar = self.get_or_create_name(loop.header.phi_var) # create var for i
for_str = (
f"affine.for {itvar} = {loop.range.start} to {loop.range.stop} step"
f" {loop.range.step}"
)

if len(loop.inits) == 1:
# Note: we must generalize to inits > 1
init_val = self.get_name(loop.inits[0])
# Within the loop, forward the name for the init val to a new temp var
iter_arg = "itarg"
self.numba_names_to_ssa_var_names[loop.inits[0].name] = iter_arg
for_str = (
f"{resultvar} = {for_str} iter_args(%{iter_arg} = {init_val}) ->"
" (i64)"
)
self.printed_blocks[loop.header_id] = True

# TODO(#1412): support nested loops.
loop_block = self.ssa_ir.blocks[loop.header.body_id]
for instr in loop_block.body:
if type(instr) == ir.Assign and instr.target in self.loops:
raise NotImplementedError("Nested loops are not supported")

body_str = self.emit_block(loop_block)
if len(loop.inits) == 1:
# Yield the iter arg
yield_var = self.get_name(loop.inits[0])
yield_str = f"affine.yield {yield_var} : i64"
body_str += "\n" + yield_str

# After we emit the body, we need to update the printed block map and replace all uses of the iterarg after the block with the result of the for loop.
if loop.inits:
self.forward_name(loop.inits[0], target)
self.printed_blocks[loop.header.body_id] = True

result = for_str + " {\n" + textwrap.indent(body_str, " ") + "\n}"
return result

def emit_return(self, ret):
var = self.get_name(ret.value)
# TODO(#1162): replace i64 with inferred or explicit return type
Expand Down

0 comments on commit 16fc8a7

Please sign in to comment.