Skip to content

Commit df6efeb

Browse files
authored
Refactoring AST lowering transform (#343)
finally decide to just do it... gonna supercede #315 This PR refactors the AST lowering framework to make it more generic. So that other langauge AST can reuse the framework to generate SSA IR (e.g QASM2, etc.). We will also try to clean up the APIs in the process. ### Highlights - new error message of lowering <img width="682" alt="image" src="https://github.com/user-attachments/assets/b5591b7a-a3dc-44fa-b977-538a2049d691" /> - removing the `push_frame` and `pop_frame` in `lowering.State`. Now you can just write ```python with state.frame(stmts) as body_frame: ... ``` The main changes are mostly splitting the APIs into several classes. We still need to improve the precision of error reports here, but we will do that in a few following PRs. cc: @weinbe58
1 parent c26cd62 commit df6efeb

80 files changed

Lines changed: 1987 additions & 1833 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/kirin/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# re-exports the public API of the kirin package
22
from kirin import ir
3-
from kirin.decl import info, statement
43

54
from . import types as types
65

7-
__all__ = ["ir", "types", "statement", "info"]
6+
__all__ = ["ir", "types"]

src/kirin/dialects/debug.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,25 @@
22

33
import rich
44

5-
from kirin import ir, decl, types, interp, lowering, exceptions
5+
from kirin import ir, decl, types, interp, lowering
66

77
dialect = ir.Dialect("debug")
88

99

10-
class InfoLowering(ir.FromPythonCall):
10+
class InfoLowering(lowering.FromPythonCall):
1111

1212
def lower(
13-
self, stmt: type, state: lowering.LoweringState, node: ast.Call
13+
self, stmt: type, state: lowering.State, node: ast.Call
1414
) -> lowering.Result:
1515
if len(node.args) == 0:
16-
raise exceptions.DialectLoweringError(
17-
"info() requires at least one argument"
18-
)
16+
raise lowering.BuildError("info() requires at least one argument")
1917

20-
msg = state.visit(node.args[0]).expect_one()
18+
msg = state.lower(node.args[0]).expect_one()
2119
if len(node.args) > 1:
22-
inputs = tuple(state.visit(arg).expect_one() for arg in node.args[1:])
20+
inputs = tuple(state.lower(arg).expect_one() for arg in node.args[1:])
2321
else:
2422
inputs = ()
25-
return lowering.Result(state.append_stmt(Info(msg=msg, inputs=inputs)))
23+
state.current_frame.push(Info(msg=msg, inputs=inputs))
2624

2725

2826
@decl.statement(dialect=dialect)

src/kirin/dialects/ilist/_wrapper.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import typing
22

3-
from kirin.lowering import wraps
3+
from kirin import lowering
44

55
from . import stmts
66
from .runtime import IList
@@ -25,42 +25,42 @@ def range(start: int, stop: int) -> IList[int, typing.Any]: ...
2525
def range(start: int, stop: int, step: int) -> IList[int, typing.Any]: ...
2626

2727

28-
@wraps(stmts.Range)
28+
@lowering.wraps(stmts.Range)
2929
def range(start: int, stop: int, step: int) -> IList[int, typing.Any]: ...
3030

3131

32-
@wraps(stmts.Map)
32+
@lowering.wraps(stmts.Map)
3333
def map(
3434
fn: typing.Callable[[ElemT], OutElemT],
3535
collection: IList[ElemT, LenT] | list[ElemT],
3636
) -> IList[OutElemT, LenT]: ...
3737

3838

39-
@wraps(stmts.Foldr)
39+
@lowering.wraps(stmts.Foldr)
4040
def foldr(
4141
fn: typing.Callable[[ElemT, OutElemT], OutElemT],
4242
collection: IList[ElemT, LenT] | list[ElemT],
4343
init: OutElemT,
4444
) -> OutElemT: ...
4545

4646

47-
@wraps(stmts.Foldl)
47+
@lowering.wraps(stmts.Foldl)
4848
def foldl(
4949
fn: typing.Callable[[OutElemT, ElemT], OutElemT],
5050
collection: IList[ElemT, LenT] | list[ElemT],
5151
init: OutElemT,
5252
) -> OutElemT: ...
5353

5454

55-
@wraps(stmts.Scan)
55+
@lowering.wraps(stmts.Scan)
5656
def scan(
5757
fn: typing.Callable[[OutElemT, ElemT], tuple[OutElemT, ResultT]],
5858
collection: IList[ElemT, LenT] | list[ElemT],
5959
init: OutElemT,
6060
) -> tuple[OutElemT, IList[ResultT, LenT]]: ...
6161

6262

63-
@wraps(stmts.ForEach)
63+
@lowering.wraps(stmts.ForEach)
6464
def for_each(
6565
fn: typing.Callable[[ElemT], typing.Any],
6666
collection: IList[ElemT, LenT] | list[ElemT],
Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
import ast
22

3-
from kirin import types
4-
from kirin.lowering import Result, FromPythonAST, LoweringState
3+
from kirin import types, lowering
54

65
from . import stmts as ilist
76
from ._dialect import dialect
87

98

109
@dialect.register
11-
class PythonLowering(FromPythonAST):
10+
class PythonLowering(lowering.FromPythonAST):
1211

13-
def lower_List(self, state: LoweringState, node: ast.List) -> Result:
14-
elts = tuple(state.visit(each).expect_one() for each in node.elts)
12+
def lower_List(self, state: lowering.State, node: ast.List) -> lowering.Result:
13+
elts = tuple(state.lower(each).expect_one() for each in node.elts)
1514

1615
if len(elts):
1716
typ = elts[0].type
@@ -20,6 +19,4 @@ def lower_List(self, state: LoweringState, node: ast.List) -> Result:
2019
else:
2120
typ = types.Any
2221

23-
stmt = ilist.New(values=tuple(elts))
24-
state.append_stmt(stmt)
25-
return Result(stmt)
22+
return state.current_frame.push(ilist.New(values=tuple(elts)))

src/kirin/dialects/ilist/stmts.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Sequence
22

3-
from kirin import ir, types
3+
from kirin import ir, types, lowering
44
from kirin.decl import info, statement
55

66
from .runtime import IList
@@ -14,7 +14,7 @@
1414
@statement(dialect=dialect)
1515
class Range(ir.Statement):
1616
name = "range"
17-
traits = frozenset({ir.Pure(), ir.FromPythonRangeLike()})
17+
traits = frozenset({ir.Pure(), lowering.FromPythonRangeLike()})
1818
start: ir.SSAValue = info.argument(types.Int)
1919
stop: ir.SSAValue = info.argument(types.Int)
2020
step: ir.SSAValue = info.argument(types.Int)
@@ -23,7 +23,7 @@ class Range(ir.Statement):
2323

2424
@statement(dialect=dialect, init=False)
2525
class New(ir.Statement):
26-
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
26+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
2727
values: tuple[ir.SSAValue, ...] = info.argument(ElemT)
2828
result: ir.ResultValue = info.result(IListType[ElemT])
2929

@@ -49,7 +49,7 @@ def __init__(
4949

5050
@statement(dialect=dialect)
5151
class Push(ir.Statement):
52-
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
52+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
5353
lst: ir.SSAValue = info.argument(IListType[ElemT])
5454
value: ir.SSAValue = info.argument(IListType[ElemT])
5555
result: ir.ResultValue = info.result(IListType[ElemT])
@@ -60,7 +60,7 @@ class Push(ir.Statement):
6060

6161
@statement(dialect=dialect)
6262
class Map(ir.Statement):
63-
traits = frozenset({ir.MaybePure(), ir.FromPythonCall()})
63+
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
6464
purity: bool = info.attribute(default=False)
6565
fn: ir.SSAValue = info.argument(types.MethodType[[ElemT], OutElemT])
6666
collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen])
@@ -69,7 +69,7 @@ class Map(ir.Statement):
6969

7070
@statement(dialect=dialect)
7171
class Foldr(ir.Statement):
72-
traits = frozenset({ir.MaybePure(), ir.FromPythonCall()})
72+
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
7373
purity: bool = info.attribute(default=False)
7474
fn: ir.SSAValue = info.argument(
7575
types.Generic(ir.Method, [ElemT, OutElemT], OutElemT)
@@ -81,7 +81,7 @@ class Foldr(ir.Statement):
8181

8282
@statement(dialect=dialect)
8383
class Foldl(ir.Statement):
84-
traits = frozenset({ir.MaybePure(), ir.FromPythonCall()})
84+
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
8585
purity: bool = info.attribute(default=False)
8686
fn: ir.SSAValue = info.argument(
8787
types.Generic(ir.Method, [OutElemT, ElemT], OutElemT)
@@ -97,7 +97,7 @@ class Foldl(ir.Statement):
9797

9898
@statement(dialect=dialect)
9999
class Scan(ir.Statement):
100-
traits = frozenset({ir.MaybePure(), ir.FromPythonCall()})
100+
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
101101
purity: bool = info.attribute(default=False)
102102
fn: ir.SSAValue = info.argument(
103103
types.Generic(ir.Method, [OutElemT, ElemT], types.Tuple[OutElemT, ResultT])
@@ -111,7 +111,7 @@ class Scan(ir.Statement):
111111

112112
@statement(dialect=dialect)
113113
class ForEach(ir.Statement):
114-
traits = frozenset({ir.MaybePure(), ir.FromPythonCall()})
114+
traits = frozenset({ir.MaybePure(), lowering.FromPythonCall()})
115115
purity: bool = info.attribute(default=False)
116116
fn: ir.SSAValue = info.argument(types.Generic(ir.Method, [ElemT], types.NoneType))
117117
collection: ir.SSAValue = info.argument(IListType[ElemT])

src/kirin/dialects/lowering/call.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from kirin import ir, types, lowering
44
from kirin.dialects import func
5-
from kirin.exceptions import DialectLoweringError
65

76
dialect = ir.Dialect("lowering.call")
87

@@ -11,38 +10,38 @@
1110
class Lowering(lowering.FromPythonAST):
1211

1312
def lower_Call_local(
14-
self, state: lowering.LoweringState, callee: ir.SSAValue, node: ast.Call
13+
self, state: lowering.State, callee: ir.SSAValue, node: ast.Call
1514
) -> lowering.Result:
1615
args, keywords = self.__lower_Call_args_kwargs(state, node)
1716
stmt = func.Call(callee, args, kwargs=keywords)
18-
return lowering.Result(state.append_stmt(stmt))
17+
return state.current_frame.push(stmt)
1918

2019
def lower_Call_global_method(
2120
self,
22-
state: lowering.LoweringState,
21+
state: lowering.State,
2322
method: ir.Method,
2423
node: ast.Call,
2524
) -> lowering.Result:
2625
args, keywords = self.__lower_Call_args_kwargs(state, node)
2726
stmt = func.Invoke(args, callee=method, kwargs=keywords)
2827
stmt.result.type = method.return_type or types.Any
29-
return lowering.Result(state.append_stmt(stmt))
28+
return state.current_frame.push(stmt)
3029

3130
def __lower_Call_args_kwargs(
3231
self,
33-
state: lowering.LoweringState,
32+
state: lowering.State,
3433
node: ast.Call,
3534
):
3635
args: list[ir.SSAValue] = []
3736
for arg in node.args:
3837
if isinstance(arg, ast.Starred): # TODO: support *args
39-
raise DialectLoweringError("starred arguments are not supported")
38+
raise lowering.BuildError("starred arguments are not supported")
4039
else:
41-
args.append(state.visit(arg).expect_one())
40+
args.append(state.lower(arg).expect_one())
4241

4342
keywords = []
4443
for kw in node.keywords:
4544
keywords.append(kw.arg)
46-
args.append(state.visit(kw.value).expect_one())
45+
args.append(state.lower(kw.value).expect_one())
4746

4847
return tuple(args), tuple(keywords)

0 commit comments

Comments
 (0)