Skip to content

Commit f338e17

Browse files
authored
support binding for contextmanager (#346)
this PR supports binding of `with <stmt>` etc. by supporting `@contextmanager` etc. in the `with` lowering transform. removed `br` inside with lowering transform by default. One should manually insert a terminator after the lowering or implement a custom lowering as needed.
1 parent 040e26a commit f338e17

3 files changed

Lines changed: 57 additions & 4 deletions

File tree

src/kirin/lowering/python/lowering.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def visit_Call(self, state: State[ast.AST], node: ast.Call) -> Result:
167167
node.func, state.lineno_offset, state.col_offset
168168
)
169169

170-
global_callee_result = self.lower_global_no_raise(state, node.func)
170+
global_callee_result = state.get_global(node.func, no_raise=True)
171171
if global_callee_result is None:
172172
return self.visit_Call_local(state, node)
173173

@@ -253,6 +253,9 @@ def visit_With(self, state: State[ast.AST], node: ast.With) -> Result:
253253
raise BuildError("expected context expression to be a call")
254254

255255
global_callee = state.get_global(item.context_expr.func).data
256+
if isinstance(global_callee, Binding):
257+
global_callee = global_callee.parent
258+
256259
if not issubclass(global_callee, ir.Statement):
257260
raise BuildError(
258261
f"expected context expression to be a statement, got {global_callee}"

src/kirin/lowering/python/traits.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,6 @@ def lower(
242242
raise BuildError(
243243
f"Expected exactly one block in region {region_name}"
244244
)
245-
body_frame.curr_region.blocks[0].stmts.append(
246-
cf.Branch(arguments=(), successor=body_frame.next_block)
247-
)
248245

249246
if len(body_frame.curr_region.blocks) != 1:
250247
raise BuildError(

test/lowering/test_with_binding.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import Any, Generator
2+
from contextlib import contextmanager
3+
4+
from kirin import ir, lowering
5+
from kirin.decl import info, statement
6+
from kirin.prelude import structural_no_opt
7+
from kirin.dialects import ilist
8+
9+
dialect = ir.Dialect("with_binding")
10+
11+
12+
@statement(dialect=dialect)
13+
class ContextStatatement(ir.Statement):
14+
traits = frozenset({lowering.FromPythonWithSingleItem()})
15+
body: ir.Region = info.region(multi=False)
16+
17+
18+
@ir.dialect_group(structural_no_opt.add(dialect))
19+
def dummy(self):
20+
21+
def run_pass(mt):
22+
23+
return mt
24+
25+
return run_pass
26+
27+
28+
@lowering.wraps(ContextStatatement)
29+
@contextmanager
30+
def context_statement() -> Generator[Any, None, None]: ...
31+
32+
33+
@dummy
34+
def with_binding():
35+
x = 1
36+
37+
def fn(x):
38+
return x**2
39+
40+
with context_statement():
41+
with context_statement():
42+
x = ilist.map(fn, ilist.range(10))
43+
44+
return x
45+
46+
47+
def test_with_binding():
48+
stmt = with_binding.callable_region.blocks[0].stmts.at(-2)
49+
assert isinstance(stmt, ContextStatatement)
50+
assert len(stmt.body.blocks) == 1
51+
stmt = stmt.body.blocks[0].stmts.at(0)
52+
assert isinstance(stmt, ContextStatatement)
53+
assert len(stmt.body.blocks[0].stmts) == 5

0 commit comments

Comments
 (0)