Skip to content

Commit 040e26a

Browse files
authored
support annotated assign (#344)
closes #324
1 parent df6efeb commit 040e26a

23 files changed

Lines changed: 384 additions & 110 deletions

File tree

src/kirin/dialects/ilist/interp.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@ def _range(self, interp, frame: Frame, stmt: Range):
1717

1818
@impl(New)
1919
def new(self, interp, frame: Frame, stmt: New):
20-
elem_type = types.Any
21-
if stmt.values:
22-
elem_type = stmt.values[0].type
23-
for each in stmt.values[1:]:
24-
elem_type = elem_type.join(each.type)
25-
return (IList(list(frame.get_values(stmt.values)), elem=elem_type),)
20+
return (IList(list(frame.get_values(stmt.values)), elem=stmt.elem_type),)
2621

2722
@impl(Len, types.PyClass(IList))
2823
def len(self, interp, frame: Frame, stmt: Len):

src/kirin/dialects/ilist/lowering.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,32 @@ def lower_List(self, state: lowering.State, node: ast.List) -> lowering.Result:
1919
else:
2020
typ = types.Any
2121

22-
return state.current_frame.push(ilist.New(values=tuple(elts)))
22+
return state.current_frame.push(ilist.New(values=tuple(elts), elem_type=typ))
23+
24+
@lowering.akin(ilist.IList)
25+
def lower_Call_IList(
26+
self, state: lowering.State, node: ast.Call
27+
) -> lowering.Result:
28+
if len(node.args) != 1:
29+
raise lowering.BuildError("IList constructor only takes one argument")
30+
value = node.args[0]
31+
if not isinstance(value, ast.List):
32+
raise lowering.BuildError("IList constructor only takes a list")
33+
34+
if len(node.keywords) > 1:
35+
raise lowering.BuildError(
36+
"IList constructor only takes one keyword argument"
37+
)
38+
39+
if len(node.keywords) == 1:
40+
kw = node.keywords[0]
41+
if kw.arg != "elem":
42+
raise lowering.BuildError(
43+
"IList constructor only takes elem keyword argument"
44+
)
45+
elem = self.get_hint(state, kw.value)
46+
elts = tuple(state.lower(each).expect_one() for each in value.elts)
47+
stmt = ilist.New(values=tuple(elts), elem_type=elem)
48+
return state.current_frame.push(stmt)
49+
else:
50+
return self.lower_List(state, value)

src/kirin/dialects/ilist/stmts.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,22 @@ class Range(ir.Statement):
2525
class New(ir.Statement):
2626
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
2727
values: tuple[ir.SSAValue, ...] = info.argument(ElemT)
28+
elem_type: types.TypeAttribute = info.attribute()
2829
result: ir.ResultValue = info.result(IListType[ElemT])
2930

3031
def __init__(
3132
self,
3233
values: Sequence[ir.SSAValue],
34+
elem_type: types.TypeAttribute | None = None,
3335
) -> None:
34-
# get elem type
35-
if not values:
36+
if not elem_type:
3637
elem_type = types.Any
37-
else:
38-
elem_type = values[0].type
39-
for v in values:
40-
elem_type = elem_type.join(v.type)
41-
4238
result_type = IListType[elem_type, types.Literal(len(values))]
4339
super().__init__(
4440
args=values,
4541
result_types=(result_type,),
4642
args_slice={"values": slice(0, len(values))},
43+
attributes={"elem_type": elem_type},
4744
)
4845

4946

src/kirin/dialects/lowering/func.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,3 @@ def assert_simple_arguments(self, node: ast.arguments) -> None:
115115

116116
if node.posonlyargs:
117117
raise lowering.BuildError("positional-only arguments are not supported")
118-
119-
@staticmethod
120-
def get_hint(state: lowering.State, node: ast.expr | None):
121-
if node is None:
122-
return types.Any
123-
124-
try:
125-
t = state.get_global(node).data
126-
return types.hint2type(t)
127-
except Exception as e: # noqa: E722
128-
raise lowering.BuildError(
129-
f"expect a type hint, got {ast.unparse(node)}"
130-
) from e

src/kirin/dialects/lowering/range.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
@py.register
1616
class PyLowering(lowering.FromPythonAST):
1717

18+
@lowering.akin(range)
1819
def lower_Call_range(
1920
self, state: lowering.State, node: ast.Call
2021
) -> lowering.Result:
@@ -24,6 +25,7 @@ def lower_Call_range(
2425
@ilist.register
2526
class IListLowering(lowering.FromPythonAST):
2627

28+
@lowering.akin(range)
2729
def lower_Call_range(
2830
self, state: lowering.State, node: ast.Call
2931
) -> lowering.Result:

src/kirin/dialects/py/assign.py

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,31 @@ class SetItem(ir.Statement):
4949
index: ir.SSAValue = info.argument(print=False)
5050

5151

52+
@statement(dialect=dialect)
53+
class SetAttribute(ir.Statement):
54+
name = "setattr"
55+
traits = frozenset({lowering.FromPythonCall()})
56+
obj: ir.SSAValue = info.argument(print=False)
57+
attr: str = info.attribute()
58+
value: ir.SSAValue = info.argument(print=False)
59+
60+
61+
@statement(dialect=dialect)
62+
class TypeAssert(ir.Statement):
63+
traits = frozenset({lowering.FromPythonCall()})
64+
got: ir.SSAValue = info.argument(print=False)
65+
expected: types.TypeAttribute = info.attribute()
66+
result: ir.ResultValue = info.result()
67+
68+
def __init__(self, got: ir.SSAValue, *, expected: types.TypeAttribute):
69+
super().__init__(
70+
args=(got,),
71+
attributes={"expected": expected},
72+
result_types=(expected,),
73+
args_slice={"got": 0},
74+
)
75+
76+
5277
@dialect.register
5378
class Concrete(interp.MethodTable):
5479

@@ -59,7 +84,37 @@ def alias(self, interp, frame: interp.Frame, stmt: Alias):
5984
@interp.impl(SetItem)
6085
def setindex(self, interp, frame: interp.Frame, stmt: SetItem):
6186
frame.get(stmt.obj)[frame.get(stmt.index)] = frame.get(stmt.value)
62-
return (None,)
87+
88+
@interp.impl(SetAttribute)
89+
def set_attribute(self, interp, frame: interp.Frame, stmt: SetAttribute):
90+
obj = frame.get(stmt.obj)
91+
value = frame.get(stmt.value)
92+
setattr(obj, stmt.attr, value)
93+
94+
# NOTE: we don't do much runtime type checking here, object with generic
95+
# types will unlikely work here.
96+
# TODO: consider runtime type checking by boxing the value
97+
@interp.impl(TypeAssert)
98+
def type_assert(self, interp_, frame: interp.Frame, stmt: TypeAssert):
99+
got = frame.get(stmt.got)
100+
got_type = types.PyClass(type(got))
101+
if not got_type.is_subseteq(stmt.expected):
102+
raise interp.WrapException(
103+
TypeError(f"Expected {stmt.expected}, got {got_type}")
104+
)
105+
return (frame.get(stmt.got),)
106+
107+
108+
@dialect.register(key="typeinfer")
109+
class TypeInfer(interp.MethodTable):
110+
@interp.impl(TypeAssert)
111+
def type_assert(
112+
self, interp_, frame: interp.Frame[types.TypeAttribute], stmt: TypeAssert
113+
):
114+
got = frame.get(stmt.got)
115+
if got.is_subseteq(stmt.expected):
116+
return (got.meet(stmt.expected),)
117+
return (types.Bottom,)
63118

64119

65120
@dialect.register
@@ -79,16 +134,43 @@ def lower_Assign(self, state: lowering.State, node: ast.Assign) -> lowering.Resu
79134
current_frame.defs[lhs_name] = current_frame.push(stmt).result
80135
case _:
81136
for target, value in zip(node.targets, result.data):
82-
match target:
83-
# NOTE: if the name exists new ssa value will be
84-
# used in the future to shadow the old one
85-
case ast.Name(name, ast.Store()):
86-
value.name = name
87-
current_frame.defs[name] = value
88-
case ast.Subscript(obj, slice):
89-
obj = state.lower(obj).expect_one()
90-
slice = state.lower(slice).expect_one()
91-
stmt = SetItem(obj=obj, index=slice, value=value)
92-
current_frame.push(stmt)
93-
case _:
94-
raise lowering.BuildError(f"unsupported target {target}")
137+
self.assign_item(state, target, value)
138+
139+
def lower_AnnAssign(
140+
self, state: lowering.State, node: ast.AnnAssign
141+
) -> lowering.Result:
142+
type_hint = self.get_hint(state, node.annotation)
143+
value = state.lower(node.value).expect_one()
144+
stmt = state.current_frame.push(TypeAssert(got=value, expected=type_hint))
145+
self.assign_item(state, node.target, stmt.result)
146+
147+
def lower_AugAssign(
148+
self, state: lowering.State, node: ast.AugAssign
149+
) -> lowering.Result:
150+
self.assign_item(state, node.target, state.lower(node.value).expect_one())
151+
152+
@staticmethod
153+
def assign_item(state: lowering.State, target, value: ir.SSAValue):
154+
current_frame = state.current_frame
155+
match target:
156+
case ast.Name(name, ast.Store()):
157+
value.name = name
158+
current_frame.defs[name] = value
159+
case ast.Attribute(obj, attr, ast.Store()):
160+
obj = state.lower(obj).expect_one()
161+
stmt = SetAttribute(obj, value, attr=attr)
162+
current_frame.push(stmt)
163+
case ast.Subscript(obj, slice, ast.Store()):
164+
obj = state.lower(obj).expect_one()
165+
slice = state.lower(slice).expect_one()
166+
stmt = SetItem(obj=obj, index=slice, value=value)
167+
current_frame.push(stmt)
168+
case _:
169+
raise lowering.BuildError(f"unsupported target {target}")
170+
171+
@staticmethod
172+
def assert_assign_value_type(value: ir.SSAValue, type_hint: types.TypeAttribute):
173+
value_type = value.type.meet(type_hint)
174+
if value_type is value_type.bottom():
175+
raise lowering.BuildError(f"Cannot assign {value.type} to {type_hint}")
176+
return value_type

src/kirin/dialects/py/attr.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,14 @@ class Lowering(lowering.FromPythonAST):
4040
def lower_Attribute(
4141
self, state: lowering.State, node: ast.Attribute
4242
) -> lowering.Result:
43-
from kirin.dialects.py import Constant
4443

4544
if not isinstance(node.ctx, ast.Load):
4645
raise lowering.BuildError(f"unsupported attribute context {node.ctx}")
4746

4847
# NOTE: eagerly load global variables
4948
value = state.get_global(node, no_raise=True)
5049
if value is not None:
51-
return state.current_frame.push(Constant(value.data))
50+
return state.lower(ast.Constant(value.data)).expect_one()
5251

5352
value = state.lower(node.value).expect_one()
5453
return state.current_frame.push(GetAttr(obj=value, attrname=node.attr))

src/kirin/dialects/py/builtin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ class Sum(ir.Statement):
3939
@dialect.register
4040
class Lowering(lowering.FromPythonAST):
4141

42+
@lowering.akin(abs)
4243
def lower_Call_abs(self, state: lowering.State, node: Call) -> lowering.Result:
4344
return state.current_frame.push(Abs(state.lower(node.args[0]).expect_one()))
4445

46+
@lowering.akin(sum)
4547
def lower_Call_sum(self, state: lowering.State, node: Call) -> lowering.Result:
4648
return state.current_frame.push(Sum(state.lower(node.args[0]).expect_one()))
4749

src/kirin/dialects/py/iterable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,15 @@ def next_(self, interp, frame: interp.Frame, stmt: Next):
6666
@dialect.register
6767
class Lowering(lowering.FromPythonAST):
6868

69+
@lowering.akin(iter)
6970
def lower_Call_iter(self, state: lowering.State, node: Call) -> lowering.Result:
7071
if len(node.args) != 1:
7172
raise lowering.BuildError("iter() takes exactly 1 argument")
7273
return state.current_frame.push(
7374
Iter(state.lower(node.args[0]).expect_one()),
7475
)
7576

77+
@lowering.akin(next)
7678
def lower_Call_next(self, state: lowering.State, node: Call) -> lowering.Result:
7779
if len(node.args) == 2:
7880
raise lowering.BuildError(

src/kirin/dialects/py/len.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,6 @@ def len(self, interp, frame: interp.Frame, stmt: Len):
4949
@dialect.register
5050
class Lowering(lowering.FromPythonAST):
5151

52+
@lowering.akin(len)
5253
def lower_Call_len(self, state: lowering.State, node: ast.Call) -> lowering.Result:
5354
return state.current_frame.push(Len(state.lower(node.args[0]).expect_one()))

0 commit comments

Comments
 (0)