diff --git a/heir_py/loop.py b/heir_py/loop.py new file mode 100644 index 000000000..3910c9050 --- /dev/null +++ b/heir_py/loop.py @@ -0,0 +1,24 @@ +"""Example of HEIR Python usage.""" + +from heir_py import pipeline + + +def foo(a): + """An example function with a static loop.""" + result = 2 + for i in range(3): + result = a + result + return result + + +# to replace with decorator +_heir_foo = pipeline.run_compiler(foo).module + +cc = _heir_foo.foo__generate_crypto_context() +kp = cc.KeyGen() +_heir_foo.foo__configure_crypto_context(cc, kp.secretKey) +arg0_enc = _heir_foo.foo__encrypt__arg0(cc, 2, kp.publicKey) +res_enc = _heir_foo.foo(cc, arg0_enc) +res = _heir_foo.foo__decrypt__result0(cc, res_enc, kp.secretKey) + +print(res) # should be 8 diff --git a/heir_py/mlir_emitter.py b/heir_py/mlir_emitter.py index 03324e96e..dec2ade50 100644 --- a/heir_py/mlir_emitter.py +++ b/heir_py/mlir_emitter.py @@ -2,10 +2,98 @@ import operator import textwrap +from dataclasses import dataclass from numba.core import ir +def get_constant(var, ssa_ir): + # Get constant value defining this var, else raise error + assert var.name in ssa_ir._definitions + vardef = ssa_ir._definitions[var.name][0] + if type(vardef) != ir.Const: + raise ValueError("expected constant variable") + return vardef.value + + +class HeaderInfo: + + def __init__(self, header_block): + body = header_block.body + assert len(body) == 5 + self.phi_var = body[3].target + self.body_id = body[4].truebr + self.next_id = body[4].falsebr + + +class RangeArgs: + + def __init__(self, range_call, ssa_ir): + args = range_call.value.args + self.stop = get_constant(args[0], ssa_ir) + self.start = 0 + self.step = 1 + if len(args) > 1: + self.stop = get_constant(args[1], ssa_ir) + if len(args) > 2: + self.step = get_constant(args[2], ssa_ir) + + +@dataclass +class Loop: + header_id: int + header: HeaderInfo + range: RangeArgs + inits: list[str] + + +def build_loop_from_call(index, body, blocks, ssa_ir): + # Build a loop from a range call starting at index + assert type(body[index + 1] == ir.Assign) + assert type(body[index + 2] == ir.Assign) + assert type(body[index + 3] == ir.Jump) + + header_id = body[index + 3].target + header = HeaderInfo(blocks[header_id]) + range_args = RangeArgs(body[index], ssa_ir) + + # Loop body must start with assigning the local iter var + loop_body = blocks[header.body_id].body + assert loop_body[0].value == header.phi_var + + inits = [] + for instr in loop_body[1:]: + if type(instr) == ir.Assign and not instr.target.is_temp: + inits.append(instr.target) + if len(inits) > 1: + raise NotImplementedError("Multiple iter_args not supported") + + return Loop( + header_id, + header, + range_args, + inits, + ) + + +def is_range_call(instr, ssa_ir): + # Returns true if the IR instruction is a call to a range + if type(instr) != ir.Assign or type(instr.value) != ir.Expr: + return False + if instr.value.op != "call": + return False + func = instr.value.func + + assert func.name in ssa_ir._definitions + func_def_list = ssa_ir._definitions[func.name] + assert len(func_def_list) == 1 + func_def = func_def_list[0] + + if type(func_def) != ir.Global: + return False + return func_def.name == "range" + + class TextualMlirEmitter: def __init__(self, ssa_ir): @@ -13,6 +101,9 @@ def __init__(self, ssa_ir): self.temp_var_id = 0 self.numba_names_to_ssa_var_names = {} self.globals_map = {} + self.loops = {} + self.printed_blocks = {} + self.omit_block_header = {} def emit(self): func_name = self.ssa_ir.func_id.func_name @@ -25,7 +116,7 @@ def emit(self): # TODO(#1162): get inferred or explicit return types return_types_str = "i64" - body = self.emit_body() + body = self.emit_blocks() mlir_func = f"""func.func @{func_name}({args_str}) -> ({return_types_str}) {{ {textwrap.indent(body, ' ')} @@ -33,30 +124,52 @@ def emit(self): """ return mlir_func - def emit_body(self): + def emit_blocks(self): blocks = self.ssa_ir.blocks - str_blocks = [] - first = True + # collect loops and block header needs + self.omit_block_header[sorted(blocks.items())[0][0]] = True for block_id, block in sorted(blocks.items()): - instructions = [] - for instr in block.body: - result = self.emit_instruction(instr) - if result: - instructions.append(result) + for i in range(len(block.body)): + # Detect a range call + instr = block.body[i] + if is_range_call(instr, self.ssa_ir): + loop = build_loop_from_call(i, block.body, blocks, self.ssa_ir) + self.loops[instr.target] = loop + self.omit_block_header[loop.header.next_id] = True - if first: - first = False + # print blocks + str_blocks = [] + for block_id, block in sorted(blocks.items()): + if block_id in self.printed_blocks: + continue + if block_id in self.omit_block_header: block_header = "" else: block_header = f"^bb{block_id}:\n" - str_blocks.append( - block_header + textwrap.indent("\n".join(instructions), " ") + block_header + textwrap.indent(self.emit_block(block), " ") ) + self.printed_blocks[block_id] = True return "\n".join(str_blocks) + def emit_block(self, block): + instructions = [] + for i in range(len(block.body)): + instr = block.body[i] + if type(instr) == ir.Assign and instr.target in self.loops: + # We hit a range call for a loop + result = self.emit_loop(instr.target) + instructions.append(result) + # Exit instructions, should be the end of the block + break + else: + result = self.emit_instruction(instr) + if result: + instructions.append(result) + return "\n".join(instructions) + def emit_instruction(self, instr): match instr: case ir.Assign(): @@ -65,6 +178,10 @@ def emit_instruction(self, instr): return self.emit_branch(instr) case ir.Return(): return self.emit_return(instr) + case ir.Jump(): + # TODO ignore + assert instr.is_terminator + return raise NotImplementedError("Unsupported instruction: " + str(instr)) def get_or_create_name(self, var): @@ -125,6 +242,10 @@ def emit_assign(self, assign): case ir.Global(): self.globals_map[assign.target.name] = assign.value.name return "" + case ir.Var(): + # FIXME: keep track of loop var, and assign + self.forward_name(from_var=assign.target, to_var=assign.value) + return "" raise NotImplementedError() def emit_expr(self, expr): @@ -134,16 +255,16 @@ def emit_expr(self, expr): raise NotImplementedError() elif expr.op == "unary": raise NotImplementedError() - - # these are all things numba has hooks for upstream, but we didn't implement - # in the prototype - elif expr.op == "pair_first": raise NotImplementedError() elif expr.op == "pair_second": raise NotImplementedError() elif expr.op in ("getiter", "iternext"): raise NotImplementedError() + + # these are all things numba has hooks for upstream, but we didn't implement + # in the prototype + elif expr.op == "exhaust_iter": raise NotImplementedError() elif expr.op == "getattr": @@ -191,6 +312,49 @@ def emit_branch(self, branch): condvar = self.get_name(branch.cond) return f"cf.cond_br {condvar}, ^bb{branch.truebr}, ^bb{branch.falsebr}" + def emit_loop(self, target): + # Note right now loops that use the %i value directly will need an index_cast to an i64 element. + loop = self.loops[target] + resultvar = self.get_or_create_name(target) + itvar = self.get_or_create_name(loop.header.phi_var) # create var for i + for_str = ( + f"affine.for {itvar} = {loop.range.start} to {loop.range.stop} step" + f" {loop.range.step}" + ) + + if len(loop.inits) == 1: + # Note: we must generalize to inits > 1 + init_val = self.get_name(loop.inits[0]) + # Within the loop, forward the name for the init val to a new temp var + iter_arg = "itarg" + self.numba_names_to_ssa_var_names[loop.inits[0].name] = iter_arg + for_str = ( + f"{resultvar} = {for_str} iter_args(%{iter_arg} = {init_val}) ->" + " (i64)" + ) + self.printed_blocks[loop.header_id] = True + + # TODO(#1412): support nested loops. + loop_block = self.ssa_ir.blocks[loop.header.body_id] + for instr in loop_block.body: + if type(instr) == ir.Assign and instr.target in self.loops: + raise NotImplementedError("Nested loops are not supported") + + body_str = self.emit_block(loop_block) + if len(loop.inits) == 1: + # Yield the iter arg + yield_var = self.get_name(loop.inits[0]) + yield_str = f"affine.yield {yield_var} : i64" + body_str += "\n" + yield_str + + # After we emit the body, we need to update the printed block map and replace all uses of the iterarg after the block with the result of the for loop. + if loop.inits: + self.forward_name(loop.inits[0], target) + self.printed_blocks[loop.header.body_id] = True + + result = for_str + " {\n" + textwrap.indent(body_str, " ") + "\n}" + return result + def emit_return(self, ret): var = self.get_name(ret.value) # TODO(#1162): replace i64 with inferred or explicit return type